summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/_scripts/review_recent_signups.py175
-rw-r--r--synapse/api/auth.py79
-rw-r--r--synapse/api/constants.py6
-rw-r--r--synapse/app/_base.py11
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/cache.py70
-rw-r--r--synapse/config/consent.py2
-rw-r--r--synapse/config/database.py3
-rw-r--r--synapse/config/jwt.py2
-rw-r--r--synapse/config/logger.py2
-rw-r--r--synapse/config/modules.py2
-rw-r--r--synapse/config/oidc.py4
-rw-r--r--synapse/config/password_auth_providers.py2
-rw-r--r--synapse/config/registration.py21
-rw-r--r--synapse/config/repository.py2
-rw-r--r--synapse/config/server.py23
-rw-r--r--synapse/config/spam_checker.py2
-rw-r--r--synapse/config/stats.py2
-rw-r--r--synapse/config/tracer.py2
-rw-r--r--synapse/config/user_directory.py2
-rw-r--r--synapse/event_auth.py5
-rw-r--r--synapse/events/__init__.py2
-rw-r--r--synapse/events/builder.py77
-rw-r--r--synapse/federation/federation_base.py12
-rw-r--r--synapse/federation/federation_server.py232
-rw-r--r--synapse/federation/transport/server.py592
-rw-r--r--synapse/handlers/admin.py7
-rw-r--r--synapse/handlers/auth.py132
-rw-r--r--synapse/handlers/event_auth.py62
-rw-r--r--synapse/handlers/federation.py255
-rw-r--r--synapse/handlers/message.py16
-rw-r--r--synapse/handlers/register.py115
-rw-r--r--synapse/handlers/room.py3
-rw-r--r--synapse/handlers/space_summary.py17
-rw-r--r--synapse/http/server.py2
-rw-r--r--synapse/http/servlet.py50
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/replication/http/login.py13
-rw-r--r--synapse/rest/client/v1/login.py171
-rw-r--r--synapse/rest/client/v2_alpha/register.py98
-rw-r--r--synapse/rest/client/v2_alpha/sync.py69
-rw-r--r--synapse/storage/database.py2
-rw-r--r--synapse/storage/databases/main/event_federation.py114
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py186
-rw-r--r--synapse/storage/databases/main/lock.py15
-rw-r--r--synapse/storage/databases/main/registration.py207
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py5
-rw-r--r--synapse/storage/engines/sqlite.py5
-rw-r--r--synapse/storage/schema/__init__.py2
-rw-r--r--synapse/storage/schema/main/delta/59/14refresh_tokens.sql34
-rw-r--r--synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres45
-rw-r--r--synapse/storage/schema/main/delta/60/02change_stream_ordering_columns.sql.postgres30
-rw-r--r--synapse/util/caches/lrucache.py237
-rw-r--r--synapse/util/linked_list.py150
57 files changed, 2704 insertions, 686 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 1bd03462ac..5ecce24eee 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.37.1" +__version__ = "1.38.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py new file mode 100644
index 0000000000..01dc0c4237 --- /dev/null +++ b/synapse/_scripts/review_recent_signups.py
@@ -0,0 +1,175 @@ +#!/usr/bin/env python +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys +import time +from datetime import datetime +from typing import List + +import attr + +from synapse.config._base import RootConfig, find_config_files, read_config_files +from synapse.config.database import DatabaseConfig +from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn +from synapse.storage.engines import create_engine + + +class ReviewConfig(RootConfig): + "A config class that just pulls out the database config" + config_classes = [DatabaseConfig] + + +@attr.s(auto_attribs=True) +class UserInfo: + user_id: str + creation_ts: int + emails: List[str] = attr.Factory(list) + private_rooms: List[str] = attr.Factory(list) + public_rooms: List[str] = attr.Factory(list) + ips: List[str] = attr.Factory(list) + + +def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]: + """Fetches recently registered users and some info on them.""" + + sql = """ + SELECT name, creation_ts FROM users + WHERE + ? <= creation_ts + AND deactivated = 0 + """ + + txn.execute(sql, (since_ms / 1000,)) + + user_infos = [UserInfo(user_id, creation_ts) for user_id, creation_ts in txn] + + for user_info in user_infos: + user_info.emails = DatabasePool.simple_select_onecol_txn( + txn, + table="user_threepids", + keyvalues={"user_id": user_info.user_id, "medium": "email"}, + retcol="address", + ) + + sql = """ + SELECT room_id, canonical_alias, name, join_rules + FROM local_current_membership + INNER JOIN room_stats_state USING (room_id) + WHERE user_id = ? AND membership = 'join' + """ + + txn.execute(sql, (user_info.user_id,)) + for room_id, canonical_alias, name, join_rules in txn: + if join_rules == "public": + user_info.public_rooms.append(canonical_alias or name or room_id) + else: + user_info.private_rooms.append(canonical_alias or name or room_id) + + user_info.ips = DatabasePool.simple_select_onecol_txn( + txn, + table="user_ips", + keyvalues={"user_id": user_info.user_id}, + retcol="ip", + ) + + return user_infos + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config-path", + action="append", + metavar="CONFIG_FILE", + help="The config files for Synapse.", + required=True, + ) + parser.add_argument( + "-s", + "--since", + metavar="duration", + help="Specify how far back to review user registrations for, defaults to 7d (i.e. 7 days).", + default="7d", + ) + parser.add_argument( + "-e", + "--exclude-emails", + action="store_true", + help="Exclude users that have validated email addresses", + ) + parser.add_argument( + "-u", + "--only-users", + action="store_true", + help="Only print user IDs that match.", + ) + + config = ReviewConfig() + + config_args = parser.parse_args(sys.argv[1:]) + config_files = find_config_files(search_paths=config_args.config_path) + config_dict = read_config_files(config_files) + config.parse_config_dict( + config_dict, + ) + + since_ms = time.time() * 1000 - config.parse_duration(config_args.since) + exclude_users_with_email = config_args.exclude_emails + include_context = not config_args.only_users + + for database_config in config.database.databases: + if "main" in database_config.databases: + break + + engine = create_engine(database_config.config) + + with make_conn(database_config, engine, "review_recent_signups") as db_conn: + user_infos = get_recent_users(db_conn.cursor(), since_ms) + + for user_info in user_infos: + if exclude_users_with_email and user_info.emails: + continue + + if include_context: + print_public_rooms = "" + if user_info.public_rooms: + print_public_rooms = "(" + ", ".join(user_info.public_rooms[:3]) + + if len(user_info.public_rooms) > 3: + print_public_rooms += ", ..." + + print_public_rooms += ")" + + print("# Created:", datetime.fromtimestamp(user_info.creation_ts)) + print("# Email:", ", ".join(user_info.emails) or "None") + print("# IPs:", ", ".join(user_info.ips)) + print( + "# Number joined public rooms:", + len(user_info.public_rooms), + print_public_rooms, + ) + print("# Number joined private rooms:", len(user_info.private_rooms)) + print("#") + + print(user_info.user_id) + + if include_context: + print() + + +if __name__ == "__main__": + main() diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 0f7788b411..d26014ef4f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import pymacaroons from netaddr import IPAddress @@ -28,7 +28,6 @@ from synapse.api.errors import ( InvalidClientTokenError, MissingClientTokenError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.http import get_request_user_agent @@ -38,7 +37,6 @@ from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester, StateMap, UserID, create_requester from synapse.util.caches.lrucache import LruCache from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry -from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -46,15 +44,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -AuthEventTypes = ( - EventTypes.Create, - EventTypes.Member, - EventTypes.PowerLevels, - EventTypes.JoinRules, - EventTypes.RoomHistoryVisibility, - EventTypes.ThirdPartyInvite, -) - # guests always get this device id. GUEST_DEVICE_ID = "guest_device" @@ -65,9 +54,7 @@ class _InvalidMacaroonException(Exception): class Auth: """ - FIXME: This class contains a mix of functions for authenticating users - of our client-server API and authenticating events added to room graphs. - The latter should be moved to synapse.handlers.event_auth.EventAuthHandler. + This class contains functions for authenticating users of our client-server API. """ def __init__(self, hs: "HomeServer"): @@ -89,18 +76,6 @@ class Auth: self._macaroon_secret_key = hs.config.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users - async def check_from_context( - self, room_version: str, event, context, do_sig_check=True - ) -> None: - auth_event_ids = event.auth_event_ids() - auth_events_by_id = await self.store.get_events(auth_event_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} - - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - event_auth.check( - room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check - ) - async def check_user_in_room( self, room_id: str, @@ -151,13 +126,6 @@ class Auth: raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - async def check_host_in_room(self, room_id: str, host: str) -> bool: - with Measure(self.clock, "check_host_in_room"): - return await self.store.is_host_joined(room_id, host) - - def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]: - return event_auth.get_public_keys(invite_event) - async def get_user_by_req( self, request: SynapseRequest, @@ -245,6 +213,11 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) + # Mark the token as used. This is used to invalidate old refresh + # tokens after some time. + if not user_info.token_used and token_id is not None: + await self.store.mark_access_token_as_used(token_id) + requester = create_requester( user_info.user_id, token_id, @@ -488,44 +461,6 @@ class Auth: """ return await self.store.is_server_admin(user) - def compute_auth_events( - self, - event, - current_state_ids: StateMap[str], - for_verification: bool = False, - ) -> List[str]: - """Given an event and current state return the list of event IDs used - to auth an event. - - If `for_verification` is False then only return auth events that - should be added to the event's `auth_events`. - - Returns: - List of event IDs. - """ - - if event.type == EventTypes.Create: - return [] - - # Currently we ignore the `for_verification` flag even though there are - # some situations where we can drop particular auth events when adding - # to the event's `auth_events` (e.g. joins pointing to previous joins - # when room is publicly joinable). Dropping event IDs has the - # advantage that the auth chain for the room grows slower, but we use - # the auth chain in state resolution v2 to order events, which means - # care must be taken if dropping events to ensure that it doesn't - # introduce undesirable "state reset" behaviour. - # - # All of which sounds a bit tricky so we don't bother for now. - - auth_ids = [] - for etype, state_key in event_auth.auth_types_for_event(event): - auth_ev_id = current_state_ids.get((etype, state_key)) - if auth_ev_id: - auth_ids.append(auth_ev_id) - - return auth_ids - async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool: """Determine whether the user is allowed to edit the room's entry in the published room list. diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 414e4c019a..8363c2bb0f 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py
@@ -201,6 +201,12 @@ class EventContentFields: ) +class RoomTypes: + """Understood values of the room_type field of m.room.create events.""" + + SPACE = "m.space" + + class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" DEFAULT = MEGOLM_V1_AES_SHA2 diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 8879136881..b30571fe49 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py
@@ -21,7 +21,7 @@ import socket import sys import traceback import warnings -from typing import Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Awaitable, Callable, Iterable from cryptography.utils import CryptographyDeprecationWarning from typing_extensions import NoReturn @@ -41,10 +41,14 @@ from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats +from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # list of tuples of function, args list, kwargs dict @@ -312,7 +316,7 @@ def refresh_certificate(hs): logger.info("Context factories updated.") -async def start(hs: "synapse.server.HomeServer"): +async def start(hs: "HomeServer"): """ Start a Synapse server or worker. @@ -365,6 +369,9 @@ async def start(hs: "synapse.server.HomeServer"): load_legacy_spam_checkers(hs) + # If we've configured an expiry time for caches, start the background job now. + setup_expire_lru_cache_entries(hs) + # It is now safe to start your Synapse. hs.start_listening() hs.get_datastore().db_pool.start_profiling() diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 23ca0c83c1..06fbd1166b 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -5,6 +5,7 @@ from synapse.config import ( api, appservice, auth, + cache, captcha, cas, consent, @@ -88,6 +89,7 @@ class RootConfig: tracer: tracer.TracerConfig redis: redis.RedisConfig modules: modules.ModulesConfig + caches: cache.CacheConfig federation: federation.FederationConfig config_classes: List = ... diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 91165ee1ce..7789b40323 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py
@@ -116,35 +116,41 @@ class CacheConfig(Config): #event_cache_size: 10K caches: - # Controls the global cache factor, which is the default cache factor - # for all caches if a specific factor for that cache is not otherwise - # set. - # - # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment - # variable. Setting by environment variable takes priority over - # setting through the config file. - # - # Defaults to 0.5, which will half the size of all caches. - # - #global_factor: 1.0 - - # A dictionary of cache name to cache factor for that individual - # cache. Overrides the global cache factor for a given cache. - # - # These can also be set through environment variables comprised - # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital - # letters and underscores. Setting by environment variable - # takes priority over setting through the config file. - # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 - # - # Some caches have '*' and other characters that are not - # alphanumeric or underscores. These caches can be named with or - # without the special characters stripped. For example, to specify - # the cache factor for `*stateGroupCache*` via an environment - # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. - # - per_cache_factors: - #get_users_who_share_room_with_user: 2.0 + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + # Some caches have '*' and other characters that are not + # alphanumeric or underscores. These caches can be named with or + # without the special characters stripped. For example, to specify + # the cache factor for `*stateGroupCache*` via an environment + # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + # Controls how long an entry can be in a cache without having been + # accessed before being evicted. Defaults to None, which means + # entries are never evicted based on time. + # + #expiry_time: 30m """ def read_config(self, config, **kwargs): @@ -200,6 +206,12 @@ class CacheConfig(Config): e.message # noqa: B306, DependencyException.message is a property ) + expiry_time = cache_config.get("expiry_time") + if expiry_time: + self.expiry_time_msec = self.parse_duration(expiry_time) + else: + self.expiry_time_msec = None + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/config/consent.py b/synapse/config/consent.py
index 30d07cc219..b05a9bd97f 100644 --- a/synapse/config/consent.py +++ b/synapse/config/consent.py
@@ -22,7 +22,7 @@ DEFAULT_CONFIG = """\ # User Consent configuration # # for detailed instructions, see -# https://github.com/matrix-org/synapse/blob/master/docs/consent_tracking.md +# https://matrix-org.github.io/synapse/latest/consent_tracking.html # # Parts of this section are required if enabling the 'consent' resource under # 'listeners', in particular 'template_dir' and 'version'. diff --git a/synapse/config/database.py b/synapse/config/database.py
index c76ef1e1de..3d7d92f615 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py
@@ -62,7 +62,8 @@ DEFAULT_CONFIG = """\ # cp_min: 5 # cp_max: 10 # -# For more information on using Synapse with Postgres, see `docs/postgres.md`. +# For more information on using Synapse with Postgres, +# see https://matrix-org.github.io/synapse/latest/postgres.html. # database: name: sqlite3 diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 9e07e73008..9d295f5856 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py
@@ -64,7 +64,7 @@ class JWTConfig(Config): # Note that this is a non-standard login type and client support is # expected to be non-existent. # - # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md. + # See https://matrix-org.github.io/synapse/latest/jwt.html. # #jwt_config: # Uncomment the following to enable authorization using JSON web diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 91d9bcf32e..ad4e6e61c3 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py
@@ -49,7 +49,7 @@ DEFAULT_LOG_CONFIG = Template( # be ingested by ELK stacks. See [2] for details. # # [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema -# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md +# [2]: https://matrix-org.github.io/synapse/latest/structured_logging.html version: 1 diff --git a/synapse/config/modules.py b/synapse/config/modules.py
index 3209e1c492..ae0821e5a5 100644 --- a/synapse/config/modules.py +++ b/synapse/config/modules.py
@@ -37,7 +37,7 @@ class ModulesConfig(Config): # Server admins can expand Synapse's functionality with external modules. # - # See https://matrix-org.github.io/synapse/develop/modules.html for more + # See https://matrix-org.github.io/synapse/latest/modules.html for more # documentation on how to configure or create custom modules for Synapse. # modules: diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index ea0abf5aa2..942e2672a9 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py
@@ -166,7 +166,7 @@ class OIDCConfig(Config): # # module: The class name of a custom mapping module. Default is # {mapping_provider!r}. - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers + # See https://matrix-org.github.io/synapse/latest/sso_mapping_providers.html#openid-mapping-providers # for information on implementing a custom mapping provider. # # config: Configuration for the mapping provider module. This section will @@ -217,7 +217,7 @@ class OIDCConfig(Config): # - attribute: groups # value: "admin" # - # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md + # See https://matrix-org.github.io/synapse/latest/openid.html # for information on how to configure these options. # # For backwards compatibility, it is also possible to configure a single OIDC diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 1cf69734bb..fd90b79772 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py
@@ -57,7 +57,7 @@ class PasswordAuthProviderConfig(Config): # ex. LDAP, external tokens, etc. # # For more information and known implementations, please see - # https://github.com/matrix-org/synapse/blob/master/docs/password_auth_providers.md + # https://matrix-org.github.io/synapse/latest/password_auth_providers.html # # Note: instances wishing to use SAML or CAS authentication should # instead use the `saml2_config` or `cas_config` options, diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index eecc0478a7..6e9f405312 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -153,6 +153,27 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime + # The `access_token_lifetime` applies for tokens that can be renewed + # using a refresh token, as per MSC2918. If it is `None`, the refresh + # token mechanism is disabled. + # + # Since it is incompatible with the `session_lifetime` mechanism, it is set to + # `None` by default if a `session_lifetime` is set. + access_token_lifetime = config.get( + "access_token_lifetime", "5m" if session_lifetime is None else None + ) + if access_token_lifetime is not None: + access_token_lifetime = self.parse_duration(access_token_lifetime) + self.access_token_lifetime = access_token_lifetime + + if session_lifetime is not None and access_token_lifetime is not None: + raise ConfigError( + "The refresh token mechanism is incompatible with the " + "`session_lifetime` option. Consider disabling the " + "`session_lifetime` option or disabling the refresh token " + "mechanism by removing the `access_token_lifetime` option." + ) + # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 1bd40a89f0..b5820616da 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -256,7 +256,7 @@ class ContentRepositoryConfig(Config): # # If you are using a reverse proxy you may also need to set this value in # your reverse proxy's config. Notably Nginx has a small max body size by default. - # See https://matrix-org.github.io/synapse/develop/reverse_proxy.html. + # See https://matrix-org.github.io/synapse/latest/reverse_proxy.html. # #max_upload_size: 50M diff --git a/synapse/config/server.py b/synapse/config/server.py
index 20022b963f..60afee5804 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -153,7 +153,7 @@ ROOM_COMPLEXITY_TOO_GREAT = ( METRICS_PORT_WARNING = """\ The metrics_port configuration option is deprecated in Synapse 0.31 in favour of a listener. Please see -https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md +https://matrix-org.github.io/synapse/latest/metrics-howto.html on how to configure the new listener. --------------------------------------------------------------------------------""" @@ -817,7 +817,7 @@ class ServerConfig(Config): # In most cases you should avoid using a matrix specific subdomain such as # matrix.example.com or synapse.example.com as the server_name for the same # reasons you wouldn't use user@email.example.com as your email address. - # See https://github.com/matrix-org/synapse/blob/master/docs/delegate.md + # See https://matrix-org.github.io/synapse/latest/delegate.html # for information on how to host Synapse on a subdomain while preserving # a clean server_name. # @@ -994,9 +994,9 @@ class ServerConfig(Config): # 'all local interfaces'. # # type: the type of listener. Normally 'http', but other valid options are: - # 'manhole' (see docs/manhole.md), - # 'metrics' (see docs/metrics-howto.md), - # 'replication' (see docs/workers.md). + # 'manhole' (see https://matrix-org.github.io/synapse/latest/manhole.html), + # 'metrics' (see https://matrix-org.github.io/synapse/latest/metrics-howto.html), + # 'replication' (see https://matrix-org.github.io/synapse/latest/workers.html). # # tls: set to true to enable TLS for this listener. Will use the TLS # key/cert specified in tls_private_key_path / tls_certificate_path. @@ -1021,8 +1021,8 @@ class ServerConfig(Config): # client: the client-server API (/_matrix/client), and the synapse admin # API (/_synapse/admin). Also implies 'media' and 'static'. # - # consent: user consent forms (/_matrix/consent). See - # docs/consent_tracking.md. + # consent: user consent forms (/_matrix/consent). + # See https://matrix-org.github.io/synapse/latest/consent_tracking.html. # # federation: the server-server API (/_matrix/federation). Also implies # 'media', 'keys', 'openid' @@ -1031,12 +1031,13 @@ class ServerConfig(Config): # # media: the media API (/_matrix/media). # - # metrics: the metrics interface. See docs/metrics-howto.md. + # metrics: the metrics interface. + # See https://matrix-org.github.io/synapse/latest/metrics-howto.html. # # openid: OpenID authentication. # - # replication: the HTTP replication API (/_synapse/replication). See - # docs/workers.md. + # replication: the HTTP replication API (/_synapse/replication). + # See https://matrix-org.github.io/synapse/latest/workers.html. # # static: static resources under synapse/static (/_matrix/static). (Mostly # useful for 'fallback authentication'.) @@ -1056,7 +1057,7 @@ class ServerConfig(Config): # that unwraps TLS. # # If you plan to use a reverse proxy, please see - # https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.md. + # https://matrix-org.github.io/synapse/latest/reverse_proxy.html. # %(unsecure_http_bindings)s diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index d0311d6468..cb7716c837 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py
@@ -26,7 +26,7 @@ LEGACY_SPAM_CHECKER_WARNING = """ This server is using a spam checker module that is implementing the deprecated spam checker interface. Please check with the module's maintainer to see if a new version supporting Synapse's generic modules system is available. -For more information, please see https://matrix-org.github.io/synapse/develop/modules.html +For more information, please see https://matrix-org.github.io/synapse/latest/modules.html ---------------------------------------------------------------------------------------""" diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 3d44b51201..78f61fe9da 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py
@@ -51,7 +51,7 @@ class StatsConfig(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ # Settings for local room and user statistics collection. See - # docs/room_and_user_statistics.md. + # https://matrix-org.github.io/synapse/latest/room_and_user_statistics.html. # stats: # Uncomment the following to disable room and user statistics. Note that doing diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index d0ea17261f..21b9a88353 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py
@@ -81,7 +81,7 @@ class TracerConfig(Config): #enabled: true # The list of homeservers we wish to send and receive span contexts and span baggage. - # See docs/opentracing.rst. + # See https://matrix-org.github.io/synapse/latest/opentracing.html. # # This is a list of regexes which are matched against the server_name of the # homeserver. diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 3ac1f2b5b1..f1beb87aea 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py
@@ -53,7 +53,7 @@ class UserDirectoryConfig(Config): # # If you set it true, you'll have to rebuild the user_directory search # indexes, see: - # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md + # https://matrix-org.github.io/synapse/latest/user_directory.html # # Uncomment to return search results containing all known users, even if that # user does not share a room with the requester. diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 33d7c60241..89bcf81515 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -29,6 +29,7 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.events import EventBase +from synapse.events.builder import EventBuilder from synapse.types import StateMap, UserID, get_domain_from_id logger = logging.getLogger(__name__) @@ -724,7 +725,7 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]: return public_keys -def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]: +def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]: """Given an event, return a list of (EventType, StateKey) that may be needed to auth the event. The returned list may be a superset of what would actually be required depending on the full state of the room. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 0cb9c1cc1e..6286ad999a 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py
@@ -118,7 +118,7 @@ class _EventInternalMetadata: proactively_send = DictProperty("proactively_send") # type: bool redacted = DictProperty("redacted") # type: bool txn_id = DictProperty("txn_id") # type: str - token_id = DictProperty("token_id") # type: str + token_id = DictProperty("token_id") # type: int historical = DictProperty("historical") # type: bool # XXX: These are set by StreamWorkerStore._set_before_and_after. diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 81bf8615b7..26e3950859 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py
@@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr from nacl.signing import SigningKey -from synapse.api.auth import Auth from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -34,10 +33,14 @@ from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.handlers.event_auth import EventAuthHandler + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -@attr.s(slots=True, cmp=False, frozen=True) +@attr.s(slots=True, cmp=False, frozen=True, auto_attribs=True) class EventBuilder: """A format independent event builder used to build up the event content before signing the event. @@ -62,31 +65,30 @@ class EventBuilder: _signing_key: The signing key to use to sign the event as the server """ - _state = attr.ib(type=StateHandler) - _auth = attr.ib(type=Auth) - _store = attr.ib(type=DataStore) - _clock = attr.ib(type=Clock) - _hostname = attr.ib(type=str) - _signing_key = attr.ib(type=SigningKey) + _state: StateHandler + _event_auth_handler: "EventAuthHandler" + _store: DataStore + _clock: Clock + _hostname: str + _signing_key: SigningKey - room_version = attr.ib(type=RoomVersion) + room_version: RoomVersion - room_id = attr.ib(type=str) - type = attr.ib(type=str) - sender = attr.ib(type=str) + room_id: str + type: str + sender: str - content = attr.ib(default=attr.Factory(dict), type=JsonDict) - unsigned = attr.ib(default=attr.Factory(dict), type=JsonDict) + content: JsonDict = attr.Factory(dict) + unsigned: JsonDict = attr.Factory(dict) # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. - _state_key = attr.ib(default=None, type=Optional[str]) - _redacts = attr.ib(default=None, type=Optional[str]) - _origin_server_ts = attr.ib(default=None, type=Optional[int]) + _state_key: Optional[str] = None + _redacts: Optional[str] = None + _origin_server_ts: Optional[int] = None - internal_metadata = attr.ib( - default=attr.Factory(lambda: _EventInternalMetadata({})), - type=_EventInternalMetadata, + internal_metadata: _EventInternalMetadata = attr.Factory( + lambda: _EventInternalMetadata({}) ) @property @@ -123,7 +125,9 @@ class EventBuilder: state_ids = await self._state.get_current_state_ids( self.room_id, prev_event_ids ) - auth_event_ids = self._auth.compute_auth_events(self, state_ids) + auth_event_ids = self._event_auth_handler.compute_auth_events( + self, state_ids + ) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: @@ -184,24 +188,23 @@ class EventBuilder: class EventBuilderFactory: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hostname = hs.hostname self.signing_key = hs.signing_key self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() - def new(self, room_version, key_values): + def new(self, room_version: str, key_values: dict) -> EventBuilder: """Generate an event builder appropriate for the given room version Deprecated: use for_room_version with a RoomVersion object instead Args: - room_version (str): Version of the room that we're creating an event builder - for - key_values (dict): Fields used as the basis of the new event + room_version: Version of the room that we're creating an event builder for + key_values: Fields used as the basis of the new event Returns: EventBuilder @@ -212,13 +215,15 @@ class EventBuilderFactory: raise UnsupportedRoomVersionError() return self.for_room_version(v, key_values) - def for_room_version(self, room_version, key_values): + def for_room_version( + self, room_version: RoomVersion, key_values: dict + ) -> EventBuilder: """Generate an event builder appropriate for the given room version Args: - room_version (synapse.api.room_versions.RoomVersion): + room_version: Version of the room that we're creating an event builder for - key_values (dict): Fields used as the basis of the new event + key_values: Fields used as the basis of the new event Returns: EventBuilder @@ -226,7 +231,7 @@ class EventBuilderFactory: return EventBuilder( store=self.store, state=self.state, - auth=self.auth, + event_auth_handler=self._event_auth_handler, clock=self.clock, hostname=self.hostname, signing_key=self.signing_key, @@ -286,15 +291,15 @@ def create_local_event_from_event_dict( _event_id_counter = 0 -def _create_event_id(clock, hostname): +def _create_event_id(clock: Clock, hostname: str) -> str: """Create a new event ID Args: - clock (Clock) - hostname (str): The server name for the event ID + clock + hostname: The server name for the event ID Returns: - str + The new event ID """ global _event_id_counter diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index c066617b92..2bfe6a3d37 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py
@@ -89,12 +89,12 @@ class FederationBase: result = await self.spam_checker.check_event_for_spam(pdu) if result: - logger.warning( - "Event contains spam, redacting %s: %s", - pdu.event_id, - pdu.get_pdu_json(), - ) - return prune_event(pdu) + logger.warning("Event contains spam, soft-failing %s", pdu.event_id) + # we redact (to save disk space) as well as soft-failing (to stop + # using the event in prev_events). + redacted_event = prune_event(pdu) + redacted_event.internal_metadata.soft_failed = True + return redacted_event return pdu diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 1d050e54e2..ac0f2ccfb3 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -34,7 +34,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -46,6 +46,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction @@ -107,9 +108,9 @@ class FederationServer(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.auth = hs.get_auth() self.handler = hs.get_federation_handler() self.state = hs.get_state_handler() + self._event_auth_handler = hs.get_event_auth_handler() self.device_handler = hs.get_device_handler() @@ -147,6 +148,41 @@ class FederationServer(FederationBase): self._room_prejoin_state_types = hs.config.api.room_prejoin_state + # Whether we have started handling old events in the staging area. + self._started_handling_of_staged_events = False + + @wrap_as_background_process("_handle_old_staged_events") + async def _handle_old_staged_events(self) -> None: + """Handle old staged events by fetching all rooms that have staged + events and start the processing of each of those rooms. + """ + + # Get all the rooms IDs with staged events. + room_ids = await self.store.get_all_rooms_with_staged_incoming_events() + + # We then shuffle them so that if there are multiple instances doing + # this work they're less likely to collide. + random.shuffle(room_ids) + + for room_id in room_ids: + room_version = await self.store.get_room_version(room_id) + + # Try and acquire the processing lock for the room, if we get it start a + # background process for handling the events in the room. + lock = await self.store.try_acquire_lock( + _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id + ) + if lock: + logger.info("Handling old staged inbound events in %s", room_id) + self._process_incoming_pdus_in_room_inner( + room_id, + room_version, + lock, + ) + + # We pause a bit so that we don't start handling all rooms at once. + await self._clock.sleep(random.uniform(0, 0.1)) + async def on_backfill_request( self, origin: str, room_id: str, versions: List[str], limit: int ) -> Tuple[int, Dict[str, Any]]: @@ -165,6 +201,12 @@ class FederationServer(FederationBase): async def on_incoming_transaction( self, origin: str, transaction_data: JsonDict ) -> Tuple[int, Dict[str, Any]]: + # If we receive a transaction we should make sure that kick off handling + # any old events in the staging area. + if not self._started_handling_of_staged_events: + self._started_handling_of_staged_events = True + self._handle_old_staged_events() + # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -368,22 +410,21 @@ class FederationServer(FederationBase): async def process_pdu(pdu: EventBase) -> JsonDict: event_id = pdu.event_id - with pdu_process_time.time(): - with nested_logging_context(event_id): - try: - await self._handle_received_pdu(origin, pdu) - return {} - except FederationError as e: - logger.warning("Error handling PDU %s: %s", event_id, e) - return {"error": str(e)} - except Exception as e: - f = failure.Failure() - logger.error( - "Failed to handle PDU %s", - event_id, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore - ) - return {"error": str(e)} + with nested_logging_context(event_id): + try: + await self._handle_received_pdu(origin, pdu) + return {} + except FederationError as e: + logger.warning("Error handling PDU %s: %s", event_id, e) + return {"error": str(e)} + except Exception as e: + f = failure.Failure() + logger.error( + "Failed to handle PDU %s", + event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + return {"error": str(e)} await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT @@ -420,7 +461,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -453,7 +494,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -544,26 +585,21 @@ class FederationServer(FederationBase): return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict + self, origin: str, content: JsonDict, room_id: str ) -> Dict[str, Any]: - logger.debug("on_send_join_request: content: %s", content) - - assert_params_in_dict(content, ["room_id"]) - room_version = await self.store.get_room_version(content["room_id"]) - pdu = event_from_pdu_json(content, room_version) - - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) - - logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) + context = await self._on_send_membership_event( + origin, content, Membership.JOIN, room_id + ) - pdu = await self._check_sigs_and_hash(room_version, pdu) + prev_state_ids = await context.get_prev_state_ids() + state_ids = list(prev_state_ids.values()) + auth_chain = await self.store.get_auth_chain(room_id, state_ids) + state = await self.store.get_events(state_ids) - res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + "state": [p.get_pdu_json(time_now) for p in state.values()], + "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], } async def on_make_leave_request( @@ -578,21 +614,11 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict: + async def on_send_leave_request( + self, origin: str, content: JsonDict, room_id: str + ) -> dict: logger.debug("on_send_leave_request: content: %s", content) - - assert_params_in_dict(content, ["room_id"]) - room_version = await self.store.get_room_version(content["room_id"]) - pdu = event_from_pdu_json(content, room_version) - - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) - - logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - - pdu = await self._check_sigs_and_hash(room_version, pdu) - - await self.handler.on_send_leave_request(origin, pdu) + await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id) return {} async def on_make_knock_request( @@ -658,39 +684,76 @@ class FederationServer(FederationBase): Returns: The stripped room state. """ - logger.debug("on_send_knock_request: content: %s", content) + event_context = await self._on_send_membership_event( + origin, content, Membership.KNOCK, room_id + ) + + # Retrieve stripped state events from the room and send them back to the remote + # server. This will allow the remote server's clients to display information + # related to the room while the knock request is pending. + stripped_room_state = ( + await self.store.get_stripped_room_state_from_event_context( + event_context, self._room_prejoin_state_types + ) + ) + return {"knock_state_events": stripped_room_state} + + async def _on_send_membership_event( + self, origin: str, content: JsonDict, membership_type: str, room_id: str + ) -> EventContext: + """Handle an on_send_{join,leave,knock} request + + Does some preliminary validation before passing the request on to the + federation handler. + + Args: + origin: The (authenticated) requesting server + content: The body of the send_* request - a complete membership event + membership_type: The expected membership type (join or leave, depending + on the endpoint) + room_id: The room_id from the request, to be validated against the room_id + in the event + + Returns: + The context of the event after inserting it into the room graph. + + Raises: + SynapseError if there is a problem with the request, including things like + the room_id not matching or the event not being authorized. + """ + assert_params_in_dict(content, ["room_id"]) + if content["room_id"] != room_id: + raise SynapseError( + 400, + "Room ID in body does not match that in request path", + Codes.BAD_JSON, + ) room_version = await self.store.get_room_version(room_id) - # Check that this room supports knocking as defined by its room version - if not room_version.msc2403_knocking: + if membership_type == Membership.KNOCK and not room_version.msc2403_knocking: raise SynapseError( 403, "This room version does not support knocking", errcode=Codes.FORBIDDEN, ) - pdu = event_from_pdu_json(content, room_version) + event = event_from_pdu_json(content, room_version) - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) + if event.type != EventTypes.Member or not event.is_state(): + raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON) - logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures) + if event.content.get("membership") != membership_type: + raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON) - pdu = await self._check_sigs_and_hash(room_version, pdu) + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, event.room_id) - # Handle the event, and retrieve the EventContext - event_context = await self.handler.on_send_knock_request(origin, pdu) + logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures) - # Retrieve stripped state events from the room and send them back to the remote - # server. This will allow the remote server's clients to display information - # related to the room while the knock request is pending. - stripped_room_state = ( - await self.store.get_stripped_room_state_from_event_context( - event_context, self._room_prejoin_state_types - ) - ) - return {"knock_state_events": stripped_room_state} + event = await self._check_sigs_and_hash(room_version, event) + + return await self.handler.on_send_membership_event(origin, event) async def on_event_auth( self, origin: str, room_id: str, event_id: str @@ -860,32 +923,39 @@ class FederationServer(FederationBase): room_id: str, room_version: RoomVersion, lock: Lock, - latest_origin: str, - latest_event: EventBase, + latest_origin: Optional[str] = None, + latest_event: Optional[EventBase] = None, ) -> None: """Process events in the staging area for the given room. The latest_origin and latest_event args are the latest origin and event - received. + received (or None to simply pull the next event from the database). """ # The common path is for the event we just received be the only event in # the room, so instead of pulling the event out of the DB and parsing # the event we just pull out the next event ID and check if that matches. - next_origin, next_event_id = await self.store.get_next_staged_event_id_for_room( - room_id - ) - if next_origin == latest_origin and next_event_id == latest_event.event_id: - origin = latest_origin - event = latest_event - else: + if latest_event is not None and latest_origin is not None: + ( + next_origin, + next_event_id, + ) = await self.store.get_next_staged_event_id_for_room(room_id) + if next_origin != latest_origin or next_event_id != latest_event.event_id: + latest_origin = None + latest_event = None + + if latest_origin is None or latest_event is None: next = await self.store.get_next_staged_event_for_room( room_id, room_version ) if not next: + await lock.release() return origin, event = next + else: + origin = latest_origin + event = latest_event # We loop round until there are no more events in the room in the # staging area, or we fail to get the lock (which means another process @@ -909,9 +979,13 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore ) - await self.store.remove_received_event_from_staging( + received_ts = await self.store.remove_received_event_from_staging( origin, event.event_id ) + if received_ts is not None: + pdu_process_time.observe( + (self._clock.time_msec() - received_ts) / 1000 + ) # We need to do this check outside the lock to avoid a race between # a new event being inserted by another instance and it attempting diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a9942b41fb..5685a71a4b 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -15,7 +15,19 @@ import functools import logging import re -from typing import Container, Mapping, Optional, Sequence, Tuple, Type +from typing import ( + Container, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from typing_extensions import Literal import synapse from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH @@ -57,15 +69,15 @@ logger = logging.getLogger(__name__) class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" - def __init__(self, hs, servlet_groups=None): + def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): """Initialize the TransportLayerServer Will by default register all servlets. For custom behaviour, pass in a list of servlet_groups to register. Args: - hs (synapse.server.HomeServer): homeserver - servlet_groups (list[str], optional): List of servlet groups to register. + hs: homeserver + servlet_groups: List of servlet groups to register. Defaults to ``DEFAULT_SERVLET_GROUPS``. """ self.hs = hs @@ -79,7 +91,7 @@ class TransportLayerServer(JsonResource): self.register_servlets() - def register_servlets(self): + def register_servlets(self) -> None: register_servlets( self.hs, resource=self, @@ -92,14 +104,10 @@ class TransportLayerServer(JsonResource): class AuthenticationError(SynapseError): """There was a problem authenticating the request""" - pass - class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" - pass - class Authenticator: def __init__(self, hs: HomeServer): @@ -411,13 +419,18 @@ class FederationSendServlet(BaseFederationServerServlet): RATELIMIT = False # This is when someone is trying to send us a bunch of data. - async def on_PUT(self, origin, content, query, transaction_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + transaction_id: str, + ) -> Tuple[int, JsonDict]: """Called on PUT /send/<transaction_id>/ Args: - request (twisted.web.http.Request): The HTTP request. - transaction_id (str): The transaction_id associated with this - request. This is *not* None. + transaction_id: The transaction_id associated with this request. This + is *not* None. Returns: Tuple of `(code, response)`, where @@ -462,7 +475,13 @@ class FederationEventServlet(BaseFederationServerServlet): PATH = "/event/(?P<event_id>[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. - async def on_GET(self, origin, content, query, event_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + event_id: str, + ) -> Tuple[int, Union[JsonDict, str]]: return await self.handler.on_pdu_request(origin, event_id) @@ -470,7 +489,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet): PATH = "/state/(?P<room_id>[^/]*)/?" # This is when someone asks for all data for a given room. - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_room_state_request( origin, room_id, @@ -481,7 +506,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet): class FederationStateIdsServlet(BaseFederationServerServlet): PATH = "/state_ids/(?P<room_id>[^/]*)/?" - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_state_ids_request( origin, room_id, @@ -492,7 +523,13 @@ class FederationStateIdsServlet(BaseFederationServerServlet): class FederationBackfillServlet(BaseFederationServerServlet): PATH = "/backfill/(?P<room_id>[^/]*)/?" - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) @@ -506,7 +543,13 @@ class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P<query_type>[^/]*)" # This is when we receive a server-server Query - async def on_GET(self, origin, content, query, query_type): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + query_type: str, + ) -> Tuple[int, JsonDict]: args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} args["origin"] = origin return await self.handler.on_query_request(query_type, args) @@ -515,47 +558,66 @@ class FederationQueryServlet(BaseFederationServerServlet): class FederationMakeJoinServlet(BaseFederationServerServlet): PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" - async def on_GET(self, origin, _content, query, room_id, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: """ Args: - origin (unicode): The authenticated server_name of the calling server + origin: The authenticated server_name of the calling server - _content (None): (GETs don't have bodies) + content: (GETs don't have bodies) - query (dict[bytes, list[bytes]]): Query params from the request. + query: Query params from the request. - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. + **kwargs: the dict mapping keys to path components as specified in + the path match regexp. Returns: - Tuple[int, object]: (response code, response object) + Tuple of (response code, response object) """ - versions = query.get(b"ver") - if versions is not None: - supported_versions = [v.decode("utf-8") for v in versions] - else: + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + if supported_versions is None: supported_versions = ["1"] - content = await self.handler.on_make_join_request( + result = await self.handler.on_make_join_request( origin, room_id, user_id, supported_versions=supported_versions ) - return 200, content + return 200, result class FederationMakeLeaveServlet(BaseFederationServerServlet): PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" - async def on_GET(self, origin, content, query, room_id, user_id): - content = await self.handler.on_make_leave_request(origin, room_id, user_id) - return 200, content + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_make_leave_request(origin, room_id, user_id) + return 200, result class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content) - return 200, (200, content) + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, result) class FederationV2SendLeaveServlet(BaseFederationServerServlet): @@ -563,50 +625,84 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content) - return 200, content + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, result class FederationMakeKnockServlet(BaseFederationServerServlet): PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" - async def on_GET(self, origin, content, query, room_id, user_id): - try: - # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") - except KeyError: - raise SynapseError(400, "Missing required query parameter 'ver'") + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_strings_from_args( + query, "ver", required=True, encoding="utf-8" + ) - content = await self.handler.on_make_knock_request( + result = await self.handler.on_make_knock_request( origin, room_id, user_id, supported_versions=supported_versions ) - return 200, content + return 200, result class FederationV1SendKnockServlet(BaseFederationServerServlet): PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_knock_request(origin, content, room_id) - return 200, content + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, result class FederationEventAuthServlet(BaseFederationServerServlet): PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_GET(self, origin, content, query, room_id, event_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_event_auth(origin, room_id, event_id) class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): - # TODO(paul): assert that room_id/event_id parsed from path actually + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content) - return 200, (200, content) + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, (200, result) class FederationV2SendJoinServlet(BaseFederationServerServlet): @@ -614,28 +710,42 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): - # TODO(paul): assert that room_id/event_id parsed from path actually + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content) - return 200, content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, result class FederationV1InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: # We don't get a room version, so we have to assume its EITHER v1 or # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing # invites - content = await self.handler.on_invite_request( + result = await self.handler.on_invite_request( origin, content, room_version_id=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` # due to a historical bug. - return 200, (200, content) + return 200, (200, result) class FederationV2InviteServlet(BaseFederationServerServlet): @@ -643,7 +753,14 @@ class FederationV2InviteServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content @@ -656,16 +773,22 @@ class FederationV2InviteServlet(BaseFederationServerServlet): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state - content = await self.handler.on_invite_request( + result = await self.handler.on_invite_request( origin, event, room_version_id=room_version ) - return 200, content + return 200, result class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" - async def on_PUT(self, origin, content, query, room_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: await self.handler.on_exchange_third_party_invite_request(content) return 200, {} @@ -673,21 +796,31 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): class FederationClientKeysQueryServlet(BaseFederationServerServlet): PATH = "/user/keys/query" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: return await self.handler.on_query_client_keys(origin, content) class FederationUserDevicesQueryServlet(BaseFederationServerServlet): PATH = "/user/devices/(?P<user_id>[^/]*)" - async def on_GET(self, origin, content, query, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + user_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_query_user_devices(origin, user_id) class FederationClientKeysClaimServlet(BaseFederationServerServlet): PATH = "/user/keys/claim" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: response = await self.handler.on_claim_client_keys(origin, content) return 200, response @@ -696,12 +829,18 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P<room_id>[^/]*)/?" - async def on_POST(self, origin, content, query, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: limit = int(content.get("limit", 10)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) - content = await self.handler.on_get_missing_events( + result = await self.handler.on_get_missing_events( origin, room_id=room_id, earliest_events=earliest_events, @@ -709,7 +848,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): limit=limit, ) - return 200, content + return 200, result class On3pidBindServlet(BaseFederationServerServlet): @@ -717,7 +856,9 @@ class On3pidBindServlet(BaseFederationServerServlet): REQUIRE_AUTH = False - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: if "invites" in content: last_exception = None for invite in content["invites"]: @@ -763,15 +904,20 @@ class OpenIdUserInfo(BaseFederationServerServlet): REQUIRE_AUTH = False - async def on_GET(self, origin, content, query): - token = query.get(b"access_token", [None])[0] + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + token = parse_string_from_args(query, "access_token") if token is None: return ( 401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, ) - user_id = await self.handler.on_openid_userinfo(token.decode("ascii")) + user_id = await self.handler.on_openid_userinfo(token) if user_id is None: return ( @@ -830,7 +976,9 @@ class PublicRoomList(BaseFederationServlet): self.handler = hs.get_room_list_handler() self.allow_access = allow_access - async def on_GET(self, origin, content, query): + async def on_GET( + self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: if not self.allow_access: raise FederationDeniedError(origin) @@ -859,7 +1007,9 @@ class PublicRoomList(BaseFederationServlet): ) return 200, data - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: # This implements MSC2197 (Search Filtering over Federation) if not self.allow_access: raise FederationDeniedError(origin) @@ -956,7 +1106,12 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False - async def on_GET(self, origin, content, query): + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: return ( 200, {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, @@ -985,7 +1140,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/profile" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -994,7 +1155,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): return 200, new_content - async def on_POST(self, origin, content, query, group_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1009,7 +1176,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): class FederationGroupsSummaryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/summary" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1024,7 +1197,13 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/rooms" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1039,7 +1218,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" - async def on_POST(self, origin, content, query, group_id, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1050,7 +1236,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): return 200, new_content - async def on_DELETE(self, origin, content, query, group_id, room_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1070,7 +1263,15 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): "/config/(?P<config_key>[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, room_id, config_key): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + config_key: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1087,7 +1288,13 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/users" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1102,7 +1309,13 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/invited_users" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1119,7 +1332,14 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1136,7 +1356,14 @@ class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1150,7 +1377,14 @@ class FederationGroupsJoinServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1164,7 +1398,14 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1198,7 +1439,14 @@ class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") @@ -1216,7 +1464,14 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, None]: if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1224,11 +1479,9 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): self.handler, GroupsLocalHandler ), "Workers cannot handle group removals." - new_content = await self.handler.user_removed_from_group( - group_id, user_id, content - ) + await self.handler.user_removed_from_group(group_id, user_id, content) - return 200, new_content + return 200, None class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): @@ -1246,7 +1499,14 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_groups_attestation_renewer() - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: # We don't need to check auth here as we check the attestation signatures new_content = await self.handler.on_renew_attestation( @@ -1270,7 +1530,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): "/rooms/(?P<room_id>[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, category_id, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1298,7 +1566,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1318,7 +1594,13 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/categories/?" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1333,7 +1615,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" - async def on_GET(self, origin, content, query, group_id, category_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1344,7 +1633,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): return 200, resp - async def on_POST(self, origin, content, query, group_id, category_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1366,7 +1662,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, category_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1386,7 +1689,13 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/roles/?" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1401,7 +1710,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" - async def on_GET(self, origin, content, query, group_id, role_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1410,7 +1726,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): return 200, resp - async def on_POST(self, origin, content, query, group_id, role_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1434,7 +1757,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, role_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1463,7 +1793,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): "/users/(?P<user_id>[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, role_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1489,7 +1827,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1509,7 +1855,9 @@ class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): PATH = "/get_groups_publicised" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: resp = await self.handler.bulk_get_publicised_groups( content["user_ids"], proxy=False ) @@ -1522,7 +1870,13 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" - async def on_PUT(self, origin, content, query, group_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1551,7 +1905,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): async def on_GET( self, origin: str, - content: JsonDict, + content: Literal[None], query: Mapping[bytes, Sequence[bytes]], room_id: str, ) -> Tuple[int, JsonDict]: @@ -1623,7 +1977,13 @@ class RoomComplexityServlet(BaseFederationServlet): super().__init__(hs, authenticator, ratelimiter, server_name) self._store = self.hs.get_datastore() - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: is_public = await self._store.is_room_world_readable_or_publicly_joinable( room_id ) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f72ded038e..d75a8b15c3 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -62,9 +62,16 @@ class AdminHandler(BaseHandler): if ret: profile = await self.store.get_profileinfo(user.localpart) threepids = await self.store.user_get_threepids(user.to_string()) + external_ids = [ + ({"auth_provider": auth_provider, "external_id": external_id}) + for auth_provider, external_id in await self.store.get_external_ids_by_user( + user.to_string() + ) + ] ret["displayname"] = profile.display_name ret["avatar_url"] = profile.avatar_url ret["threepids"] = threepids + ret["external_ids"] = external_ids return ret async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1971e373ed..e2ac595a62 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -30,6 +30,7 @@ from typing import ( Optional, Tuple, Union, + cast, ) import attr @@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: + from synapse.rest.client.v1.login import LoginResponse from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -777,6 +779,108 @@ class AuthHandler(BaseHandler): "params": params, } + async def refresh_token( + self, + refresh_token: str, + valid_until_ms: Optional[int], + ) -> Tuple[str, str]: + """ + Consumes a refresh token and generate both a new access token and a new refresh token from it. + + The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token. + + Args: + refresh_token: The token to consume. + valid_until_ms: The expiration timestamp of the new access token. + + Returns: + A tuple containing the new access token and refresh token + """ + + # Verify the token signature first before looking up the token + if not self._verify_refresh_token(refresh_token): + raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN) + + existing_token = await self.store.lookup_refresh_token(refresh_token) + if existing_token is None: + raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN) + + if ( + existing_token.has_next_access_token_been_used + or existing_token.has_next_refresh_token_been_refreshed + ): + raise SynapseError( + 403, "refresh token isn't valid anymore", Codes.FORBIDDEN + ) + + ( + new_refresh_token, + new_refresh_token_id, + ) = await self.get_refresh_token_for_user_id( + user_id=existing_token.user_id, device_id=existing_token.device_id + ) + access_token = await self.get_access_token_for_user_id( + user_id=existing_token.user_id, + device_id=existing_token.device_id, + valid_until_ms=valid_until_ms, + refresh_token_id=new_refresh_token_id, + ) + await self.store.replace_refresh_token( + existing_token.token_id, new_refresh_token_id + ) + return access_token, new_refresh_token + + def _verify_refresh_token(self, token: str) -> bool: + """ + Verifies the shape of a refresh token. + + Args: + token: The refresh token to verify + + Returns: + Whether the token has the right shape + """ + parts = token.split("_", maxsplit=4) + if len(parts) != 4: + return False + + type, localpart, rand, crc = parts + + # Refresh tokens are prefixed by "syr_", let's check that + if type != "syr": + return False + + # Check the CRC + base = f"{type}_{localpart}_{rand}" + expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + if crc != expected_crc: + return False + + return True + + async def get_refresh_token_for_user_id( + self, + user_id: str, + device_id: str, + ) -> Tuple[str, int]: + """ + Creates a new refresh token for the user with the given user ID. + + Args: + user_id: canonical user ID + device_id: the device ID to associate with the token. + + Returns: + The newly created refresh token and its ID in the database + """ + refresh_token = self.generate_refresh_token(UserID.from_string(user_id)) + refresh_token_id = await self.store.add_refresh_token_to_user( + user_id=user_id, + token=refresh_token, + device_id=device_id, + ) + return refresh_token, refresh_token_id + async def get_access_token_for_user_id( self, user_id: str, @@ -784,6 +888,7 @@ class AuthHandler(BaseHandler): valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, is_appservice_ghost: bool = False, + refresh_token_id: Optional[int] = None, ) -> str: """ Creates a new access token for the user with the given user ID. @@ -801,6 +906,8 @@ class AuthHandler(BaseHandler): valid_until_ms: when the token is valid until. None for no expiry. is_appservice_ghost: Whether the user is an application ghost user + refresh_token_id: the refresh token ID that will be associated with + this access token. Returns: The access token for the user's session. Raises: @@ -836,6 +943,7 @@ class AuthHandler(BaseHandler): device_id=device_id, valid_until_ms=valid_until_ms, puppets_user_id=puppets_user_id, + refresh_token_id=refresh_token_id, ) # the device *should* have been registered before we got here; however, @@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler): self, login_submission: Dict[str, Any], ratelimit: bool = False, - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate auth types which don't @@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler): self, username: str, login_submission: Dict[str, Any], - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Helper for validate_login Handles login, once we've mapped 3pids onto userids @@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler): async def check_password_provider_3pid( self, medium: str, address: str, password: str - ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler): crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) return f"{base}_{crc}" + def generate_refresh_token(self, for_user: UserID) -> str: + """Generates an opaque string, for use as a refresh token""" + + # we use the following format for refresh tokens: + # syr_<base64 local part>_<random string>_<base62 crc check> + + b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8")) + random_string = stringutils.random_string(20) + base = f"syr_{b64local}_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + async def validate_short_term_login_token( self, login_token: str ) -> LoginTokenAttributes: @@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler): ) respond_with_html(request, 200, html) - async def _sso_login_callback(self, login_result: JsonDict) -> None: + async def _sso_login_callback(self, login_result: "LoginResponse") -> None: """ A login callback which might add additional attributes to the login response. @@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler): extra_attributes = self._extra_attributes.get(login_result["user_id"]) if extra_attributes: - login_result.update(extra_attributes.extra_attributes) + login_result_dict = cast(Dict[str, Any], login_result) + login_result_dict.update(extra_attributes.extra_attributes) def _expire_sso_extra_attributes(self) -> None: """ diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 989996b628..41dbdfd0a1 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py
@@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Collection, Optional +from typing import TYPE_CHECKING, Collection, List, Optional, Union +from synapse import event_auth from synapse.api.constants import ( EventTypes, JoinRules, @@ -20,9 +21,11 @@ from synapse.api.constants import ( RestrictedJoinRuleTypes, ) from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersion +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase +from synapse.events.builder import EventBuilder from synapse.types import StateMap +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -34,8 +37,63 @@ class EventAuthHandler: """ def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() self._store = hs.get_datastore() + async def check_from_context( + self, room_version: str, event, context, do_sig_check=True + ) -> None: + auth_event_ids = event.auth_event_ids() + auth_events_by_id = await self._store.get_events(auth_event_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} + + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + event_auth.check( + room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check + ) + + def compute_auth_events( + self, + event: Union[EventBase, EventBuilder], + current_state_ids: StateMap[str], + for_verification: bool = False, + ) -> List[str]: + """Given an event and current state return the list of event IDs used + to auth an event. + + If `for_verification` is False then only return auth events that + should be added to the event's `auth_events`. + + Returns: + List of event IDs. + """ + + if event.type == EventTypes.Create: + return [] + + # Currently we ignore the `for_verification` flag even though there are + # some situations where we can drop particular auth events when adding + # to the event's `auth_events` (e.g. joins pointing to previous joins + # when room is publicly joinable). Dropping event IDs has the + # advantage that the auth chain for the room grows slower, but we use + # the auth chain in state resolution v2 to order events, which means + # care must be taken if dropping events to ensure that it doesn't + # introduce undesirable "state reset" behaviour. + # + # All of which sounds a bit tricky so we don't bother for now. + + auth_ids = [] + for etype, state_key in event_auth.auth_types_for_event(event): + auth_ev_id = current_state_ids.get((etype, state_key)) + if auth_ev_id: + auth_ids.append(auth_ev_id) + + return auth_ids + + async def check_host_in_room(self, room_id: str, host: str) -> bool: + with Measure(self._clock, "check_host_in_room"): + return await self._store.is_host_joined(room_id, host) + async def check_restricted_join_rules( self, state_ids: StateMap[str], diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index aaf209842c..f734ce1861 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -250,7 +250,9 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) if not is_in_room: logger.info( "Ignoring PDU from %s as we're not in the room", @@ -1675,7 +1677,9 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(room_id) # now check that we are *still* in the room - is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) if not is_in_room: logger.info( "Got /make_join request for room %s we are no longer in", @@ -1706,86 +1710,12 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) return event - async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict: - """We have received a join event for a room. Fully process it and - respond with the current state and auth chains. - """ - event = pdu - - logger.debug( - "on_send_join_request from %s: Got event: %s, signatures: %s", - origin, - event.event_id, - event.signatures, - ) - - if get_domain_from_id(event.sender) != origin: - logger.info( - "Got /send_join request for user %r from different origin %s", - event.sender, - origin, - ) - raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - - event.internal_metadata.outlier = False - # Send this event on behalf of the origin server. - # - # The reasons we have the destination server rather than the origin - # server send it are slightly mysterious: the origin server should have - # all the necessary state once it gets the response to the send_join, - # so it could send the event itself if it wanted to. It may be that - # doing it this way reduces failure modes, or avoids certain attacks - # where a new server selectively tells a subset of the federation that - # it has joined. - # - # The fact is that, as of the current writing, Synapse doesn't send out - # the join event over federation after joining, and changing it now - # would introduce the danger of backwards-compatibility problems. - event.internal_metadata.send_on_behalf_of = origin - - # Calculate the event context. - context = await self.state_handler.compute_event_context(event) - - # Get the state before the new event. - prev_state_ids = await context.get_prev_state_ids() - - # Check if the user is already in the room or invited to the room. - user_id = event.state_key - prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - prev_member_event = None - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - - # Check if the member should be allowed access via membership in a space. - await self._event_auth_handler.check_restricted_join_rules( - prev_state_ids, - event.room_version, - user_id, - prev_member_event, - ) - - # Persist the event. - await self._auth_and_persist_event(origin, event, context) - - logger.debug( - "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s", - event.event_id, - event.signatures, - ) - - state_ids = list(prev_state_ids.values()) - auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) - - state = await self.store.get_events(list(prev_state_ids.values())) - - return {"state": list(state.values()), "auth_chain": auth_chain} - async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion ) -> EventBase: @@ -1959,7 +1889,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) except AuthError as e: @@ -1968,37 +1898,6 @@ class FederationHandler(BaseHandler): return event - async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None: - """We have received a leave event for a room. Fully process it.""" - event = pdu - - logger.debug( - "on_send_leave_request: Got event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - if get_domain_from_id(event.sender) != origin: - logger.info( - "Got /send_leave request for user %r from different origin %s", - event.sender, - origin, - ) - raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - - event.internal_metadata.outlier = False - - context = await self.state_handler.compute_event_context(event) - await self._auth_and_persist_event(origin, event, context) - - logger.debug( - "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s", - event.event_id, - event.signatures, - ) - - return None - @log_function async def on_make_knock_request( self, origin: str, room_id: str, user_id: str @@ -2052,7 +1951,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( room_version, event, context, do_sig_check=False ) except AuthError as e: @@ -2062,51 +1961,115 @@ class FederationHandler(BaseHandler): return event @log_function - async def on_send_knock_request( + async def on_send_membership_event( self, origin: str, event: EventBase ) -> EventContext: """ - We have received a knock event for a room. Verify that event and send it into the room - on the knocking homeserver's behalf. + We have received a join/leave/knock event for a room via send_join/leave/knock. + + Verify that event and send it into the room on the remote homeserver's behalf. + + This is quite similar to on_receive_pdu, with the following principal + differences: + * only membership events are permitted (and only events with + sender==state_key -- ie, no kicks or bans) + * *We* send out the event on behalf of the remote server. + * We enforce the membership restrictions of restricted rooms. + * Rejected events result in an exception rather than being stored. + + There are also other differences, however it is not clear if these are by + design or omission. In particular, we do not attempt to backfill any missing + prev_events. Args: - origin: The remote homeserver of the knocking user. - event: The knocking member event that has been signed by the remote homeserver. + origin: The homeserver of the remote (joining/invited/knocking) user. + event: The member event that has been signed by the remote homeserver. Returns: The context of the event after inserting it into the room graph. + + Raises: + SynapseError if the event is not accepted into the room """ logger.debug( - "on_send_knock_request: Got event: %s, signatures: %s", + "on_send_membership_event: Got event: %s, signatures: %s", event.event_id, event.signatures, ) if get_domain_from_id(event.sender) != origin: logger.info( - "Got /send_knock request for user %r from different origin %s", + "Got send_membership request for user %r from different origin %s", event.sender, origin, ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - event.internal_metadata.outlier = False + if event.sender != event.state_key: + raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) - context = await self.state_handler.compute_event_context(event) + assert not event.internal_metadata.outlier - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of knock %s forbidden by third-party rules", event) + # Send this event on behalf of the other server. + # + # The remote server isn't a full participant in the room at this point, so + # may not have an up-to-date list of the other homeservers participating in + # the room, so we send it on their behalf. + event.internal_metadata.send_on_behalf_of = origin + + context = await self.state_handler.compute_event_context(event) + context = await self._check_event_auth(origin, event, context) + if context.rejected: raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN + 403, f"{event.membership} event was rejected", Codes.FORBIDDEN ) - await self._auth_and_persist_event(origin, event, context) + # for joins, we need to check the restrictions of restricted rooms + if event.membership == Membership.JOIN: + await self._check_join_restrictions(context, event) + # for knock events, we run the third-party event rules. It's not entirely clear + # why we don't do this for other sorts of membership events. + if event.membership == Membership.KNOCK: + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + # all looks good, we can persist the event. + await self._run_push_actions_and_persist_event(event, context) return context + async def _check_join_restrictions( + self, context: EventContext, event: EventBase + ) -> None: + """Check that restrictions in restricted join rules are matched + + Called when we receive a join event via send_join. + + Raises an auth error if the restrictions are not matched. + """ + prev_state_ids = await context.get_prev_state_ids() + + # Check if the user is already in the room or invited to the room. + user_id = event.state_key + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + + # Check if the member should be allowed access via membership in a space. + await self._event_auth_handler.check_restricted_join_rules( + prev_state_ids, + event.room_version, + user_id, + prev_member_event, + ) + async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event.""" @@ -2160,7 +2123,7 @@ class FederationHandler(BaseHandler): async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2195,7 +2158,9 @@ class FederationHandler(BaseHandler): ) if event: - in_room = await self.auth.check_host_in_room(event.room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room( + event.room_id, origin + ) if not in_room: raise AuthError(403, "Host not in room.") @@ -2248,6 +2213,18 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) + await self._run_push_actions_and_persist_event(event, context, backfilled) + + async def _run_push_actions_and_persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ): + """Run the push actions for a received event, and persist it. + + Args: + event: The event itself. + context: The event context. + backfilled: True if the event was backfilled. + """ try: if ( not event.internal_metadata.is_outlier() @@ -2536,7 +2513,7 @@ class FederationHandler(BaseHandler): latest_events: List[str], limit: int, ) -> List[EventBase]: - in_room = await self.auth.check_host_in_room(room_id, origin) + in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2561,9 +2538,9 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - state: Optional[Iterable[EventBase]], - auth_events: Optional[MutableStateMap[EventBase]], - backfilled: bool, + state: Optional[Iterable[EventBase]] = None, + auth_events: Optional[MutableStateMap[EventBase]] = None, + backfilled: bool = False, ) -> EventContext: """ Checks whether an event should be rejected (for failing auth checks). @@ -2599,7 +2576,7 @@ class FederationHandler(BaseHandler): if not auth_events: prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_x = await self.store.get_events(auth_events_ids) @@ -3028,7 +3005,7 @@ class FederationHandler(BaseHandler): "state_key": target_user_id, } - if await self.auth.check_host_in_room(room_id, self.hs.hostname): + if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname): room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) @@ -3048,7 +3025,9 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e @@ -3091,7 +3070,9 @@ class FederationHandler(BaseHandler): ) try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) raise e @@ -3179,7 +3160,7 @@ class FederationHandler(BaseHandler): last_exception = None # type: Optional[Exception] # for each public key in the 3pid invite event - for public_key_object in self.hs.get_auth().get_public_keys(invite_event): + for public_key_object in event_auth.get_public_keys(invite_event): try: # for each sig on the third_party_invite block of the actual invite for server, signature_block in signed["signatures"].items(): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 74a7fc7ea2..eead07d94e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -386,6 +386,7 @@ class EventCreationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastore() self.storage = hs.get_storage() self.state = hs.get_state_handler() @@ -510,6 +511,8 @@ class EventCreationHandler: Should normally be left as None, which will cause them to be calculated based on the room state at the prev_events. + If non-None, prev_event_ids must also be provided. + require_consent: Whether to check if the requester has consented to the privacy policy. @@ -582,6 +585,9 @@ class EventCreationHandler: # Strip down the auth_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender if auth_event_ids is not None: + # If auth events are provided, prev events must be also. + assert prev_event_ids is not None + temp_event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, @@ -593,7 +599,7 @@ class EventCreationHandler: (e.type, e.state_key): e.event_id for e in auth_events } # Actually strip down and use the necessary auth events - auth_event_ids = self.auth.compute_auth_events( + auth_event_ids = self._event_auth_handler.compute_auth_events( event=temp_event, current_state_ids=auth_event_state_map, for_verification=False, @@ -785,6 +791,8 @@ class EventCreationHandler: The event ids to use as the auth_events for the new event. Should normally be left as None, which will cause them to be calculated based on the room state at the prev_events. + + If non-None, prev_event_ids must also be provided. ratelimit: Whether to rate limit this send. txn_id: The transaction ID. ignore_shadow_ban: True if shadow-banned users should be allowed to @@ -1050,7 +1058,9 @@ class EventCreationHandler: assert event.content["membership"] == Membership.LEAVE else: try: - await self.auth.check_from_context(room_version, event, context) + await self._event_auth_handler.check_from_context( + room_version, event, context + ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) raise err @@ -1375,7 +1385,7 @@ class EventCreationHandler: raise AuthError(403, "Redacting server ACL events is not permitted") prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_map = await self.store.get_events(auth_events_ids) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0880bd0496..c863504beb 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -15,9 +15,10 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from prometheus_client import Counter +from typing_extensions import TypedDict from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -54,6 +55,16 @@ login_counter = Counter( ["guest", "auth_provider"], ) +LoginDict = TypedDict( + "LoginDict", + { + "device_id": str, + "access_token": str, + "valid_until_ms": Optional[int], + "refresh_token": Optional[str], + }, +) + class RegistrationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): @@ -88,6 +99,7 @@ class RegistrationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.session_lifetime + self.access_token_lifetime = hs.config.access_token_lifetime async def check_username( self, @@ -418,11 +430,32 @@ class RegistrationHandler(BaseHandler): room_alias = RoomAlias.from_string(r) if self.hs.hostname != room_alias.domain: - logger.warning( - "Cannot create room alias %s, " - "it does not match server domain", + # If the alias is remote, try to join the room. This might fail + # because the room might be invite only, but we don't have any local + # user in the room to invite this one with, so at this point that's + # the best we can do. + logger.info( + "Cannot automatically create room with alias %s as it isn't" + " local, trying to join the room instead", r, ) + + ( + room, + remote_room_hosts, + ) = await room_member_handler.lookup_room_alias(room_alias) + room_id = room.to_string() + + await room_member_handler.update_membership( + requester=create_requester( + user_id, authenticated_entity=self._server_name + ), + target=UserID.from_string(user_id), + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="join", + ratelimit=False, + ) else: # A shallow copy is OK here since the only key that is # modified is room_alias_name. @@ -480,22 +513,32 @@ class RegistrationHandler(BaseHandler): ) # Calculate whether the room requires an invite or can be - # joined directly. Note that unless a join rule of public exists, - # it is treated as requiring an invite. - requires_invite = True - - state = await self.store.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) + # joined directly. By default, we consider the room as requiring an + # invite if the homeserver is in the room (unless told otherwise by the + # join rules). Otherwise we consider it as being joinable, at the risk of + # failing to join, but in this case there's little more we can do since + # we don't have a local user in the room to craft up an invite with. + requires_invite = await self.store.is_host_joined( + room_id, + self.server_name, ) - event_id = state.get((EventTypes.JoinRules, "")) - if event_id: - join_rules_event = await self.store.get_event( - event_id, allow_none=True + if requires_invite: + # If the server is in the room, check if the room is public. + state = await self.store.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) ) - if join_rules_event: - join_rule = join_rules_event.content.get("join_rule", None) - requires_invite = join_rule and join_rule != JoinRules.PUBLIC + + event_id = state.get((EventTypes.JoinRules, "")) + if event_id: + join_rules_event = await self.store.get_event( + event_id, allow_none=True + ) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + requires_invite = ( + join_rule and join_rule != JoinRules.PUBLIC + ) # Send the invite, if necessary. if requires_invite: @@ -745,7 +788,8 @@ class RegistrationHandler(BaseHandler): is_guest: bool = False, is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, - ) -> Tuple[str, str]: + should_issue_refresh_token: bool = False, + ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. @@ -757,8 +801,9 @@ class RegistrationHandler(BaseHandler): is_guest: Whether this is a guest account auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: Whether it should also issue a refresh token Returns: - Tuple of device ID and access token + Tuple of device ID, access token, access token expiration time and refresh token """ res = await self._register_device_client( user_id=user_id, @@ -766,6 +811,7 @@ class RegistrationHandler(BaseHandler): initial_display_name=initial_display_name, is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) login_counter.labels( @@ -773,7 +819,12 @@ class RegistrationHandler(BaseHandler): auth_provider=(auth_provider_id or ""), ).inc() - return res["device_id"], res["access_token"] + return ( + res["device_id"], + res["access_token"], + res["valid_until_ms"], + res["refresh_token"], + ) async def register_device_inner( self, @@ -782,7 +833,8 @@ class RegistrationHandler(BaseHandler): initial_display_name: Optional[str], is_guest: bool = False, is_appservice_ghost: bool = False, - ) -> Dict[str, str]: + should_issue_refresh_token: bool = False, + ) -> LoginDict: """Helper for register_device Does the bits that need doing on the main process. Not for use outside this @@ -797,6 +849,9 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime + refresh_token = None + refresh_token_id = None + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) @@ -804,14 +859,30 @@ class RegistrationHandler(BaseHandler): assert valid_until_ms is None access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: + if should_issue_refresh_token: + ( + refresh_token, + refresh_token_id, + ) = await self._auth_handler.get_refresh_token_for_user_id( + user_id, + device_id=registered_device_id, + ) + valid_until_ms = self.clock.time_msec() + self.access_token_lifetime + access_token = await self._auth_handler.get_access_token_for_user_id( user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms, is_appservice_ghost=is_appservice_ghost, + refresh_token_id=refresh_token_id, ) - return {"device_id": registered_device_id, "access_token": access_token} + return { + "device_id": registered_device_id, + "access_token": access_token, + "valid_until_ms": valid_until_ms, + "refresh_token": refresh_token, + } async def post_registration_actions( self, user_id: str, auth_result: dict, access_token: Optional[str] diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 58739e5016..0cb6855767 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -83,6 +83,7 @@ class RoomCreationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() + self._event_auth_handler = hs.get_event_auth_handler() self.config = hs.config # Room state based off defined presets @@ -226,7 +227,7 @@ class RoomCreationHandler(BaseHandler): }, ) old_room_version = await self.store.get_room_version_id(old_room_id) - await self.auth.check_from_context( + await self._event_auth_handler.check_from_context( old_room_version, tombstone_event, tombstone_context ) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 17fc47ce16..b585057ec3 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py
@@ -25,6 +25,7 @@ from synapse.api.constants import ( EventTypes, HistoryVisibility, Membership, + RoomTypes, ) from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 @@ -318,7 +319,8 @@ class SpaceSummaryHandler: Returns: A tuple of: - An iterable of a single value of the room. + The room information, if the room should be returned to the + user. None, otherwise. An iterable of the sorted children events. This may be limited to a maximum size or may include all children. @@ -328,7 +330,11 @@ class SpaceSummaryHandler: room_entry = await self._build_room_entry(room_id) - # look for child rooms/spaces. + # If the room is not a space, return just the room information. + if room_entry.get("room_type") != RoomTypes.SPACE: + return room_entry, () + + # Otherwise, look for child rooms/spaces. child_events = await self._get_child_events(room_id) if suggested_only: @@ -348,6 +354,7 @@ class SpaceSummaryHandler: event_format=format_event_for_client_v2, ) ) + return room_entry, events_result async def _summarize_remote_room( @@ -465,7 +472,7 @@ class SpaceSummaryHandler: # If this is a request over federation, check if the host is in the room or # is in one of the spaces specified via the join rules. elif origin: - if await self._auth.check_host_in_room(room_id, origin): + if await self._event_auth_handler.check_host_in_room(room_id, origin): return True # Alternately, if the host has a user in any of the spaces specified @@ -478,7 +485,9 @@ class SpaceSummaryHandler: await self._event_auth_handler.get_rooms_that_allow_join(state_ids) ) for space_id in allowed_rooms: - if await self._auth.check_host_in_room(space_id, origin): + if await self._event_auth_handler.check_host_in_room( + space_id, origin + ): return True # otherwise, check if the room is peekable diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845651e606..efbc6d5b25 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -728,7 +728,7 @@ def set_cors_headers(request: Request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date", + b"X-Requested-With, Content-Type, Authorization, Date", ) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 07eb4f439b..bf706b3a8f 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -113,8 +113,18 @@ def parse_boolean_from_args(args, name, default=None, required=False): def parse_bytes_from_args( args: Dict[bytes, List[bytes]], name: str, + default: Optional[bytes] = None, +) -> Optional[bytes]: + ... + + +@overload +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, default: Literal[None] = None, - required: Literal[True] = True, + *, + required: Literal[True], ) -> bytes: ... @@ -197,7 +207,12 @@ def parse_string( """ args = request.args # type: Dict[bytes, List[bytes]] # type: ignore return parse_string_from_args( - args, name, default, required, allowed_values, encoding + args, + name, + default, + required=required, + allowed_values=allowed_values, + encoding=encoding, ) @@ -227,7 +242,20 @@ def parse_strings_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[List[str]] = None, - required: Literal[True] = True, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[List[str]]: + ... + + +@overload +def parse_strings_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[List[str]] = None, + *, + required: Literal[True], allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", ) -> List[str]: @@ -239,6 +267,7 @@ def parse_strings_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[List[str]] = None, + *, required: bool = False, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -299,7 +328,20 @@ def parse_string_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[str] = None, - required: Literal[True] = True, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + *, + required: Literal[True], allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", ) -> str: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 58b255eb1b..721c45abac 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -168,7 +168,7 @@ class ModuleApi: "Using deprecated ModuleApi.register which creates a dummy user device." ) user_id = yield self.register_user(localpart, displayname, emails or []) - _, access_token = yield self.register_device(user_id) + _, access_token, _, _ = yield self.register_device(user_id) return user_id, access_token def register_user( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 350646f458..669ea462e2 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -104,7 +104,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() - self.auth = hs.get_auth() + self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a # time. Keyed off room_id. @@ -172,7 +172,7 @@ class BulkPushRuleEvaluator: # not having a power level event is an extreme edge case auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: - auth_events_ids = self.auth.compute_auth_events( + auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index c2e8c00293..550bd5c95f 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py
@@ -36,20 +36,29 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload( - user_id, device_id, initial_display_name, is_guest, is_appservice_ghost + user_id, + device_id, + initial_display_name, + is_guest, + is_appservice_ghost, + should_issue_refresh_token, ): """ Args: + user_id (int) device_id (str|None): Device ID to use, if None a new one is generated. initial_display_name (str|None) is_guest (bool) + is_appservice_ghost (bool) + should_issue_refresh_token (bool) """ return { "device_id": device_id, "initial_display_name": initial_display_name, "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, + "should_issue_refresh_token": should_issue_refresh_token, } async def _handle_request(self, request, user_id): @@ -59,6 +68,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] + should_issue_refresh_token = content["should_issue_refresh_token"] res = await self.registration_handler.register_device_inner( user_id, @@ -66,6 +76,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): initial_display_name, is_guest, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) return 200, res diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6be5f1020..cbcb60fe31 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,9 @@ import logging import re -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional + +from typing_extensions import TypedDict from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -25,6 +27,8 @@ from synapse.http import get_request_uri from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, + assert_params_in_dict, + parse_boolean, parse_bytes_from_args, parse_json_object_from_request, parse_string, @@ -40,6 +44,21 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +LoginResponse = TypedDict( + "LoginResponse", + { + "user_id": str, + "access_token": str, + "home_server": str, + "expires_in_ms": Optional[int], + "refresh_token": Optional[str], + "device_id": str, + "well_known": Optional[Dict[str, Any]], + }, + total=False, +) + + class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" + REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet): self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._msc2918_enabled = hs.config.access_token_lifetime is not None self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self._sso_handler = hs.get_sso_handler() @@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet): async def on_POST(self, request: SynapseRequest): login_submission = parse_json_object_from_request(request) + if self._msc2918_enabled: + # Check if this login should also issue a refresh token, as per + # MSC2918 + should_issue_refresh_token = parse_boolean( + request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False + ) + else: + should_issue_refresh_token = False + try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: appservice = self.auth.get_appservice_by_req(request) @@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet): None, request.getClientIP() ) - result = await self._do_appservice_login(login_submission, appservice) + result = await self._do_appservice_login( + login_submission, + appservice, + should_issue_refresh_token=should_issue_refresh_token, + ) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_jwt_login(login_submission) + result = await self._do_jwt_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_token_login(login_submission) + result = await self._do_token_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) else: await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_other_login(login_submission) + result = await self._do_other_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet): return 200, result async def _do_appservice_login( - self, login_submission: JsonDict, appservice: ApplicationService + self, + login_submission: JsonDict, + appservice: ApplicationService, + should_issue_refresh_token: bool = False, ): identifier = login_submission.get("identifier") logger.info("Got appservice login request with identifier: %r", identifier) @@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) return await self._complete_login( - qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited() + qualified_user_id, + login_submission, + ratelimit=appservice.is_rate_limited(), + should_issue_refresh_token=should_issue_refresh_token, ) - async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_other_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: """Handle non-token/saml/jwt logins Args: login_submission: + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: HTTP response @@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet): login_submission, ratelimit=True ) result = await self._complete_login( - canonical_user_id, login_submission, callback + canonical_user_id, + login_submission, + callback, + should_issue_refresh_token=should_issue_refresh_token, ) return result @@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, + callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, create_non_existent_users: bool = False, ratelimit: bool = True, auth_provider_id: Optional[str] = None, - ) -> Dict[str, str]: + should_issue_refresh_token: bool = False, + ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on all successful logins. @@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet): ratelimit: Whether to ratelimit the login request. auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: result: Dictionary of account information after successful login. @@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, ) - result = { - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - } + result = LoginResponse( + user_id=user_id, + access_token=access_token, + home_server=self.hs.hostname, + device_id=device_id, + ) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token if callback is not None: await callback(result) return result - async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_token_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: """ Handle the final stage of SSO login. Args: - login_submission: The JSON request body. + login_submission: The JSON request body. + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: The body of the JSON response. @@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet): login_submission, self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, ) - async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_jwt_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: token = login_submission.get("token", None) if token is None: raise LoginError( @@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet): user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( - user_id, login_submission, create_non_existent_users=True + user_id, + login_submission, + create_non_existent_users=True, + should_issue_refresh_token=should_issue_refresh_token, ) return result @@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp( return e +class RefreshTokenServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True + ) + + def __init__(self, hs: "HomeServer"): + self._auth_handler = hs.get_auth_handler() + self._clock = hs.get_clock() + self.access_token_lifetime = hs.config.access_token_lifetime + + async def on_POST( + self, + request: SynapseRequest, + ): + refresh_submission = parse_json_object_from_request(request) + + assert_params_in_dict(refresh_submission, ["refresh_token"]) + token = refresh_submission["refresh_token"] + if not isinstance(token, str): + raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) + + valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + access_token, refresh_token = await self._auth_handler.refresh_token( + token, valid_until_ms + ) + expires_in_ms = valid_until_ms - self._clock.time_msec() + return ( + 200, + { + "access_token": access_token, + "refresh_token": refresh_token, + "expires_in_ms": expires_in_ms, + }, + ) + + class SsoRedirectServlet(RestServlet): PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ re.compile( @@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + if hs.config.access_token_lifetime is not None: + RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 0be5fbb7f7..f2c155dfae 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -42,11 +42,13 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_json_object_from_request, parse_string, ) from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -396,6 +398,7 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.enable_registration + self._msc2918_enabled = hs.config.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -421,6 +424,15 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) + if self._msc2918_enabled: + # Check if this registration should also issue a refresh token, as + # per MSC2918 + should_issue_refresh_token = parse_boolean( + request, name="org.matrix.msc2918.refresh_token", default=False + ) + else: + should_issue_refresh_token = False + # We don't care about usernames for this deployment. In fact, the act # of checking whether they exist already can leak metadata about # which users are already registered. @@ -472,7 +484,12 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Desired Username is missing or not a string") result = await self._do_appservice_registration( - desired_username, password, desired_display_name, access_token, body + desired_username, + password, + desired_display_name, + access_token, + body, + should_issue_refresh_token=should_issue_refresh_token, ) return 200, result @@ -744,7 +761,9 @@ class RegisterRestServlet(RestServlet): registered = True return_dict = await self._create_registration_details( - registered_user_id, params + registered_user_id, + params, + should_issue_refresh_token=should_issue_refresh_token, ) if registered: @@ -757,17 +776,21 @@ class RegisterRestServlet(RestServlet): return 200, return_dict async def _do_appservice_registration( - self, username, password, display_name, as_token, body + self, + username, + password, + display_name, + as_token, + body, + should_issue_refresh_token: bool = False, ): - # FIXME: appservice_register() is horribly duplicated with register() - # and they should probably just be combined together with a config flag. - if password: # Hash the password # # In mainline hashing of the password was moved further on in the registration # flow, but we need it here for the AS use-case of shadow servers password = await self.auth_handler.hash(password) + user_id = await self.registration_handler.appservice_register( username, as_token, password, display_name ) @@ -775,6 +798,7 @@ class RegisterRestServlet(RestServlet): user_id, body, is_appservice_ghost=True, + should_issue_refresh_token=should_issue_refresh_token, ) auth_result = body.get("auth_result") @@ -793,16 +817,23 @@ class RegisterRestServlet(RestServlet): return result async def _create_registration_details( - self, user_id, params, is_appservice_ghost=False + self, + user_id: str, + params: JsonDict, + is_appservice_ghost: bool = False, + should_issue_refresh_token: bool = False, ): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. Args: - (str) user_id: full canonical @user:id - (object) params: registration parameters, from which we pull - device_id, initial_device_name and inhibit_login + user_id: full canonical @user:id + params: registration parameters, from which we pull device_id, + initial_device_name and inhibit_login + is_appservice_ghost + should_issue_refresh_token: True if this registration should issue + a refresh token alongside the access token. Returns: dictionary for response from /register """ @@ -810,15 +841,29 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) result.update({"access_token": access_token, "device_id": device_id}) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + return result async def _do_guest_registration(self, params, address=None): @@ -832,19 +877,30 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) - return ( - 200, - { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - }, - ) + result = { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + } + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + return 200, result def cap(name): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 042e1788b6..ecbbcf3851 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py
@@ -13,6 +13,7 @@ # limitations under the License. import itertools import logging +from collections import defaultdict from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple from synapse.api.constants import Membership, PresenceState @@ -232,29 +233,51 @@ class SyncRestServlet(RestServlet): ) logger.debug("building sync response dict") - return { - "account_data": {"events": sync_result.account_data}, - "to_device": {"events": sync_result.to_device}, - "device_lists": { - "changed": list(sync_result.device_lists.changed), - "left": list(sync_result.device_lists.left), - }, - "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now), - "rooms": { - Membership.JOIN: joined, - Membership.INVITE: invited, - Membership.KNOCK: knocked, - Membership.LEAVE: archived, - }, - "groups": { - Membership.JOIN: sync_result.groups.join, - Membership.INVITE: sync_result.groups.invite, - Membership.LEAVE: sync_result.groups.leave, - }, - "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types, - "next_batch": await sync_result.next_batch.to_string(self.store), - } + + response: dict = defaultdict(dict) + response["next_batch"] = await sync_result.next_batch.to_string(self.store) + + if sync_result.account_data: + response["account_data"] = {"events": sync_result.account_data} + if sync_result.presence: + response["presence"] = SyncRestServlet.encode_presence( + sync_result.presence, time_now + ) + + if sync_result.to_device: + response["to_device"] = {"events": sync_result.to_device} + + if sync_result.device_lists.changed: + response["device_lists"]["changed"] = list(sync_result.device_lists.changed) + if sync_result.device_lists.left: + response["device_lists"]["left"] = list(sync_result.device_lists.left) + + if sync_result.device_one_time_keys_count: + response[ + "device_one_time_keys_count" + ] = sync_result.device_one_time_keys_count + if sync_result.device_unused_fallback_key_types: + response[ + "org.matrix.msc2732.device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types + + if joined: + response["rooms"][Membership.JOIN] = joined + if invited: + response["rooms"][Membership.INVITE] = invited + if knocked: + response["rooms"][Membership.KNOCK] = knocked + if archived: + response["rooms"][Membership.LEAVE] = archived + + if sync_result.groups.join: + response["groups"][Membership.JOIN] = sync_result.groups.join + if sync_result.groups.invite: + response["groups"][Membership.INVITE] = sync_result.groups.invite + if sync_result.groups.leave: + response["groups"][Membership.LEAVE] = sync_result.groups.leave + + return response @staticmethod def encode_presence(events, time_now): diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d470cdacde..33c42cf95a 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -111,7 +111,7 @@ def make_conn( db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine, default_txn_name: str, -) -> Connection: +) -> "LoggingDatabaseConnection": """Make a new connection to the database and return it. Returns: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index f23f8c6ecf..c4474df975 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -16,6 +16,8 @@ import logging from queue import Empty, PriorityQueue from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple +from prometheus_client import Gauge + from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion @@ -32,6 +34,16 @@ from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter +oldest_pdu_in_federation_staging = Gauge( + "synapse_federation_server_oldest_inbound_pdu_in_staging", + "The age in seconds since we received the oldest pdu in the federation staging area", +) + +number_pdus_in_federation_queue = Gauge( + "synapse_federation_server_number_inbound_pdu_in_staging", + "The total number of events in the inbound federation staging", +) + logger = logging.getLogger(__name__) @@ -54,6 +66,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas 500000, "_event_auth_cache", size_callback=len ) # type: LruCache[str, List[Tuple[str, int]]] + self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) + async def get_auth_chain( self, room_id: str, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: @@ -1075,16 +1089,62 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas self, origin: str, event_id: str, - ) -> None: - """Remove the given event from the staging area""" - await self.db_pool.simple_delete( - table="federation_inbound_events_staging", - keyvalues={ - "origin": origin, - "event_id": event_id, - }, - desc="remove_received_event_from_staging", - ) + ) -> Optional[int]: + """Remove the given event from the staging area. + + Returns: + The received_ts of the row that was deleted, if any. + """ + if self.db_pool.engine.supports_returning: + + def _remove_received_event_from_staging_txn(txn): + sql = """ + DELETE FROM federation_inbound_events_staging + WHERE origin = ? AND event_id = ? + RETURNING received_ts + """ + + txn.execute(sql, (origin, event_id)) + return txn.fetchone() + + row = await self.db_pool.runInteraction( + "remove_received_event_from_staging", + _remove_received_event_from_staging_txn, + db_autocommit=True, + ) + if row is None: + return None + + return row[0] + + else: + + def _remove_received_event_from_staging_txn(txn): + received_ts = self.db_pool.simple_select_one_onecol_txn( + txn, + table="federation_inbound_events_staging", + keyvalues={ + "origin": origin, + "event_id": event_id, + }, + retcol="received_ts", + allow_none=True, + ) + self.db_pool.simple_delete_txn( + txn, + table="federation_inbound_events_staging", + keyvalues={ + "origin": origin, + "event_id": event_id, + }, + ) + + return received_ts + + return await self.db_pool.runInteraction( + "remove_received_event_from_staging", + _remove_received_event_from_staging_txn, + ) async def get_next_staged_event_id_for_room( self, @@ -1147,6 +1207,40 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return origin, event + async def get_all_rooms_with_staged_incoming_events(self) -> List[str]: + """Get the room IDs of all events currently staged.""" + return await self.db_pool.simple_select_onecol( + table="federation_inbound_events_staging", + keyvalues={}, + retcol="DISTINCT room_id", + desc="get_all_rooms_with_staged_incoming_events", + ) + + @wrap_as_background_process("_get_stats_for_federation_staging") + async def _get_stats_for_federation_staging(self): + """Update the prometheus metrics for the inbound federation staging area.""" + + def _get_stats_for_federation_staging_txn(txn): + txn.execute( + "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging" + ) + (count,) = txn.fetchone() + + txn.execute( + "SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging" + ) + + (age,) = txn.fetchone() + + return count, age + + count, age = await self.db_pool.runInteraction( + "_get_stats_for_federation_staging", _get_stats_for_federation_staging_txn + ) + + number_pdus_in_federation_queue.set(count) + oldest_pdu_in_federation_staging.set(age) + class EventFederationStore(EventFederationWorkerStore): """Responsible for storing and serving up the various graphs associated diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index cbe4be1437..29f33bac55 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -29,6 +29,34 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +_REPLACE_STREAM_ORDERING_SQL_COMMANDS = ( + # there should be no leftover rows without a stream_ordering2, but just in case... + "UPDATE events SET stream_ordering2 = stream_ordering WHERE stream_ordering2 IS NULL", + # now we can drop the rule and switch the columns + "DROP RULE populate_stream_ordering2 ON events", + "ALTER TABLE events DROP COLUMN stream_ordering", + "ALTER TABLE events RENAME COLUMN stream_ordering2 TO stream_ordering", + # ... and finally, rename the indexes into place for consistency with sqlite + "ALTER INDEX event_contains_url_index2 RENAME TO event_contains_url_index", + "ALTER INDEX events_order_room2 RENAME TO events_order_room", + "ALTER INDEX events_room_stream2 RENAME TO events_room_stream", + "ALTER INDEX events_ts2 RENAME TO events_ts", +) + + +class _BackgroundUpdates: + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2" + INDEX_STREAM_ORDERING2 = "index_stream_ordering2" + INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url" + INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order" + INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream" + INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts" + REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column" + + @attr.s(slots=True, frozen=True) class _CalculateChainCover: """Return value for _calculate_chain_cover_txn.""" @@ -48,19 +76,15 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( - self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, + self._background_reindex_origin_server_ts, ) self.db_pool.updates.register_background_update_handler( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) @@ -85,7 +109,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) self.db_pool.updates.register_background_update_handler( - self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update + _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES, + self._cleanup_extremities_bg_update, ) self.db_pool.updates.register_background_update_handler( @@ -139,6 +164,59 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._purged_chain_cover_index, ) + ################################################################################ + + # bg updates for replacing stream_ordering with a BIGINT + # (these only run on postgres.) + + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.POPULATE_STREAM_ORDERING2, + self._background_populate_stream_ordering2, + ) + # CREATE UNIQUE INDEX events_stream_ordering ON events(stream_ordering2); + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.INDEX_STREAM_ORDERING2, + index_name="events_stream_ordering", + table="events", + columns=["stream_ordering2"], + unique=True, + ) + # CREATE INDEX event_contains_url_index ON events(room_id, topological_ordering, stream_ordering) WHERE contains_url = true AND outlier = false; + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.INDEX_STREAM_ORDERING2_CONTAINS_URL, + index_name="event_contains_url_index2", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering2"], + where_clause="contains_url = true AND outlier = false", + ) + # CREATE INDEX events_order_room ON events(room_id, topological_ordering, stream_ordering); + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.INDEX_STREAM_ORDERING2_ROOM_ORDER, + index_name="events_order_room2", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering2"], + ) + # CREATE INDEX events_room_stream ON events(room_id, stream_ordering); + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.INDEX_STREAM_ORDERING2_ROOM_STREAM, + index_name="events_room_stream2", + table="events", + columns=["room_id", "stream_ordering2"], + ) + # CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering); + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.INDEX_STREAM_ORDERING2_TS, + index_name="events_ts2", + table="events", + columns=["origin_server_ts", "stream_ordering2"], + ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN, + self._background_replace_stream_ordering_column, + ) + + ################################################################################ + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -190,18 +268,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): } self.db_pool.updates._background_update_progress_txn( - txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress + txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) result = await self.db_pool.runInteraction( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: await self.db_pool.updates._end_background_update( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME ) return result @@ -264,18 +342,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): } self.db_pool.updates._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress ) return len(rows_to_update) result = await self.db_pool.runInteraction( - self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: await self.db_pool.updates._end_background_update( - self.EVENT_ORIGIN_SERVER_TS_NAME + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME ) return result @@ -454,7 +532,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): if not num_handled: await self.db_pool.updates._end_background_update( - self.DELETE_SOFT_FAILED_EXTREMITIES + _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES ) def _drop_table_txn(txn): @@ -1009,3 +1087,81 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.updates._end_background_update("purged_chain_cover") return result + + async def _background_populate_stream_ordering2( + self, progress: JsonDict, batch_size: int + ) -> int: + """Populate events.stream_ordering2, then replace stream_ordering + + This is to deal with the fact that stream_ordering was initially created as a + 32-bit integer field. + """ + batch_size = max(batch_size, 1) + + def process(txn: Cursor) -> int: + last_stream = progress.get("last_stream", -(1 << 31)) + txn.execute( + """ + UPDATE events SET stream_ordering2=stream_ordering + WHERE stream_ordering IN ( + SELECT stream_ordering FROM events WHERE stream_ordering > ? + ORDER BY stream_ordering LIMIT ? + ) + RETURNING stream_ordering; + """, + (last_stream, batch_size), + ) + row_count = txn.rowcount + if row_count == 0: + return 0 + last_stream = max(row[0] for row in txn) + logger.info("populated stream_ordering2 up to %i", last_stream) + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.POPULATE_STREAM_ORDERING2, + {"last_stream": last_stream}, + ) + return row_count + + result = await self.db_pool.runInteraction( + "_background_populate_stream_ordering2", process + ) + + if result != 0: + return result + + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.POPULATE_STREAM_ORDERING2 + ) + return 0 + + async def _background_replace_stream_ordering_column( + self, progress: JsonDict, batch_size: int + ) -> int: + """Drop the old 'stream_ordering' column and rename 'stream_ordering2' into its place.""" + + def process(txn: Cursor) -> None: + for sql in _REPLACE_STREAM_ORDERING_SQL_COMMANDS: + logger.info("completing stream_ordering migration: %s", sql) + txn.execute(sql) + + # ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the + # indexes on it. + # We need to pass execute a dummy function to handle the txn's result otherwise + # it tries to call fetchall() on it and fails because there's no result to fetch. + await self.db_pool.execute( + "background_analyze_new_stream_ordering_column", + lambda txn: None, + "ANALYZE events(stream_ordering2)", + ) + + await self.db_pool.runInteraction( + "_background_replace_stream_ordering_column", process + ) + + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN + ) + + return 0 diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index e76188328c..774861074c 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py
@@ -310,14 +310,25 @@ class Lock: _excinst: Optional[BaseException], _exctb: Optional[TracebackType], ) -> bool: + await self.release() + + return False + + async def release(self) -> None: + """Release the lock. + + This is automatically called when using the lock as a context manager. + """ + + if self._dropped: + return + if self._looping_call.running: self._looping_call.stop() await self._store._drop_lock(self._lock_name, self._lock_key, self._token) self._dropped = True - return False - def __del__(self) -> None: if not self._dropped: # We should not be dropped without the lock being released (unless diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 77e2eb27db..b8bdfc721e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -53,6 +53,9 @@ class TokenLookupResult: valid_until_ms: The timestamp the token expires, if any. token_owner: The "owner" of the token. This is either the same as the user, or a server admin who is logged in as the user. + token_used: True if this token was used at least once in a request. + This field can be out of date since `get_user_by_access_token` is + cached. """ user_id = attr.ib(type=str) @@ -62,6 +65,7 @@ class TokenLookupResult: device_id = attr.ib(type=Optional[str], default=None) valid_until_ms = attr.ib(type=Optional[int], default=None) token_owner = attr.ib(type=str) + token_used = attr.ib(type=bool, default=False) # Make the token owner default to the user ID, which is the common case. @token_owner.default @@ -69,6 +73,29 @@ class TokenLookupResult: return self.user_id +@attr.s(frozen=True, slots=True) +class RefreshTokenLookupResult: + """Result of looking up a refresh token.""" + + user_id = attr.ib(type=str) + """The user this token belongs to.""" + + device_id = attr.ib(type=str) + """The device associated with this refresh token.""" + + token_id = attr.ib(type=int) + """The ID of this refresh token.""" + + next_token_id = attr.ib(type=Optional[int]) + """The ID of the refresh token which replaced this one.""" + + has_next_refresh_token_been_refreshed = attr.ib(type=bool) + """True if the next refresh token was used for another refresh.""" + + has_next_access_token_been_used = attr.ib(type=bool) + """True if the next access token was already used at least once.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -521,7 +548,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): access_tokens.id as token_id, access_tokens.device_id, access_tokens.valid_until_ms, - access_tokens.user_id as token_owner + access_tokens.user_id as token_owner, + access_tokens.used as token_used FROM users INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) WHERE token = ? @@ -529,8 +557,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) + if rows: - return TokenLookupResult(**rows[0]) + row = rows[0] + + # This field is nullable, ensure it comes out as a boolean + if row["token_used"] is None: + row["token_used"] = False + + return TokenLookupResult(**row) return None @@ -1152,6 +1187,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + @cached() + async def mark_access_token_as_used(self, token_id: int) -> None: + """ + Mark the access token as used, which invalidates the refresh token used + to obtain it. + + Because get_user_by_access_token is cached, this function might be + called multiple times for the same token, effectively doing unnecessary + SQL updates. Because updating the `used` field only goes one way (from + False to True) it is safe to cache this function as well to avoid this + issue. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"used": True}, + desc="mark_access_token_as_used", + ) + + async def lookup_refresh_token( + self, token: str + ) -> Optional[RefreshTokenLookupResult]: + """Lookup a refresh token with hints about its validity.""" + + def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: + txn.execute( + """ + SELECT + rt.id token_id, + rt.user_id, + rt.device_id, + rt.next_token_id, + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used + FROM refresh_tokens rt + LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id + LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id + WHERE rt.token = ? + """, + (token,), + ) + row = txn.fetchone() + + if row is None: + return None + + return RefreshTokenLookupResult( + token_id=row[0], + user_id=row[1], + device_id=row[2], + next_token_id=row[3], + has_next_refresh_token_been_refreshed=row[4], + # This column is nullable, ensure it's a boolean + has_next_access_token_been_used=(row[5] or False), + ) + + return await self.db_pool.runInteraction( + "lookup_refresh_token", _lookup_refresh_token_txn + ) + + async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None: + """ + Set the successor of a refresh token, removing the existing successor + if any. + + Args: + token_id: ID of the refresh token to update. + next_token_id: ID of its successor. + """ + + def _replace_refresh_token_txn(txn) -> None: + # First check if there was an existing refresh token + old_next_token_id = self.db_pool.simple_select_one_onecol_txn( + txn, + "refresh_tokens", + {"id": token_id}, + "next_token_id", + allow_none=True, + ) + + self.db_pool.simple_update_one_txn( + txn, + "refresh_tokens", + {"id": token_id}, + {"next_token_id": next_token_id}, + ) + + # Delete the old "next" token if it exists. This should cascade and + # delete the associated access_token + if old_next_token_id is not None: + self.db_pool.simple_delete_one_txn( + txn, + "refresh_tokens", + {"id": old_next_token_id}, + ) + + await self.db_pool.runInteraction( + "replace_refresh_token", _replace_refresh_token_txn + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1343,6 +1483,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") async def add_access_token_to_user( self, @@ -1351,14 +1492,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str], valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, + refresh_token_id: Optional[int] = None, ) -> int: """Adds an access token for the given user. Args: user_id: The user ID. token: The new access token to add. - device_id: ID of the device to associate with the access token + device_id: ID of the device to associate with the access token. valid_until_ms: when the token is valid until. None for no expiry. + puppets_user_id + refresh_token_id: ID of the refresh token generated alongside this + access token. Raises: StoreError if there was a problem adding this. Returns: @@ -1377,12 +1522,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "valid_until_ms": valid_until_ms, "puppets_user_id": puppets_user_id, "last_validated": now, + "refresh_token_id": refresh_token_id, + "used": False, }, desc="add_access_token_to_user", ) return next_id + async def add_refresh_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + ) -> int: + """Adds a refresh token for the given user. + + Args: + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the refresh token. + Raises: + StoreError if there was a problem adding this. + Returns: + The token ID + """ + next_id = self._refresh_tokens_id_gen.get_next() + + await self.db_pool.simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "device_id": device_id, + "token": token, + "next_token_id": None, + }, + desc="add_refresh_token_to_user", + ) + + return next_id + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" @@ -1625,7 +1805,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str] = None, ) -> List[Tuple[str, int, Optional[str]]]: """ - Invalidate access tokens belonging to a user + Invalidate access and refresh tokens belonging to a user Args: user_id: ID of user the tokens belong to @@ -1645,7 +1825,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] # type: List[Union[str, int]] + # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat + # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where + # clause and values before we handle that. This seems to be only used in the "set password" handler. + refresh_where_clause = where_clause + refresh_values = values.copy() if except_token_id: + # TODO: support that for refresh tokens where_clause += " AND id != ?" values.append(except_token_id) @@ -1663,6 +1849,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) + txn.execute( + "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause, + refresh_values, + ) + return tokens_and_devices return await self.db_pool.runInteraction("user_delete_access_tokens", f) @@ -1679,6 +1870,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) + async def delete_refresh_token(self, refresh_token: str) -> None: + def f(txn): + self.db_pool.simple_delete_one_txn( + txn, table="refresh_tokens", keyvalues={"token": refresh_token} + ) + + await self.db_pool.runInteraction("delete_refresh_token", f) + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 1882bfd9cf..20cd63c330 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py
@@ -49,6 +49,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): """ ... + @property + @abc.abstractmethod + def supports_returning(self) -> bool: + """Do we support the `RETURNING` clause in insert/update/delete?""" + ... + @abc.abstractmethod def check_database( self, db_conn: ConnectionType, allow_outdated_version: bool = False diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 21411c5fea..30f948a0f7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py
@@ -133,6 +133,11 @@ class PostgresEngine(BaseDatabaseEngine): """Do we support using `a = ANY(?)` and passing a list""" return True + @property + def supports_returning(self) -> bool: + """Do we support the `RETURNING` clause in insert/update/delete?""" + return True + def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 5fe1b205e1..70d17d4f2c 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py
@@ -60,6 +60,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): """Do we support using `a = ANY(?)` and passing a list""" return False + @property + def supports_returning(self) -> bool: + """Do we support the `RETURNING` clause in insert/update/delete?""" + return self.module.sqlite_version_info >= (3, 35, 0) + def check_database(self, db_conn, allow_outdated_version: bool = False): if not allow_outdated_version: version = self.module.sqlite_version_info diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index d36ba1d773..0a53b73ccc 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 59 +SCHEMA_VERSION = 60 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the diff --git a/synapse/storage/schema/main/delta/59/14refresh_tokens.sql b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql new file mode 100644
index 0000000000..9a6bce1e3e --- /dev/null +++ b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql
@@ -0,0 +1,34 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Holds MSC2918 refresh tokens +CREATE TABLE refresh_tokens ( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + token TEXT NOT NULL, + -- When consumed, a new refresh token is generated, which is tracked by + -- this foreign key + next_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE, + UNIQUE(token) +); + +-- Add a reference to the refresh token generated alongside each access token +ALTER TABLE "access_tokens" + ADD COLUMN refresh_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE; + +-- Add a flag whether the token was already used or not +ALTER TABLE "access_tokens" + ADD COLUMN used BOOLEAN; diff --git a/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres new file mode 100644
index 0000000000..0edc9fe7a2 --- /dev/null +++ b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres
@@ -0,0 +1,45 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This migration handles the process of changing the type of `stream_ordering` to +-- a BIGINT. +-- +-- Note that this is only a problem on postgres as sqlite only has one "integer" type +-- which can cope with values up to 2^63. + +-- First add a new column to contain the bigger stream_ordering +ALTER TABLE events ADD COLUMN stream_ordering2 BIGINT; + +-- Create a rule which will populate it for new rows. +CREATE OR REPLACE RULE "populate_stream_ordering2" AS + ON INSERT TO events + DO UPDATE events SET stream_ordering2=NEW.stream_ordering WHERE stream_ordering=NEW.stream_ordering; + +-- Start a bg process to populate it for old events +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6001, 'populate_stream_ordering2', '{}'); + +-- ... and some more to build indexes on it. These aren't really interdependent +-- but the backround_updates manager can only handle a single dependency per update. +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (6001, 'index_stream_ordering2', '{}', 'populate_stream_ordering2'), + (6001, 'index_stream_ordering2_room_order', '{}', 'index_stream_ordering2'), + (6001, 'index_stream_ordering2_contains_url', '{}', 'index_stream_ordering2_room_order'), + (6001, 'index_stream_ordering2_room_stream', '{}', 'index_stream_ordering2_contains_url'), + (6001, 'index_stream_ordering2_ts', '{}', 'index_stream_ordering2_room_stream'); + +-- ... and another to do the switcheroo +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (6001, 'replace_stream_ordering_column', '{}', 'index_stream_ordering2_ts'); diff --git a/synapse/storage/schema/main/delta/60/02change_stream_ordering_columns.sql.postgres b/synapse/storage/schema/main/delta/60/02change_stream_ordering_columns.sql.postgres new file mode 100644
index 0000000000..630c24fd9e --- /dev/null +++ b/synapse/storage/schema/main/delta/60/02change_stream_ordering_columns.sql.postgres
@@ -0,0 +1,30 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This migration is closely related to '01recreate_stream_ordering.sql.postgres'. +-- +-- It updates the other tables which use an INTEGER to refer to a stream ordering. +-- These tables are all small enough that a re-create is tractable. +ALTER TABLE pushers ALTER COLUMN last_stream_ordering SET DATA TYPE BIGINT; +ALTER TABLE federation_stream_position ALTER COLUMN stream_id SET DATA TYPE BIGINT; + +-- these aren't actually event stream orderings, but they are numbers where 2 billion +-- is a bit limiting, application_services_state is tiny, and I don't want to ever have +-- to do this again. +ALTER TABLE application_services_state ALTER COLUMN last_txn SET DATA TYPE BIGINT; +ALTER TABLE application_services_state ALTER COLUMN read_receipt_stream_id SET DATA TYPE BIGINT; +ALTER TABLE application_services_state ALTER COLUMN presence_stream_id SET DATA TYPE BIGINT; + + diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index d89e9d9b1d..4b9d0433ff 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import threading +import weakref from functools import wraps from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -31,10 +34,19 @@ from typing import ( from typing_extensions import Literal +from twisted.internet import reactor + from synapse.config import cache as cache_config -from synapse.util import caches +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.util import Clock, caches from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry +from synapse.util.linked_list import ListNode + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) try: from pympler.asizeof import Asizer @@ -82,19 +94,126 @@ def enumerate_leaves(node, depth): yield m +P = TypeVar("P") + + +class _TimedListNode(ListNode[P]): + """A `ListNode` that tracks last access time.""" + + __slots__ = ["last_access_ts_secs"] + + def update_last_access(self, clock: Clock): + self.last_access_ts_secs = int(clock.time()) + + +# Whether to insert new cache entries to the global list. We only add to it if +# time based eviction is enabled. +USE_GLOBAL_LIST = False + +# A linked list of all cache entries, allowing efficient time based eviction. +GLOBAL_ROOT = ListNode["_Node"].create_root_node() + + +@wrap_as_background_process("LruCache._expire_old_entries") +async def _expire_old_entries(clock: Clock, expiry_seconds: int): + """Walks the global cache list to find cache entries that haven't been + accessed in the given number of seconds. + """ + + now = int(clock.time()) + node = GLOBAL_ROOT.prev_node + assert node is not None + + i = 0 + + logger.debug("Searching for stale caches") + + while node is not GLOBAL_ROOT: + # Only the root node isn't a `_TimedListNode`. + assert isinstance(node, _TimedListNode) + + if node.last_access_ts_secs > now - expiry_seconds: + break + + cache_entry = node.get_cache_entry() + next_node = node.prev_node + + # The node should always have a reference to a cache entry and a valid + # `prev_node`, as we only drop them when we remove the node from the + # list. + assert next_node is not None + assert cache_entry is not None + cache_entry.drop_from_cache() + + # If we do lots of work at once we yield to allow other stuff to happen. + if (i + 1) % 10000 == 0: + logger.debug("Waiting during drop") + await clock.sleep(0) + logger.debug("Waking during drop") + + node = next_node + + # If we've yielded then our current node may have been evicted, so we + # need to check that its still valid. + if node.prev_node is None: + break + + i += 1 + + logger.info("Dropped %d items from caches", i) + + +def setup_expire_lru_cache_entries(hs: "HomeServer"): + """Start a background job that expires all cache entries if they have not + been accessed for the given number of seconds. + """ + if not hs.config.caches.expiry_time_msec: + return + + logger.info( + "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000 + ) + + global USE_GLOBAL_LIST + USE_GLOBAL_LIST = True + + clock = hs.get_clock() + clock.looping_call( + _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000 + ) + + class _Node: - __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"] + __slots__ = [ + "_list_node", + "_global_list_node", + "_cache", + "key", + "value", + "callbacks", + "memory", + ] def __init__( self, - prev_node, - next_node, + root: "ListNode[_Node]", key, value, + cache: "weakref.ReferenceType[LruCache]", + clock: Clock, callbacks: Collection[Callable[[], None]] = (), ): - self.prev_node = prev_node - self.next_node = next_node + self._list_node = ListNode.insert_after(self, root) + self._global_list_node = None + if USE_GLOBAL_LIST: + self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT) + self._global_list_node.update_last_access(clock) + + # We store a weak reference to the cache object so that this _Node can + # remove itself from the cache. If the cache is dropped we ensure we + # remove our entries in the lists. + self._cache = cache + self.key = key self.value = value @@ -116,11 +235,16 @@ class _Node: self.memory = ( _get_size_of(key) + _get_size_of(value) + + _get_size_of(self._list_node, recurse=False) + _get_size_of(self.callbacks, recurse=False) + _get_size_of(self, recurse=False) ) self.memory += _get_size_of(self.memory, recurse=False) + if self._global_list_node: + self.memory += _get_size_of(self._global_list_node, recurse=False) + self.memory += _get_size_of(self._global_list_node.last_access_ts_secs) + def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None: """Add to stored list of callbacks, removing duplicates.""" @@ -147,6 +271,32 @@ class _Node: self.callbacks = None + def drop_from_cache(self) -> None: + """Drop this node from the cache. + + Ensures that the entry gets removed from the cache and that we get + removed from all lists. + """ + cache = self._cache() + if not cache or not cache.pop(self.key, None): + # `cache.pop` should call `drop_from_lists()`, unless this Node had + # already been removed from the cache. + self.drop_from_lists() + + def drop_from_lists(self) -> None: + """Remove this node from the cache lists.""" + self._list_node.remove_from_list() + + if self._global_list_node: + self._global_list_node.remove_from_list() + + def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None: + """Moves this node to the front of all the lists its in.""" + self._list_node.move_after(cache_list_root) + if self._global_list_node: + self._global_list_node.move_after(GLOBAL_ROOT) + self._global_list_node.update_last_access(clock) + class LruCache(Generic[KT, VT]): """ @@ -163,6 +313,7 @@ class LruCache(Generic[KT, VT]): size_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, + clock: Optional[Clock] = None, ): """ Args: @@ -188,6 +339,13 @@ class LruCache(Generic[KT, VT]): apply_cache_factor_from_config (bool): If true, `max_size` will be multiplied by a cache factor derived from the homeserver config """ + # Default `clock` to something sensible. Note that we rename it to + # `real_clock` so that mypy doesn't think its still `Optional`. + if clock is None: + real_clock = Clock(reactor) + else: + real_clock = clock + cache = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -219,17 +377,31 @@ class LruCache(Generic[KT, VT]): # this is exposed for access from outside this class self.metrics = metrics - list_root = _Node(None, None, None, None) - list_root.next_node = list_root - list_root.prev_node = list_root + # We create a single weakref to self here so that we don't need to keep + # creating more each time we create a `_Node`. + weak_ref_to_self = weakref.ref(self) + + list_root = ListNode[_Node].create_root_node() lock = threading.Lock() def evict(): while cache_len() > self.max_size: + # Get the last node in the list (i.e. the oldest node). todelete = list_root.prev_node - evicted_len = delete_node(todelete) - cache.pop(todelete.key, None) + + # The list root should always have a valid `prev_node` if the + # cache is not empty. + assert todelete is not None + + # The node should always have a reference to a cache entry, as + # we only drop the cache entry when we remove the node from the + # list. + node = todelete.get_cache_entry() + assert node is not None + + evicted_len = delete_node(node) + cache.pop(node.key, None) if metrics: metrics.inc_evictions(evicted_len) @@ -255,11 +427,7 @@ class LruCache(Generic[KT, VT]): self.len = synchronized(cache_len) def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): - prev_node = list_root - next_node = prev_node.next_node - node = _Node(prev_node, next_node, key, value, callbacks) - prev_node.next_node = node - next_node.prev_node = node + node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks) cache[key] = node if size_callback: @@ -268,23 +436,11 @@ class LruCache(Generic[KT, VT]): if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) - def move_node_to_front(node): - prev_node = node.prev_node - next_node = node.next_node - prev_node.next_node = next_node - next_node.prev_node = prev_node - prev_node = list_root - next_node = prev_node.next_node - node.prev_node = prev_node - node.next_node = next_node - prev_node.next_node = node - next_node.prev_node = node - - def delete_node(node): - prev_node = node.prev_node - next_node = node.next_node - prev_node.next_node = next_node - next_node.prev_node = prev_node + def move_node_to_front(node: _Node): + node.move_to_front(real_clock, list_root) + + def delete_node(node: _Node) -> int: + node.drop_from_lists() deleted_len = 1 if size_callback: @@ -411,10 +567,13 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_clear() -> None: - list_root.next_node = list_root - list_root.prev_node = list_root for node in cache.values(): node.run_and_clear_callbacks() + node.drop_from_lists() + + assert list_root.next_node == list_root + assert list_root.prev_node == list_root + cache.clear() if size_callback: cached_cache_len[0] = 0 @@ -484,3 +643,11 @@ class LruCache(Generic[KT, VT]): self._on_resize() return True return False + + def __del__(self) -> None: + # We're about to be deleted, so we make sure to clear up all the nodes + # and run callbacks, etc. + # + # This happens e.g. in the sync code where we have an expiring cache of + # lru caches. + self.clear() diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py new file mode 100644
index 0000000000..a456b136f0 --- /dev/null +++ b/synapse/util/linked_list.py
@@ -0,0 +1,150 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A circular doubly linked list implementation. +""" + +import threading +from typing import Generic, Optional, Type, TypeVar + +P = TypeVar("P") +LN = TypeVar("LN", bound="ListNode") + + +class ListNode(Generic[P]): + """A node in a circular doubly linked list, with an (optional) reference to + a cache entry. + + The reference should only be `None` for the root node or if the node has + been removed from the list. + """ + + # A lock to protect mutating the list prev/next pointers. + _LOCK = threading.Lock() + + # We don't use attrs here as in py3.6 you can't have `attr.s(slots=True)` + # and inherit from `Generic` for some reason + __slots__ = [ + "cache_entry", + "prev_node", + "next_node", + ] + + def __init__(self, cache_entry: Optional[P] = None) -> None: + self.cache_entry = cache_entry + self.prev_node: Optional[ListNode[P]] = None + self.next_node: Optional[ListNode[P]] = None + + @classmethod + def create_root_node(cls: Type["ListNode[P]"]) -> "ListNode[P]": + """Create a new linked list by creating a "root" node, which is a node + that has prev_node/next_node pointing to itself and no associated cache + entry. + """ + root = cls() + root.prev_node = root + root.next_node = root + return root + + @classmethod + def insert_after( + cls: Type[LN], + cache_entry: P, + node: "ListNode[P]", + ) -> LN: + """Create a new list node that is placed after the given node. + + Args: + cache_entry: The associated cache entry. + node: The existing node in the list to insert the new entry after. + """ + new_node = cls(cache_entry) + with cls._LOCK: + new_node._refs_insert_after(node) + return new_node + + def remove_from_list(self): + """Remove this node from the list.""" + with self._LOCK: + self._refs_remove_node_from_list() + + # We drop the reference to the cache entry to break the reference cycle + # between the list node and cache entry, allowing the two to be dropped + # immediately rather than at the next GC. + self.cache_entry = None + + def move_after(self, node: "ListNode"): + """Move this node from its current location in the list to after the + given node. + """ + with self._LOCK: + # We assert that both this node and the target node is still "alive". + assert self.prev_node + assert self.next_node + assert node.prev_node + assert node.next_node + + assert self is not node + + # Remove self from the list + self._refs_remove_node_from_list() + + # Insert self back into the list, after target node + self._refs_insert_after(node) + + def _refs_remove_node_from_list(self): + """Internal method to *just* remove the node from the list, without + e.g. clearing out the cache entry. + """ + if self.prev_node is None or self.next_node is None: + # We've already been removed from the list. + return + + prev_node = self.prev_node + next_node = self.next_node + + prev_node.next_node = next_node + next_node.prev_node = prev_node + + # We set these to None so that we don't get circular references, + # allowing us to be dropped without having to go via the GC. + self.prev_node = None + self.next_node = None + + def _refs_insert_after(self, node: "ListNode"): + """Internal method to insert the node after the given node.""" + + # This method should only be called when we're not already in the list. + assert self.prev_node is None + assert self.next_node is None + + # We expect the given node to be in the list and thus have valid + # prev/next refs. + assert node.next_node + assert node.prev_node + + prev_node = node + next_node = node.next_node + + self.prev_node = prev_node + self.next_node = next_node + + prev_node.next_node = self + next_node.prev_node = self + + def get_cache_entry(self) -> Optional[P]: + """Get the cache entry, returns None if this is the root node (i.e. + cache_entry is None) or if the entry has been dropped. + """ + return self.cache_entry