diff --git a/synapse/__init__.py b/synapse/__init__.py
index 1bd03462ac..5ecce24eee 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.37.1"
+__version__ = "1.38.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py
new file mode 100644
index 0000000000..01dc0c4237
--- /dev/null
+++ b/synapse/_scripts/review_recent_signups.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import sys
+import time
+from datetime import datetime
+from typing import List
+
+import attr
+
+from synapse.config._base import RootConfig, find_config_files, read_config_files
+from synapse.config.database import DatabaseConfig
+from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
+from synapse.storage.engines import create_engine
+
+
+class ReviewConfig(RootConfig):
+ "A config class that just pulls out the database config"
+ config_classes = [DatabaseConfig]
+
+
+@attr.s(auto_attribs=True)
+class UserInfo:
+ user_id: str
+ creation_ts: int
+ emails: List[str] = attr.Factory(list)
+ private_rooms: List[str] = attr.Factory(list)
+ public_rooms: List[str] = attr.Factory(list)
+ ips: List[str] = attr.Factory(list)
+
+
+def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]:
+ """Fetches recently registered users and some info on them."""
+
+ sql = """
+ SELECT name, creation_ts FROM users
+ WHERE
+ ? <= creation_ts
+ AND deactivated = 0
+ """
+
+ txn.execute(sql, (since_ms / 1000,))
+
+ user_infos = [UserInfo(user_id, creation_ts) for user_id, creation_ts in txn]
+
+ for user_info in user_infos:
+ user_info.emails = DatabasePool.simple_select_onecol_txn(
+ txn,
+ table="user_threepids",
+ keyvalues={"user_id": user_info.user_id, "medium": "email"},
+ retcol="address",
+ )
+
+ sql = """
+ SELECT room_id, canonical_alias, name, join_rules
+ FROM local_current_membership
+ INNER JOIN room_stats_state USING (room_id)
+ WHERE user_id = ? AND membership = 'join'
+ """
+
+ txn.execute(sql, (user_info.user_id,))
+ for room_id, canonical_alias, name, join_rules in txn:
+ if join_rules == "public":
+ user_info.public_rooms.append(canonical_alias or name or room_id)
+ else:
+ user_info.private_rooms.append(canonical_alias or name or room_id)
+
+ user_info.ips = DatabasePool.simple_select_onecol_txn(
+ txn,
+ table="user_ips",
+ keyvalues={"user_id": user_info.user_id},
+ retcol="ip",
+ )
+
+ return user_infos
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-c",
+ "--config-path",
+ action="append",
+ metavar="CONFIG_FILE",
+ help="The config files for Synapse.",
+ required=True,
+ )
+ parser.add_argument(
+ "-s",
+ "--since",
+ metavar="duration",
+ help="Specify how far back to review user registrations for, defaults to 7d (i.e. 7 days).",
+ default="7d",
+ )
+ parser.add_argument(
+ "-e",
+ "--exclude-emails",
+ action="store_true",
+ help="Exclude users that have validated email addresses",
+ )
+ parser.add_argument(
+ "-u",
+ "--only-users",
+ action="store_true",
+ help="Only print user IDs that match.",
+ )
+
+ config = ReviewConfig()
+
+ config_args = parser.parse_args(sys.argv[1:])
+ config_files = find_config_files(search_paths=config_args.config_path)
+ config_dict = read_config_files(config_files)
+ config.parse_config_dict(
+ config_dict,
+ )
+
+ since_ms = time.time() * 1000 - config.parse_duration(config_args.since)
+ exclude_users_with_email = config_args.exclude_emails
+ include_context = not config_args.only_users
+
+ for database_config in config.database.databases:
+ if "main" in database_config.databases:
+ break
+
+ engine = create_engine(database_config.config)
+
+ with make_conn(database_config, engine, "review_recent_signups") as db_conn:
+ user_infos = get_recent_users(db_conn.cursor(), since_ms)
+
+ for user_info in user_infos:
+ if exclude_users_with_email and user_info.emails:
+ continue
+
+ if include_context:
+ print_public_rooms = ""
+ if user_info.public_rooms:
+ print_public_rooms = "(" + ", ".join(user_info.public_rooms[:3])
+
+ if len(user_info.public_rooms) > 3:
+ print_public_rooms += ", ..."
+
+ print_public_rooms += ")"
+
+ print("# Created:", datetime.fromtimestamp(user_info.creation_ts))
+ print("# Email:", ", ".join(user_info.emails) or "None")
+ print("# IPs:", ", ".join(user_info.ips))
+ print(
+ "# Number joined public rooms:",
+ len(user_info.public_rooms),
+ print_public_rooms,
+ )
+ print("# Number joined private rooms:", len(user_info.private_rooms))
+ print("#")
+
+ print(user_info.user_id)
+
+ if include_context:
+ print()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 0f7788b411..d26014ef4f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
@@ -28,7 +28,6 @@ from synapse.api.errors import (
InvalidClientTokenError,
MissingClientTokenError,
)
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.http import get_request_user_agent
@@ -38,7 +37,6 @@ from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
-from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -46,15 +44,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-AuthEventTypes = (
- EventTypes.Create,
- EventTypes.Member,
- EventTypes.PowerLevels,
- EventTypes.JoinRules,
- EventTypes.RoomHistoryVisibility,
- EventTypes.ThirdPartyInvite,
-)
-
# guests always get this device id.
GUEST_DEVICE_ID = "guest_device"
@@ -65,9 +54,7 @@ class _InvalidMacaroonException(Exception):
class Auth:
"""
- FIXME: This class contains a mix of functions for authenticating users
- of our client-server API and authenticating events added to room graphs.
- The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
+ This class contains functions for authenticating users of our client-server API.
"""
def __init__(self, hs: "HomeServer"):
@@ -89,18 +76,6 @@ class Auth:
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
- async def check_from_context(
- self, room_version: str, event, context, do_sig_check=True
- ) -> None:
- auth_event_ids = event.auth_event_ids()
- auth_events_by_id = await self.store.get_events(auth_event_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
-
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- event_auth.check(
- room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
- )
-
async def check_user_in_room(
self,
room_id: str,
@@ -151,13 +126,6 @@ class Auth:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
- async def check_host_in_room(self, room_id: str, host: str) -> bool:
- with Measure(self.clock, "check_host_in_room"):
- return await self.store.is_host_joined(room_id, host)
-
- def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
- return event_auth.get_public_keys(invite_event)
-
async def get_user_by_req(
self,
request: SynapseRequest,
@@ -245,6 +213,11 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
+ # Mark the token as used. This is used to invalidate old refresh
+ # tokens after some time.
+ if not user_info.token_used and token_id is not None:
+ await self.store.mark_access_token_as_used(token_id)
+
requester = create_requester(
user_info.user_id,
token_id,
@@ -488,44 +461,6 @@ class Auth:
"""
return await self.store.is_server_admin(user)
- def compute_auth_events(
- self,
- event,
- current_state_ids: StateMap[str],
- for_verification: bool = False,
- ) -> List[str]:
- """Given an event and current state return the list of event IDs used
- to auth an event.
-
- If `for_verification` is False then only return auth events that
- should be added to the event's `auth_events`.
-
- Returns:
- List of event IDs.
- """
-
- if event.type == EventTypes.Create:
- return []
-
- # Currently we ignore the `for_verification` flag even though there are
- # some situations where we can drop particular auth events when adding
- # to the event's `auth_events` (e.g. joins pointing to previous joins
- # when room is publicly joinable). Dropping event IDs has the
- # advantage that the auth chain for the room grows slower, but we use
- # the auth chain in state resolution v2 to order events, which means
- # care must be taken if dropping events to ensure that it doesn't
- # introduce undesirable "state reset" behaviour.
- #
- # All of which sounds a bit tricky so we don't bother for now.
-
- auth_ids = []
- for etype, state_key in event_auth.auth_types_for_event(event):
- auth_ev_id = current_state_ids.get((etype, state_key))
- if auth_ev_id:
- auth_ids.append(auth_ev_id)
-
- return auth_ids
-
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 414e4c019a..8363c2bb0f 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -201,6 +201,12 @@ class EventContentFields:
)
+class RoomTypes:
+ """Understood values of the room_type field of m.room.create events."""
+
+ SPACE = "m.space"
+
+
class RoomEncryptionAlgorithms:
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
DEFAULT = MEGOLM_V1_AES_SHA2
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 8879136881..b30571fe49 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -21,7 +21,7 @@ import socket
import sys
import traceback
import warnings
-from typing import Awaitable, Callable, Iterable
+from typing import TYPE_CHECKING, Awaitable, Callable, Iterable
from cryptography.utils import CryptographyDeprecationWarning
from typing_extensions import NoReturn
@@ -41,10 +41,14 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
+from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# list of tuples of function, args list, kwargs dict
@@ -312,7 +316,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
-async def start(hs: "synapse.server.HomeServer"):
+async def start(hs: "HomeServer"):
"""
Start a Synapse server or worker.
@@ -365,6 +369,9 @@ async def start(hs: "synapse.server.HomeServer"):
load_legacy_spam_checkers(hs)
+ # If we've configured an expiry time for caches, start the background job now.
+ setup_expire_lru_cache_entries(hs)
+
# It is now safe to start your Synapse.
hs.start_listening()
hs.get_datastore().db_pool.start_profiling()
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 23ca0c83c1..06fbd1166b 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -5,6 +5,7 @@ from synapse.config import (
api,
appservice,
auth,
+ cache,
captcha,
cas,
consent,
@@ -88,6 +89,7 @@ class RootConfig:
tracer: tracer.TracerConfig
redis: redis.RedisConfig
modules: modules.ModulesConfig
+ caches: cache.CacheConfig
federation: federation.FederationConfig
config_classes: List = ...
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 91165ee1ce..7789b40323 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -116,35 +116,41 @@ class CacheConfig(Config):
#event_cache_size: 10K
caches:
- # Controls the global cache factor, which is the default cache factor
- # for all caches if a specific factor for that cache is not otherwise
- # set.
- #
- # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
- # variable. Setting by environment variable takes priority over
- # setting through the config file.
- #
- # Defaults to 0.5, which will half the size of all caches.
- #
- #global_factor: 1.0
-
- # A dictionary of cache name to cache factor for that individual
- # cache. Overrides the global cache factor for a given cache.
- #
- # These can also be set through environment variables comprised
- # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
- # letters and underscores. Setting by environment variable
- # takes priority over setting through the config file.
- # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
- #
- # Some caches have '*' and other characters that are not
- # alphanumeric or underscores. These caches can be named with or
- # without the special characters stripped. For example, to specify
- # the cache factor for `*stateGroupCache*` via an environment
- # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
- #
- per_cache_factors:
- #get_users_who_share_room_with_user: 2.0
+ # Controls the global cache factor, which is the default cache factor
+ # for all caches if a specific factor for that cache is not otherwise
+ # set.
+ #
+ # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
+ # variable. Setting by environment variable takes priority over
+ # setting through the config file.
+ #
+ # Defaults to 0.5, which will half the size of all caches.
+ #
+ #global_factor: 1.0
+
+ # A dictionary of cache name to cache factor for that individual
+ # cache. Overrides the global cache factor for a given cache.
+ #
+ # These can also be set through environment variables comprised
+ # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
+ # letters and underscores. Setting by environment variable
+ # takes priority over setting through the config file.
+ # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
+ #
+ # Some caches have '*' and other characters that are not
+ # alphanumeric or underscores. These caches can be named with or
+ # without the special characters stripped. For example, to specify
+ # the cache factor for `*stateGroupCache*` via an environment
+ # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
+ #
+ per_cache_factors:
+ #get_users_who_share_room_with_user: 2.0
+
+ # Controls how long an entry can be in a cache without having been
+ # accessed before being evicted. Defaults to None, which means
+ # entries are never evicted based on time.
+ #
+ #expiry_time: 30m
"""
def read_config(self, config, **kwargs):
@@ -200,6 +206,12 @@ class CacheConfig(Config):
e.message # noqa: B306, DependencyException.message is a property
)
+ expiry_time = cache_config.get("expiry_time")
+ if expiry_time:
+ self.expiry_time_msec = self.parse_duration(expiry_time)
+ else:
+ self.expiry_time_msec = None
+
# Resize all caches (if necessary) with the new factors we've loaded
self.resize_all_caches()
diff --git a/synapse/config/consent.py b/synapse/config/consent.py
index 30d07cc219..b05a9bd97f 100644
--- a/synapse/config/consent.py
+++ b/synapse/config/consent.py
@@ -22,7 +22,7 @@ DEFAULT_CONFIG = """\
# User Consent configuration
#
# for detailed instructions, see
-# https://github.com/matrix-org/synapse/blob/master/docs/consent_tracking.md
+# https://matrix-org.github.io/synapse/latest/consent_tracking.html
#
# Parts of this section are required if enabling the 'consent' resource under
# 'listeners', in particular 'template_dir' and 'version'.
diff --git a/synapse/config/database.py b/synapse/config/database.py
index c76ef1e1de..3d7d92f615 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -62,7 +62,8 @@ DEFAULT_CONFIG = """\
# cp_min: 5
# cp_max: 10
#
-# For more information on using Synapse with Postgres, see `docs/postgres.md`.
+# For more information on using Synapse with Postgres,
+# see https://matrix-org.github.io/synapse/latest/postgres.html.
#
database:
name: sqlite3
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 9e07e73008..9d295f5856 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -64,7 +64,7 @@ class JWTConfig(Config):
# Note that this is a non-standard login type and client support is
# expected to be non-existent.
#
- # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md.
+ # See https://matrix-org.github.io/synapse/latest/jwt.html.
#
#jwt_config:
# Uncomment the following to enable authorization using JSON web
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 91d9bcf32e..ad4e6e61c3 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -49,7 +49,7 @@ DEFAULT_LOG_CONFIG = Template(
# be ingested by ELK stacks. See [2] for details.
#
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
-# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
+# [2]: https://matrix-org.github.io/synapse/latest/structured_logging.html
version: 1
diff --git a/synapse/config/modules.py b/synapse/config/modules.py
index 3209e1c492..ae0821e5a5 100644
--- a/synapse/config/modules.py
+++ b/synapse/config/modules.py
@@ -37,7 +37,7 @@ class ModulesConfig(Config):
# Server admins can expand Synapse's functionality with external modules.
#
- # See https://matrix-org.github.io/synapse/develop/modules.html for more
+ # See https://matrix-org.github.io/synapse/latest/modules.html for more
# documentation on how to configure or create custom modules for Synapse.
#
modules:
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index ea0abf5aa2..942e2672a9 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -166,7 +166,7 @@ class OIDCConfig(Config):
#
# module: The class name of a custom mapping module. Default is
# {mapping_provider!r}.
- # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+ # See https://matrix-org.github.io/synapse/latest/sso_mapping_providers.html#openid-mapping-providers
# for information on implementing a custom mapping provider.
#
# config: Configuration for the mapping provider module. This section will
@@ -217,7 +217,7 @@ class OIDCConfig(Config):
# - attribute: groups
# value: "admin"
#
- # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
+ # See https://matrix-org.github.io/synapse/latest/openid.html
# for information on how to configure these options.
#
# For backwards compatibility, it is also possible to configure a single OIDC
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 1cf69734bb..fd90b79772 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -57,7 +57,7 @@ class PasswordAuthProviderConfig(Config):
# ex. LDAP, external tokens, etc.
#
# For more information and known implementations, please see
- # https://github.com/matrix-org/synapse/blob/master/docs/password_auth_providers.md
+ # https://matrix-org.github.io/synapse/latest/password_auth_providers.html
#
# Note: instances wishing to use SAML or CAS authentication should
# instead use the `saml2_config` or `cas_config` options,
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index eecc0478a7..6e9f405312 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -153,6 +153,27 @@ class RegistrationConfig(Config):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime
+ # The `access_token_lifetime` applies for tokens that can be renewed
+ # using a refresh token, as per MSC2918. If it is `None`, the refresh
+ # token mechanism is disabled.
+ #
+ # Since it is incompatible with the `session_lifetime` mechanism, it is set to
+ # `None` by default if a `session_lifetime` is set.
+ access_token_lifetime = config.get(
+ "access_token_lifetime", "5m" if session_lifetime is None else None
+ )
+ if access_token_lifetime is not None:
+ access_token_lifetime = self.parse_duration(access_token_lifetime)
+ self.access_token_lifetime = access_token_lifetime
+
+ if session_lifetime is not None and access_token_lifetime is not None:
+ raise ConfigError(
+ "The refresh token mechanism is incompatible with the "
+ "`session_lifetime` option. Consider disabling the "
+ "`session_lifetime` option or disabling the refresh token "
+ "mechanism by removing the `access_token_lifetime` option."
+ )
+
# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 1bd40a89f0..b5820616da 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -256,7 +256,7 @@ class ContentRepositoryConfig(Config):
#
# If you are using a reverse proxy you may also need to set this value in
# your reverse proxy's config. Notably Nginx has a small max body size by default.
- # See https://matrix-org.github.io/synapse/develop/reverse_proxy.html.
+ # See https://matrix-org.github.io/synapse/latest/reverse_proxy.html.
#
#max_upload_size: 50M
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 20022b963f..60afee5804 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -153,7 +153,7 @@ ROOM_COMPLEXITY_TOO_GREAT = (
METRICS_PORT_WARNING = """\
The metrics_port configuration option is deprecated in Synapse 0.31 in favour of
a listener. Please see
-https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
+https://matrix-org.github.io/synapse/latest/metrics-howto.html
on how to configure the new listener.
--------------------------------------------------------------------------------"""
@@ -817,7 +817,7 @@ class ServerConfig(Config):
# In most cases you should avoid using a matrix specific subdomain such as
# matrix.example.com or synapse.example.com as the server_name for the same
# reasons you wouldn't use user@email.example.com as your email address.
- # See https://github.com/matrix-org/synapse/blob/master/docs/delegate.md
+ # See https://matrix-org.github.io/synapse/latest/delegate.html
# for information on how to host Synapse on a subdomain while preserving
# a clean server_name.
#
@@ -994,9 +994,9 @@ class ServerConfig(Config):
# 'all local interfaces'.
#
# type: the type of listener. Normally 'http', but other valid options are:
- # 'manhole' (see docs/manhole.md),
- # 'metrics' (see docs/metrics-howto.md),
- # 'replication' (see docs/workers.md).
+ # 'manhole' (see https://matrix-org.github.io/synapse/latest/manhole.html),
+ # 'metrics' (see https://matrix-org.github.io/synapse/latest/metrics-howto.html),
+ # 'replication' (see https://matrix-org.github.io/synapse/latest/workers.html).
#
# tls: set to true to enable TLS for this listener. Will use the TLS
# key/cert specified in tls_private_key_path / tls_certificate_path.
@@ -1021,8 +1021,8 @@ class ServerConfig(Config):
# client: the client-server API (/_matrix/client), and the synapse admin
# API (/_synapse/admin). Also implies 'media' and 'static'.
#
- # consent: user consent forms (/_matrix/consent). See
- # docs/consent_tracking.md.
+ # consent: user consent forms (/_matrix/consent).
+ # See https://matrix-org.github.io/synapse/latest/consent_tracking.html.
#
# federation: the server-server API (/_matrix/federation). Also implies
# 'media', 'keys', 'openid'
@@ -1031,12 +1031,13 @@ class ServerConfig(Config):
#
# media: the media API (/_matrix/media).
#
- # metrics: the metrics interface. See docs/metrics-howto.md.
+ # metrics: the metrics interface.
+ # See https://matrix-org.github.io/synapse/latest/metrics-howto.html.
#
# openid: OpenID authentication.
#
- # replication: the HTTP replication API (/_synapse/replication). See
- # docs/workers.md.
+ # replication: the HTTP replication API (/_synapse/replication).
+ # See https://matrix-org.github.io/synapse/latest/workers.html.
#
# static: static resources under synapse/static (/_matrix/static). (Mostly
# useful for 'fallback authentication'.)
@@ -1056,7 +1057,7 @@ class ServerConfig(Config):
# that unwraps TLS.
#
# If you plan to use a reverse proxy, please see
- # https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.md.
+ # https://matrix-org.github.io/synapse/latest/reverse_proxy.html.
#
%(unsecure_http_bindings)s
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index d0311d6468..cb7716c837 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -26,7 +26,7 @@ LEGACY_SPAM_CHECKER_WARNING = """
This server is using a spam checker module that is implementing the deprecated spam
checker interface. Please check with the module's maintainer to see if a new version
supporting Synapse's generic modules system is available.
-For more information, please see https://matrix-org.github.io/synapse/develop/modules.html
+For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 3d44b51201..78f61fe9da 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -51,7 +51,7 @@ class StatsConfig(Config):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Settings for local room and user statistics collection. See
- # docs/room_and_user_statistics.md.
+ # https://matrix-org.github.io/synapse/latest/room_and_user_statistics.html.
#
stats:
# Uncomment the following to disable room and user statistics. Note that doing
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index d0ea17261f..21b9a88353 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -81,7 +81,7 @@ class TracerConfig(Config):
#enabled: true
# The list of homeservers we wish to send and receive span contexts and span baggage.
- # See docs/opentracing.rst.
+ # See https://matrix-org.github.io/synapse/latest/opentracing.html.
#
# This is a list of regexes which are matched against the server_name of the
# homeserver.
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 3ac1f2b5b1..f1beb87aea 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -53,7 +53,7 @@ class UserDirectoryConfig(Config):
#
# If you set it true, you'll have to rebuild the user_directory search
# indexes, see:
- # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
+ # https://matrix-org.github.io/synapse/latest/user_directory.html
#
# Uncomment to return search results containing all known users, even if that
# user does not share a room with the requester.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 33d7c60241..89bcf81515 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -29,6 +29,7 @@ from synapse.api.room_versions import (
RoomVersion,
)
from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
from synapse.types import StateMap, UserID, get_domain_from_id
logger = logging.getLogger(__name__)
@@ -724,7 +725,7 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
return public_keys
-def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
+def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 0cb9c1cc1e..6286ad999a 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -118,7 +118,7 @@ class _EventInternalMetadata:
proactively_send = DictProperty("proactively_send") # type: bool
redacted = DictProperty("redacted") # type: bool
txn_id = DictProperty("txn_id") # type: str
- token_id = DictProperty("token_id") # type: str
+ token_id = DictProperty("token_id") # type: int
historical = DictProperty("historical") # type: bool
# XXX: These are set by StreamWorkerStore._set_before_and_after.
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 81bf8615b7..26e3950859 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
from nacl.signing import SigningKey
-from synapse.api.auth import Auth
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
@@ -34,10 +33,14 @@ from synapse.types import EventID, JsonDict
from synapse.util import Clock
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.handlers.event_auth import EventAuthHandler
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-@attr.s(slots=True, cmp=False, frozen=True)
+@attr.s(slots=True, cmp=False, frozen=True, auto_attribs=True)
class EventBuilder:
"""A format independent event builder used to build up the event content
before signing the event.
@@ -62,31 +65,30 @@ class EventBuilder:
_signing_key: The signing key to use to sign the event as the server
"""
- _state = attr.ib(type=StateHandler)
- _auth = attr.ib(type=Auth)
- _store = attr.ib(type=DataStore)
- _clock = attr.ib(type=Clock)
- _hostname = attr.ib(type=str)
- _signing_key = attr.ib(type=SigningKey)
+ _state: StateHandler
+ _event_auth_handler: "EventAuthHandler"
+ _store: DataStore
+ _clock: Clock
+ _hostname: str
+ _signing_key: SigningKey
- room_version = attr.ib(type=RoomVersion)
+ room_version: RoomVersion
- room_id = attr.ib(type=str)
- type = attr.ib(type=str)
- sender = attr.ib(type=str)
+ room_id: str
+ type: str
+ sender: str
- content = attr.ib(default=attr.Factory(dict), type=JsonDict)
- unsigned = attr.ib(default=attr.Factory(dict), type=JsonDict)
+ content: JsonDict = attr.Factory(dict)
+ unsigned: JsonDict = attr.Factory(dict)
# These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist.
- _state_key = attr.ib(default=None, type=Optional[str])
- _redacts = attr.ib(default=None, type=Optional[str])
- _origin_server_ts = attr.ib(default=None, type=Optional[int])
+ _state_key: Optional[str] = None
+ _redacts: Optional[str] = None
+ _origin_server_ts: Optional[int] = None
- internal_metadata = attr.ib(
- default=attr.Factory(lambda: _EventInternalMetadata({})),
- type=_EventInternalMetadata,
+ internal_metadata: _EventInternalMetadata = attr.Factory(
+ lambda: _EventInternalMetadata({})
)
@property
@@ -123,7 +125,9 @@ class EventBuilder:
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
- auth_event_ids = self._auth.compute_auth_events(self, state_ids)
+ auth_event_ids = self._event_auth_handler.compute_auth_events(
+ self, state_ids
+ )
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
@@ -184,24 +188,23 @@ class EventBuilder:
class EventBuilderFactory:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hostname = hs.hostname
self.signing_key = hs.signing_key
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- self.auth = hs.get_auth()
+ self._event_auth_handler = hs.get_event_auth_handler()
- def new(self, room_version, key_values):
+ def new(self, room_version: str, key_values: dict) -> EventBuilder:
"""Generate an event builder appropriate for the given room version
Deprecated: use for_room_version with a RoomVersion object instead
Args:
- room_version (str): Version of the room that we're creating an event builder
- for
- key_values (dict): Fields used as the basis of the new event
+ room_version: Version of the room that we're creating an event builder for
+ key_values: Fields used as the basis of the new event
Returns:
EventBuilder
@@ -212,13 +215,15 @@ class EventBuilderFactory:
raise UnsupportedRoomVersionError()
return self.for_room_version(v, key_values)
- def for_room_version(self, room_version, key_values):
+ def for_room_version(
+ self, room_version: RoomVersion, key_values: dict
+ ) -> EventBuilder:
"""Generate an event builder appropriate for the given room version
Args:
- room_version (synapse.api.room_versions.RoomVersion):
+ room_version:
Version of the room that we're creating an event builder for
- key_values (dict): Fields used as the basis of the new event
+ key_values: Fields used as the basis of the new event
Returns:
EventBuilder
@@ -226,7 +231,7 @@ class EventBuilderFactory:
return EventBuilder(
store=self.store,
state=self.state,
- auth=self.auth,
+ event_auth_handler=self._event_auth_handler,
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
@@ -286,15 +291,15 @@ def create_local_event_from_event_dict(
_event_id_counter = 0
-def _create_event_id(clock, hostname):
+def _create_event_id(clock: Clock, hostname: str) -> str:
"""Create a new event ID
Args:
- clock (Clock)
- hostname (str): The server name for the event ID
+ clock
+ hostname: The server name for the event ID
Returns:
- str
+ The new event ID
"""
global _event_id_counter
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index c066617b92..2bfe6a3d37 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -89,12 +89,12 @@ class FederationBase:
result = await self.spam_checker.check_event_for_spam(pdu)
if result:
- logger.warning(
- "Event contains spam, redacting %s: %s",
- pdu.event_id,
- pdu.get_pdu_json(),
- )
- return prune_event(pdu)
+ logger.warning("Event contains spam, soft-failing %s", pdu.event_id)
+ # we redact (to save disk space) as well as soft-failing (to stop
+ # using the event in prev_events).
+ redacted_event = prune_event(pdu)
+ redacted_event.internal_metadata.soft_failed = True
+ return redacted_event
return pdu
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 1d050e54e2..ac0f2ccfb3 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -34,7 +34,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import EduTypes, EventTypes
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -46,6 +46,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
@@ -107,9 +108,9 @@ class FederationServer(FederationBase):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.auth = hs.get_auth()
self.handler = hs.get_federation_handler()
self.state = hs.get_state_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
self.device_handler = hs.get_device_handler()
@@ -147,6 +148,41 @@ class FederationServer(FederationBase):
self._room_prejoin_state_types = hs.config.api.room_prejoin_state
+ # Whether we have started handling old events in the staging area.
+ self._started_handling_of_staged_events = False
+
+ @wrap_as_background_process("_handle_old_staged_events")
+ async def _handle_old_staged_events(self) -> None:
+ """Handle old staged events by fetching all rooms that have staged
+ events and start the processing of each of those rooms.
+ """
+
+ # Get all the rooms IDs with staged events.
+ room_ids = await self.store.get_all_rooms_with_staged_incoming_events()
+
+ # We then shuffle them so that if there are multiple instances doing
+ # this work they're less likely to collide.
+ random.shuffle(room_ids)
+
+ for room_id in room_ids:
+ room_version = await self.store.get_room_version(room_id)
+
+ # Try and acquire the processing lock for the room, if we get it start a
+ # background process for handling the events in the room.
+ lock = await self.store.try_acquire_lock(
+ _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
+ )
+ if lock:
+ logger.info("Handling old staged inbound events in %s", room_id)
+ self._process_incoming_pdus_in_room_inner(
+ room_id,
+ room_version,
+ lock,
+ )
+
+ # We pause a bit so that we don't start handling all rooms at once.
+ await self._clock.sleep(random.uniform(0, 0.1))
+
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
@@ -165,6 +201,12 @@ class FederationServer(FederationBase):
async def on_incoming_transaction(
self, origin: str, transaction_data: JsonDict
) -> Tuple[int, Dict[str, Any]]:
+ # If we receive a transaction we should make sure that kick off handling
+ # any old events in the staging area.
+ if not self._started_handling_of_staged_events:
+ self._started_handling_of_staged_events = True
+ self._handle_old_staged_events()
+
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -368,22 +410,21 @@ class FederationServer(FederationBase):
async def process_pdu(pdu: EventBase) -> JsonDict:
event_id = pdu.event_id
- with pdu_process_time.time():
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- return {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- return {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
- )
- return {"error": str(e)}
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ return {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ return {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
+ )
+ return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
@@ -420,7 +461,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -453,7 +494,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -544,26 +585,21 @@ class FederationServer(FederationBase):
return {"event": ret_pdu.get_pdu_json(time_now)}
async def on_send_join_request(
- self, origin: str, content: JsonDict
+ self, origin: str, content: JsonDict, room_id: str
) -> Dict[str, Any]:
- logger.debug("on_send_join_request: content: %s", content)
-
- assert_params_in_dict(content, ["room_id"])
- room_version = await self.store.get_room_version(content["room_id"])
- pdu = event_from_pdu_json(content, room_version)
-
- origin_host, _ = parse_server_name(origin)
- await self.check_server_matches_acl(origin_host, pdu.room_id)
-
- logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
+ context = await self._on_send_membership_event(
+ origin, content, Membership.JOIN, room_id
+ )
- pdu = await self._check_sigs_and_hash(room_version, pdu)
+ prev_state_ids = await context.get_prev_state_ids()
+ state_ids = list(prev_state_ids.values())
+ auth_chain = await self.store.get_auth_chain(room_id, state_ids)
+ state = await self.store.get_events(state_ids)
- res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
return {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+ "state": [p.get_pdu_json(time_now) for p in state.values()],
+ "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
}
async def on_make_leave_request(
@@ -578,21 +614,11 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
+ async def on_send_leave_request(
+ self, origin: str, content: JsonDict, room_id: str
+ ) -> dict:
logger.debug("on_send_leave_request: content: %s", content)
-
- assert_params_in_dict(content, ["room_id"])
- room_version = await self.store.get_room_version(content["room_id"])
- pdu = event_from_pdu_json(content, room_version)
-
- origin_host, _ = parse_server_name(origin)
- await self.check_server_matches_acl(origin_host, pdu.room_id)
-
- logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
-
- pdu = await self._check_sigs_and_hash(room_version, pdu)
-
- await self.handler.on_send_leave_request(origin, pdu)
+ await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id)
return {}
async def on_make_knock_request(
@@ -658,39 +684,76 @@ class FederationServer(FederationBase):
Returns:
The stripped room state.
"""
- logger.debug("on_send_knock_request: content: %s", content)
+ event_context = await self._on_send_membership_event(
+ origin, content, Membership.KNOCK, room_id
+ )
+
+ # Retrieve stripped state events from the room and send them back to the remote
+ # server. This will allow the remote server's clients to display information
+ # related to the room while the knock request is pending.
+ stripped_room_state = (
+ await self.store.get_stripped_room_state_from_event_context(
+ event_context, self._room_prejoin_state_types
+ )
+ )
+ return {"knock_state_events": stripped_room_state}
+
+ async def _on_send_membership_event(
+ self, origin: str, content: JsonDict, membership_type: str, room_id: str
+ ) -> EventContext:
+ """Handle an on_send_{join,leave,knock} request
+
+ Does some preliminary validation before passing the request on to the
+ federation handler.
+
+ Args:
+ origin: The (authenticated) requesting server
+ content: The body of the send_* request - a complete membership event
+ membership_type: The expected membership type (join or leave, depending
+ on the endpoint)
+ room_id: The room_id from the request, to be validated against the room_id
+ in the event
+
+ Returns:
+ The context of the event after inserting it into the room graph.
+
+ Raises:
+ SynapseError if there is a problem with the request, including things like
+ the room_id not matching or the event not being authorized.
+ """
+ assert_params_in_dict(content, ["room_id"])
+ if content["room_id"] != room_id:
+ raise SynapseError(
+ 400,
+ "Room ID in body does not match that in request path",
+ Codes.BAD_JSON,
+ )
room_version = await self.store.get_room_version(room_id)
- # Check that this room supports knocking as defined by its room version
- if not room_version.msc2403_knocking:
+ if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
raise SynapseError(
403,
"This room version does not support knocking",
errcode=Codes.FORBIDDEN,
)
- pdu = event_from_pdu_json(content, room_version)
+ event = event_from_pdu_json(content, room_version)
- origin_host, _ = parse_server_name(origin)
- await self.check_server_matches_acl(origin_host, pdu.room_id)
+ if event.type != EventTypes.Member or not event.is_state():
+ raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON)
- logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
+ if event.content.get("membership") != membership_type:
+ raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON)
- pdu = await self._check_sigs_and_hash(room_version, pdu)
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, event.room_id)
- # Handle the event, and retrieve the EventContext
- event_context = await self.handler.on_send_knock_request(origin, pdu)
+ logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
- # Retrieve stripped state events from the room and send them back to the remote
- # server. This will allow the remote server's clients to display information
- # related to the room while the knock request is pending.
- stripped_room_state = (
- await self.store.get_stripped_room_state_from_event_context(
- event_context, self._room_prejoin_state_types
- )
- )
- return {"knock_state_events": stripped_room_state}
+ event = await self._check_sigs_and_hash(room_version, event)
+
+ return await self.handler.on_send_membership_event(origin, event)
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
@@ -860,32 +923,39 @@ class FederationServer(FederationBase):
room_id: str,
room_version: RoomVersion,
lock: Lock,
- latest_origin: str,
- latest_event: EventBase,
+ latest_origin: Optional[str] = None,
+ latest_event: Optional[EventBase] = None,
) -> None:
"""Process events in the staging area for the given room.
The latest_origin and latest_event args are the latest origin and event
- received.
+ received (or None to simply pull the next event from the database).
"""
# The common path is for the event we just received be the only event in
# the room, so instead of pulling the event out of the DB and parsing
# the event we just pull out the next event ID and check if that matches.
- next_origin, next_event_id = await self.store.get_next_staged_event_id_for_room(
- room_id
- )
- if next_origin == latest_origin and next_event_id == latest_event.event_id:
- origin = latest_origin
- event = latest_event
- else:
+ if latest_event is not None and latest_origin is not None:
+ (
+ next_origin,
+ next_event_id,
+ ) = await self.store.get_next_staged_event_id_for_room(room_id)
+ if next_origin != latest_origin or next_event_id != latest_event.event_id:
+ latest_origin = None
+ latest_event = None
+
+ if latest_origin is None or latest_event is None:
next = await self.store.get_next_staged_event_for_room(
room_id, room_version
)
if not next:
+ await lock.release()
return
origin, event = next
+ else:
+ origin = latest_origin
+ event = latest_event
# We loop round until there are no more events in the room in the
# staging area, or we fail to get the lock (which means another process
@@ -909,9 +979,13 @@ class FederationServer(FederationBase):
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
- await self.store.remove_received_event_from_staging(
+ received_ts = await self.store.remove_received_event_from_staging(
origin, event.event_id
)
+ if received_ts is not None:
+ pdu_process_time.observe(
+ (self._clock.time_msec() - received_ts) / 1000
+ )
# We need to do this check outside the lock to avoid a race between
# a new event being inserted by another instance and it attempting
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a9942b41fb..5685a71a4b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -15,7 +15,19 @@
import functools
import logging
import re
-from typing import Container, Mapping, Optional, Sequence, Tuple, Type
+from typing import (
+ Container,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
+
+from typing_extensions import Literal
import synapse
from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
@@ -57,15 +69,15 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
- def __init__(self, hs, servlet_groups=None):
+ def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None):
"""Initialize the TransportLayerServer
Will by default register all servlets. For custom behaviour, pass in
a list of servlet_groups to register.
Args:
- hs (synapse.server.HomeServer): homeserver
- servlet_groups (list[str], optional): List of servlet groups to register.
+ hs: homeserver
+ servlet_groups: List of servlet groups to register.
Defaults to ``DEFAULT_SERVLET_GROUPS``.
"""
self.hs = hs
@@ -79,7 +91,7 @@ class TransportLayerServer(JsonResource):
self.register_servlets()
- def register_servlets(self):
+ def register_servlets(self) -> None:
register_servlets(
self.hs,
resource=self,
@@ -92,14 +104,10 @@ class TransportLayerServer(JsonResource):
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
- pass
-
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
- pass
-
class Authenticator:
def __init__(self, hs: HomeServer):
@@ -411,13 +419,18 @@ class FederationSendServlet(BaseFederationServerServlet):
RATELIMIT = False
# This is when someone is trying to send us a bunch of data.
- async def on_PUT(self, origin, content, query, transaction_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ transaction_id: str,
+ ) -> Tuple[int, JsonDict]:
"""Called on PUT /send/<transaction_id>/
Args:
- request (twisted.web.http.Request): The HTTP request.
- transaction_id (str): The transaction_id associated with this
- request. This is *not* None.
+ transaction_id: The transaction_id associated with this request. This
+ is *not* None.
Returns:
Tuple of `(code, response)`, where
@@ -462,7 +475,13 @@ class FederationEventServlet(BaseFederationServerServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
- async def on_GET(self, origin, content, query, event_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ event_id: str,
+ ) -> Tuple[int, Union[JsonDict, str]]:
return await self.handler.on_pdu_request(origin, event_id)
@@ -470,7 +489,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
PATH = "/state/(?P<room_id>[^/]*)/?"
# This is when someone asks for all data for a given room.
- async def on_GET(self, origin, content, query, room_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_room_state_request(
origin,
room_id,
@@ -481,7 +506,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
class FederationStateIdsServlet(BaseFederationServerServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
- async def on_GET(self, origin, content, query, room_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_state_ids_request(
origin,
room_id,
@@ -492,7 +523,13 @@ class FederationStateIdsServlet(BaseFederationServerServlet):
class FederationBackfillServlet(BaseFederationServerServlet):
PATH = "/backfill/(?P<room_id>[^/]*)/?"
- async def on_GET(self, origin, content, query, room_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
@@ -506,7 +543,13 @@ class FederationQueryServlet(BaseFederationServerServlet):
PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query
- async def on_GET(self, origin, content, query, query_type):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ query_type: str,
+ ) -> Tuple[int, JsonDict]:
args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}
args["origin"] = origin
return await self.handler.on_query_request(query_type, args)
@@ -515,47 +558,66 @@ class FederationQueryServlet(BaseFederationServerServlet):
class FederationMakeJoinServlet(BaseFederationServerServlet):
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, _content, query, room_id, user_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
"""
Args:
- origin (unicode): The authenticated server_name of the calling server
+ origin: The authenticated server_name of the calling server
- _content (None): (GETs don't have bodies)
+ content: (GETs don't have bodies)
- query (dict[bytes, list[bytes]]): Query params from the request.
+ query: Query params from the request.
- **kwargs (dict[unicode, unicode]): the dict mapping keys to path
- components as specified in the path match regexp.
+ **kwargs: the dict mapping keys to path components as specified in
+ the path match regexp.
Returns:
- Tuple[int, object]: (response code, response object)
+ Tuple of (response code, response object)
"""
- versions = query.get(b"ver")
- if versions is not None:
- supported_versions = [v.decode("utf-8") for v in versions]
- else:
+ supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
+ if supported_versions is None:
supported_versions = ["1"]
- content = await self.handler.on_make_join_request(
+ result = await self.handler.on_make_join_request(
origin, room_id, user_id, supported_versions=supported_versions
)
- return 200, content
+ return 200, result
class FederationMakeLeaveServlet(BaseFederationServerServlet):
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, content, query, room_id, user_id):
- content = await self.handler.on_make_leave_request(origin, room_id, user_id)
- return 200, content
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
+ result = await self.handler.on_make_leave_request(origin, room_id, user_id)
+ return 200, result
class FederationV1SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id, event_id):
- content = await self.handler.on_send_leave_request(origin, content)
- return 200, (200, content)
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, Tuple[int, JsonDict]]:
+ result = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, (200, result)
class FederationV2SendLeaveServlet(BaseFederationServerServlet):
@@ -563,50 +625,84 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX
- async def on_PUT(self, origin, content, query, room_id, event_id):
- content = await self.handler.on_send_leave_request(origin, content)
- return 200, content
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
+ result = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, result
class FederationMakeKnockServlet(BaseFederationServerServlet):
PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, content, query, room_id, user_id):
- try:
- # Retrieve the room versions the remote homeserver claims to support
- supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
- except KeyError:
- raise SynapseError(400, "Missing required query parameter 'ver'")
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
+ # Retrieve the room versions the remote homeserver claims to support
+ supported_versions = parse_strings_from_args(
+ query, "ver", required=True, encoding="utf-8"
+ )
- content = await self.handler.on_make_knock_request(
+ result = await self.handler.on_make_knock_request(
origin, room_id, user_id, supported_versions=supported_versions
)
- return 200, content
+ return 200, result
class FederationV1SendKnockServlet(BaseFederationServerServlet):
PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id, event_id):
- content = await self.handler.on_send_knock_request(origin, content, room_id)
- return 200, content
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
+ result = await self.handler.on_send_knock_request(origin, content, room_id)
+ return 200, result
class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_GET(self, origin, content, query, room_id, event_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_event_auth(origin, room_id, event_id)
class FederationV1SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id, event_id):
- # TODO(paul): assert that room_id/event_id parsed from path actually
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, Tuple[int, JsonDict]]:
+ # TODO(paul): assert that event_id parsed from path actually
# match those given in content
- content = await self.handler.on_send_join_request(origin, content)
- return 200, (200, content)
+ result = await self.handler.on_send_join_request(origin, content, room_id)
+ return 200, (200, result)
class FederationV2SendJoinServlet(BaseFederationServerServlet):
@@ -614,28 +710,42 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX
- async def on_PUT(self, origin, content, query, room_id, event_id):
- # TODO(paul): assert that room_id/event_id parsed from path actually
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
+ # TODO(paul): assert that event_id parsed from path actually
# match those given in content
- content = await self.handler.on_send_join_request(origin, content)
- return 200, content
+ result = await self.handler.on_send_join_request(origin, content, room_id)
+ return 200, result
class FederationV1InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id, event_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, Tuple[int, JsonDict]]:
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
# invites
- content = await self.handler.on_invite_request(
+ result = await self.handler.on_invite_request(
origin, content, room_version_id=RoomVersions.V1.identifier
)
# V1 federation API is defined to return a content of `[200, {...}]`
# due to a historical bug.
- return 200, (200, content)
+ return 200, (200, result)
class FederationV2InviteServlet(BaseFederationServerServlet):
@@ -643,7 +753,14 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX
- async def on_PUT(self, origin, content, query, room_id, event_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
@@ -656,16 +773,22 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
- content = await self.handler.on_invite_request(
+ result = await self.handler.on_invite_request(
origin, event, room_version_id=room_version
)
- return 200, content
+ return 200, result
class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
await self.handler.on_exchange_third_party_invite_request(content)
return 200, {}
@@ -673,21 +796,31 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
class FederationClientKeysQueryServlet(BaseFederationServerServlet):
PATH = "/user/keys/query"
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, content, query, user_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServerServlet):
PATH = "/user/keys/claim"
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(origin, content)
return 200, response
@@ -696,12 +829,18 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
- async def on_POST(self, origin, content, query, room_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
limit = int(content.get("limit", 10))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
- content = await self.handler.on_get_missing_events(
+ result = await self.handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
@@ -709,7 +848,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
limit=limit,
)
- return 200, content
+ return 200, result
class On3pidBindServlet(BaseFederationServerServlet):
@@ -717,7 +856,9 @@ class On3pidBindServlet(BaseFederationServerServlet):
REQUIRE_AUTH = False
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -763,15 +904,20 @@ class OpenIdUserInfo(BaseFederationServerServlet):
REQUIRE_AUTH = False
- async def on_GET(self, origin, content, query):
- token = query.get(b"access_token", [None])[0]
+ async def on_GET(
+ self,
+ origin: Optional[str],
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ ) -> Tuple[int, JsonDict]:
+ token = parse_string_from_args(query, "access_token")
if token is None:
return (
401,
{"errcode": "M_MISSING_TOKEN", "error": "Access Token required"},
)
- user_id = await self.handler.on_openid_userinfo(token.decode("ascii"))
+ user_id = await self.handler.on_openid_userinfo(token)
if user_id is None:
return (
@@ -830,7 +976,9 @@ class PublicRoomList(BaseFederationServlet):
self.handler = hs.get_room_list_handler()
self.allow_access = allow_access
- async def on_GET(self, origin, content, query):
+ async def on_GET(
+ self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
if not self.allow_access:
raise FederationDeniedError(origin)
@@ -859,7 +1007,9 @@ class PublicRoomList(BaseFederationServlet):
)
return 200, data
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
# This implements MSC2197 (Search Filtering over Federation)
if not self.allow_access:
raise FederationDeniedError(origin)
@@ -956,7 +1106,12 @@ class FederationVersionServlet(BaseFederationServlet):
REQUIRE_AUTH = False
- async def on_GET(self, origin, content, query):
+ async def on_GET(
+ self,
+ origin: Optional[str],
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ ) -> Tuple[int, JsonDict]:
return (
200,
{"server": {"name": "Synapse", "version": get_version_string(synapse)}},
@@ -985,7 +1140,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/profile"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -994,7 +1155,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
return 200, new_content
- async def on_POST(self, origin, content, query, group_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1009,7 +1176,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1024,7 +1197,13 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1039,7 +1218,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
- async def on_POST(self, origin, content, query, group_id, room_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1050,7 +1236,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
return 200, new_content
- async def on_DELETE(self, origin, content, query, group_id, room_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1070,7 +1263,15 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
"/config/(?P<config_key>[^/]*)"
)
- async def on_POST(self, origin, content, query, group_id, room_id, config_key):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ room_id: str,
+ config_key: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1087,7 +1288,13 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1102,7 +1309,13 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1119,7 +1332,14 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1136,7 +1356,14 @@ class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
@@ -1150,7 +1377,14 @@ class FederationGroupsJoinServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
@@ -1164,7 +1398,14 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1198,7 +1439,14 @@ class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet):
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
@@ -1216,7 +1464,14 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, None]:
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
@@ -1224,11 +1479,9 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
self.handler, GroupsLocalHandler
), "Workers cannot handle group removals."
- new_content = await self.handler.user_removed_from_group(
- group_id, user_id, content
- )
+ await self.handler.user_removed_from_group(group_id, user_id, content)
- return 200, new_content
+ return 200, None
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
@@ -1246,7 +1499,14 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_groups_attestation_renewer()
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
# We don't need to check auth here as we check the attestation signatures
new_content = await self.handler.on_renew_attestation(
@@ -1270,7 +1530,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
"/rooms/(?P<room_id>[^/]*)"
)
- async def on_POST(self, origin, content, query, group_id, category_id, room_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1298,7 +1566,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1318,7 +1594,13 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1333,7 +1615,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
- async def on_GET(self, origin, content, query, group_id, category_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1344,7 +1633,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_POST(self, origin, content, query, group_id, category_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1366,7 +1662,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_DELETE(self, origin, content, query, group_id, category_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1386,7 +1689,13 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1401,7 +1710,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
- async def on_GET(self, origin, content, query, group_id, role_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1410,7 +1726,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_POST(self, origin, content, query, group_id, role_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1434,7 +1757,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_DELETE(self, origin, content, query, group_id, role_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1463,7 +1793,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
"/users/(?P<user_id>[^/]*)"
)
- async def on_POST(self, origin, content, query, group_id, role_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1489,7 +1827,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1509,7 +1855,9 @@ class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet):
PATH = "/get_groups_publicised"
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
resp = await self.handler.bulk_get_publicised_groups(
content["user_ids"], proxy=False
)
@@ -1522,7 +1870,13 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
- async def on_PUT(self, origin, content, query, group_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1551,7 +1905,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
async def on_GET(
self,
origin: str,
- content: JsonDict,
+ content: Literal[None],
query: Mapping[bytes, Sequence[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
@@ -1623,7 +1977,13 @@ class RoomComplexityServlet(BaseFederationServlet):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._store = self.hs.get_datastore()
- async def on_GET(self, origin, content, query, room_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
is_public = await self._store.is_room_world_readable_or_publicly_joinable(
room_id
)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f72ded038e..d75a8b15c3 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -62,9 +62,16 @@ class AdminHandler(BaseHandler):
if ret:
profile = await self.store.get_profileinfo(user.localpart)
threepids = await self.store.user_get_threepids(user.to_string())
+ external_ids = [
+ ({"auth_provider": auth_provider, "external_id": external_id})
+ for auth_provider, external_id in await self.store.get_external_ids_by_user(
+ user.to_string()
+ )
+ ]
ret["displayname"] = profile.display_name
ret["avatar_url"] = profile.avatar_url
ret["threepids"] = threepids
+ ret["external_ids"] = external_ids
return ret
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1971e373ed..e2ac595a62 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -30,6 +30,7 @@ from typing import (
Optional,
Tuple,
Union,
+ cast,
)
import attr
@@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode
from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
+ from synapse.rest.client.v1.login import LoginResponse
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -777,6 +779,108 @@ class AuthHandler(BaseHandler):
"params": params,
}
+ async def refresh_token(
+ self,
+ refresh_token: str,
+ valid_until_ms: Optional[int],
+ ) -> Tuple[str, str]:
+ """
+ Consumes a refresh token and generate both a new access token and a new refresh token from it.
+
+ The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
+
+ Args:
+ refresh_token: The token to consume.
+ valid_until_ms: The expiration timestamp of the new access token.
+
+ Returns:
+ A tuple containing the new access token and refresh token
+ """
+
+ # Verify the token signature first before looking up the token
+ if not self._verify_refresh_token(refresh_token):
+ raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+
+ existing_token = await self.store.lookup_refresh_token(refresh_token)
+ if existing_token is None:
+ raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+
+ if (
+ existing_token.has_next_access_token_been_used
+ or existing_token.has_next_refresh_token_been_refreshed
+ ):
+ raise SynapseError(
+ 403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+ )
+
+ (
+ new_refresh_token,
+ new_refresh_token_id,
+ ) = await self.get_refresh_token_for_user_id(
+ user_id=existing_token.user_id, device_id=existing_token.device_id
+ )
+ access_token = await self.get_access_token_for_user_id(
+ user_id=existing_token.user_id,
+ device_id=existing_token.device_id,
+ valid_until_ms=valid_until_ms,
+ refresh_token_id=new_refresh_token_id,
+ )
+ await self.store.replace_refresh_token(
+ existing_token.token_id, new_refresh_token_id
+ )
+ return access_token, new_refresh_token
+
+ def _verify_refresh_token(self, token: str) -> bool:
+ """
+ Verifies the shape of a refresh token.
+
+ Args:
+ token: The refresh token to verify
+
+ Returns:
+ Whether the token has the right shape
+ """
+ parts = token.split("_", maxsplit=4)
+ if len(parts) != 4:
+ return False
+
+ type, localpart, rand, crc = parts
+
+ # Refresh tokens are prefixed by "syr_", let's check that
+ if type != "syr":
+ return False
+
+ # Check the CRC
+ base = f"{type}_{localpart}_{rand}"
+ expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+ if crc != expected_crc:
+ return False
+
+ return True
+
+ async def get_refresh_token_for_user_id(
+ self,
+ user_id: str,
+ device_id: str,
+ ) -> Tuple[str, int]:
+ """
+ Creates a new refresh token for the user with the given user ID.
+
+ Args:
+ user_id: canonical user ID
+ device_id: the device ID to associate with the token.
+
+ Returns:
+ The newly created refresh token and its ID in the database
+ """
+ refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
+ refresh_token_id = await self.store.add_refresh_token_to_user(
+ user_id=user_id,
+ token=refresh_token,
+ device_id=device_id,
+ )
+ return refresh_token, refresh_token_id
+
async def get_access_token_for_user_id(
self,
user_id: str,
@@ -784,6 +888,7 @@ class AuthHandler(BaseHandler):
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
is_appservice_ghost: bool = False,
+ refresh_token_id: Optional[int] = None,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -801,6 +906,8 @@ class AuthHandler(BaseHandler):
valid_until_ms: when the token is valid until. None for
no expiry.
is_appservice_ghost: Whether the user is an application ghost user
+ refresh_token_id: the refresh token ID that will be associated with
+ this access token.
Returns:
The access token for the user's session.
Raises:
@@ -836,6 +943,7 @@ class AuthHandler(BaseHandler):
device_id=device_id,
valid_until_ms=valid_until_ms,
puppets_user_id=puppets_user_id,
+ refresh_token_id=refresh_token_id,
)
# the device *should* have been registered before we got here; however,
@@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler):
self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler):
self,
username: str,
login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler):
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}"
+ def generate_refresh_token(self, for_user: UserID) -> str:
+ """Generates an opaque string, for use as a refresh token"""
+
+ # we use the following format for refresh tokens:
+ # syr_<base64 local part>_<random string>_<base62 crc check>
+
+ b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
+ random_string = stringutils.random_string(20)
+ base = f"syr_{b64local}_{random_string}"
+
+ crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+ return f"{base}_{crc}"
+
async def validate_short_term_login_token(
self, login_token: str
) -> LoginTokenAttributes:
@@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler):
)
respond_with_html(request, 200, html)
- async def _sso_login_callback(self, login_result: JsonDict) -> None:
+ async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
"""
A login callback which might add additional attributes to the login response.
@@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler):
extra_attributes = self._extra_attributes.get(login_result["user_id"])
if extra_attributes:
- login_result.update(extra_attributes.extra_attributes)
+ login_result_dict = cast(Dict[str, Any], login_result)
+ login_result_dict.update(extra_attributes.extra_attributes)
def _expire_sso_extra_attributes(self) -> None:
"""
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 989996b628..41dbdfd0a1 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Collection, Optional
+from typing import TYPE_CHECKING, Collection, List, Optional, Union
+from synapse import event_auth
from synapse.api.constants import (
EventTypes,
JoinRules,
@@ -20,9 +21,11 @@ from synapse.api.constants import (
RestrictedJoinRuleTypes,
)
from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersion
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
from synapse.types import StateMap
+from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -34,8 +37,63 @@ class EventAuthHandler:
"""
def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
self._store = hs.get_datastore()
+ async def check_from_context(
+ self, room_version: str, event, context, do_sig_check=True
+ ) -> None:
+ auth_event_ids = event.auth_event_ids()
+ auth_events_by_id = await self._store.get_events(auth_event_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
+
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ event_auth.check(
+ room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
+ )
+
+ def compute_auth_events(
+ self,
+ event: Union[EventBase, EventBuilder],
+ current_state_ids: StateMap[str],
+ for_verification: bool = False,
+ ) -> List[str]:
+ """Given an event and current state return the list of event IDs used
+ to auth an event.
+
+ If `for_verification` is False then only return auth events that
+ should be added to the event's `auth_events`.
+
+ Returns:
+ List of event IDs.
+ """
+
+ if event.type == EventTypes.Create:
+ return []
+
+ # Currently we ignore the `for_verification` flag even though there are
+ # some situations where we can drop particular auth events when adding
+ # to the event's `auth_events` (e.g. joins pointing to previous joins
+ # when room is publicly joinable). Dropping event IDs has the
+ # advantage that the auth chain for the room grows slower, but we use
+ # the auth chain in state resolution v2 to order events, which means
+ # care must be taken if dropping events to ensure that it doesn't
+ # introduce undesirable "state reset" behaviour.
+ #
+ # All of which sounds a bit tricky so we don't bother for now.
+
+ auth_ids = []
+ for etype, state_key in event_auth.auth_types_for_event(event):
+ auth_ev_id = current_state_ids.get((etype, state_key))
+ if auth_ev_id:
+ auth_ids.append(auth_ev_id)
+
+ return auth_ids
+
+ async def check_host_in_room(self, room_id: str, host: str) -> bool:
+ with Measure(self._clock, "check_host_in_room"):
+ return await self._store.is_host_joined(room_id, host)
+
async def check_restricted_join_rules(
self,
state_ids: StateMap[str],
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index aaf209842c..f734ce1861 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -250,7 +250,9 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self._event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
if not is_in_room:
logger.info(
"Ignoring PDU from %s as we're not in the room",
@@ -1675,7 +1677,9 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version_id(room_id)
# now check that we are *still* in the room
- is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self._event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
if not is_in_room:
logger.info(
"Got /make_join request for room %s we are no longer in",
@@ -1706,86 +1710,12 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- await self.auth.check_from_context(
+ await self._event_auth_handler.check_from_context(
room_version, event, context, do_sig_check=False
)
return event
- async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
- """We have received a join event for a room. Fully process it and
- respond with the current state and auth chains.
- """
- event = pdu
-
- logger.debug(
- "on_send_join_request from %s: Got event: %s, signatures: %s",
- origin,
- event.event_id,
- event.signatures,
- )
-
- if get_domain_from_id(event.sender) != origin:
- logger.info(
- "Got /send_join request for user %r from different origin %s",
- event.sender,
- origin,
- )
- raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
- event.internal_metadata.outlier = False
- # Send this event on behalf of the origin server.
- #
- # The reasons we have the destination server rather than the origin
- # server send it are slightly mysterious: the origin server should have
- # all the necessary state once it gets the response to the send_join,
- # so it could send the event itself if it wanted to. It may be that
- # doing it this way reduces failure modes, or avoids certain attacks
- # where a new server selectively tells a subset of the federation that
- # it has joined.
- #
- # The fact is that, as of the current writing, Synapse doesn't send out
- # the join event over federation after joining, and changing it now
- # would introduce the danger of backwards-compatibility problems.
- event.internal_metadata.send_on_behalf_of = origin
-
- # Calculate the event context.
- context = await self.state_handler.compute_event_context(event)
-
- # Get the state before the new event.
- prev_state_ids = await context.get_prev_state_ids()
-
- # Check if the user is already in the room or invited to the room.
- user_id = event.state_key
- prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
-
- # Check if the member should be allowed access via membership in a space.
- await self._event_auth_handler.check_restricted_join_rules(
- prev_state_ids,
- event.room_version,
- user_id,
- prev_member_event,
- )
-
- # Persist the event.
- await self._auth_and_persist_event(origin, event, context)
-
- logger.debug(
- "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
- event.event_id,
- event.signatures,
- )
-
- state_ids = list(prev_state_ids.values())
- auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
-
- state = await self.store.get_events(list(prev_state_ids.values()))
-
- return {"state": list(state.values()), "auth_chain": auth_chain}
-
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
) -> EventBase:
@@ -1959,7 +1889,7 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- await self.auth.check_from_context(
+ await self._event_auth_handler.check_from_context(
room_version, event, context, do_sig_check=False
)
except AuthError as e:
@@ -1968,37 +1898,6 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
- """We have received a leave event for a room. Fully process it."""
- event = pdu
-
- logger.debug(
- "on_send_leave_request: Got event: %s, signatures: %s",
- event.event_id,
- event.signatures,
- )
-
- if get_domain_from_id(event.sender) != origin:
- logger.info(
- "Got /send_leave request for user %r from different origin %s",
- event.sender,
- origin,
- )
- raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
- event.internal_metadata.outlier = False
-
- context = await self.state_handler.compute_event_context(event)
- await self._auth_and_persist_event(origin, event, context)
-
- logger.debug(
- "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
- event.event_id,
- event.signatures,
- )
-
- return None
-
@log_function
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str
@@ -2052,7 +1951,7 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_knock_request`
- await self.auth.check_from_context(
+ await self._event_auth_handler.check_from_context(
room_version, event, context, do_sig_check=False
)
except AuthError as e:
@@ -2062,51 +1961,115 @@ class FederationHandler(BaseHandler):
return event
@log_function
- async def on_send_knock_request(
+ async def on_send_membership_event(
self, origin: str, event: EventBase
) -> EventContext:
"""
- We have received a knock event for a room. Verify that event and send it into the room
- on the knocking homeserver's behalf.
+ We have received a join/leave/knock event for a room via send_join/leave/knock.
+
+ Verify that event and send it into the room on the remote homeserver's behalf.
+
+ This is quite similar to on_receive_pdu, with the following principal
+ differences:
+ * only membership events are permitted (and only events with
+ sender==state_key -- ie, no kicks or bans)
+ * *We* send out the event on behalf of the remote server.
+ * We enforce the membership restrictions of restricted rooms.
+ * Rejected events result in an exception rather than being stored.
+
+ There are also other differences, however it is not clear if these are by
+ design or omission. In particular, we do not attempt to backfill any missing
+ prev_events.
Args:
- origin: The remote homeserver of the knocking user.
- event: The knocking member event that has been signed by the remote homeserver.
+ origin: The homeserver of the remote (joining/invited/knocking) user.
+ event: The member event that has been signed by the remote homeserver.
Returns:
The context of the event after inserting it into the room graph.
+
+ Raises:
+ SynapseError if the event is not accepted into the room
"""
logger.debug(
- "on_send_knock_request: Got event: %s, signatures: %s",
+ "on_send_membership_event: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
if get_domain_from_id(event.sender) != origin:
logger.info(
- "Got /send_knock request for user %r from different origin %s",
+ "Got send_membership request for user %r from different origin %s",
event.sender,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
- event.internal_metadata.outlier = False
+ if event.sender != event.state_key:
+ raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
- context = await self.state_handler.compute_event_context(event)
+ assert not event.internal_metadata.outlier
- event_allowed = await self.third_party_event_rules.check_event_allowed(
- event, context
- )
- if not event_allowed:
- logger.info("Sending of knock %s forbidden by third-party rules", event)
+ # Send this event on behalf of the other server.
+ #
+ # The remote server isn't a full participant in the room at this point, so
+ # may not have an up-to-date list of the other homeservers participating in
+ # the room, so we send it on their behalf.
+ event.internal_metadata.send_on_behalf_of = origin
+
+ context = await self.state_handler.compute_event_context(event)
+ context = await self._check_event_auth(origin, event, context)
+ if context.rejected:
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
)
- await self._auth_and_persist_event(origin, event, context)
+ # for joins, we need to check the restrictions of restricted rooms
+ if event.membership == Membership.JOIN:
+ await self._check_join_restrictions(context, event)
+ # for knock events, we run the third-party event rules. It's not entirely clear
+ # why we don't do this for other sorts of membership events.
+ if event.membership == Membership.KNOCK:
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ # all looks good, we can persist the event.
+ await self._run_push_actions_and_persist_event(event, context)
return context
+ async def _check_join_restrictions(
+ self, context: EventContext, event: EventBase
+ ) -> None:
+ """Check that restrictions in restricted join rules are matched
+
+ Called when we receive a join event via send_join.
+
+ Raises an auth error if the restrictions are not matched.
+ """
+ prev_state_ids = await context.get_prev_state_ids()
+
+ # Check if the user is already in the room or invited to the room.
+ user_id = event.state_key
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+
+ # Check if the member should be allowed access via membership in a space.
+ await self._event_auth_handler.check_restricted_join_rules(
+ prev_state_ids,
+ event.room_version,
+ user_id,
+ prev_member_event,
+ )
+
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
@@ -2160,7 +2123,7 @@ class FederationHandler(BaseHandler):
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2195,7 +2158,9 @@ class FederationHandler(BaseHandler):
)
if event:
- in_room = await self.auth.check_host_in_room(event.room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(
+ event.room_id, origin
+ )
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2248,6 +2213,18 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
+ await self._run_push_actions_and_persist_event(event, context, backfilled)
+
+ async def _run_push_actions_and_persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ):
+ """Run the push actions for a received event, and persist it.
+
+ Args:
+ event: The event itself.
+ context: The event context.
+ backfilled: True if the event was backfilled.
+ """
try:
if (
not event.internal_metadata.is_outlier()
@@ -2536,7 +2513,7 @@ class FederationHandler(BaseHandler):
latest_events: List[str],
limit: int,
) -> List[EventBase]:
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2561,9 +2538,9 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
- state: Optional[Iterable[EventBase]],
- auth_events: Optional[MutableStateMap[EventBase]],
- backfilled: bool,
+ state: Optional[Iterable[EventBase]] = None,
+ auth_events: Optional[MutableStateMap[EventBase]] = None,
+ backfilled: bool = False,
) -> EventContext:
"""
Checks whether an event should be rejected (for failing auth checks).
@@ -2599,7 +2576,7 @@ class FederationHandler(BaseHandler):
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self.auth.compute_auth_events(
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
@@ -3028,7 +3005,7 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id,
}
- if await self.auth.check_host_in_room(room_id, self.hs.hostname):
+ if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname):
room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
@@ -3048,7 +3025,9 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
- await self.auth.check_from_context(room_version, event, context)
+ await self._event_auth_handler.check_from_context(
+ room_version, event, context
+ )
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
raise e
@@ -3091,7 +3070,9 @@ class FederationHandler(BaseHandler):
)
try:
- await self.auth.check_from_context(room_version, event, context)
+ await self._event_auth_handler.check_from_context(
+ room_version, event, context
+ )
except AuthError as e:
logger.warning("Denying third party invite %r because %s", event, e)
raise e
@@ -3179,7 +3160,7 @@ class FederationHandler(BaseHandler):
last_exception = None # type: Optional[Exception]
# for each public key in the 3pid invite event
- for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
+ for public_key_object in event_auth.get_public_keys(invite_event):
try:
# for each sig on the third_party_invite block of the actual invite
for server, signature_block in signed["signatures"].items():
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 74a7fc7ea2..eead07d94e 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -386,6 +386,7 @@ class EventCreationHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
+ self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state = hs.get_state_handler()
@@ -510,6 +511,8 @@ class EventCreationHandler:
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
+ If non-None, prev_event_ids must also be provided.
+
require_consent: Whether to check if the requester has
consented to the privacy policy.
@@ -582,6 +585,9 @@ class EventCreationHandler:
# Strip down the auth_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
if auth_event_ids is not None:
+ # If auth events are provided, prev events must be also.
+ assert prev_event_ids is not None
+
temp_event = await builder.build(
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
@@ -593,7 +599,7 @@ class EventCreationHandler:
(e.type, e.state_key): e.event_id for e in auth_events
}
# Actually strip down and use the necessary auth events
- auth_event_ids = self.auth.compute_auth_events(
+ auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
current_state_ids=auth_event_state_map,
for_verification=False,
@@ -785,6 +791,8 @@ class EventCreationHandler:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
+
+ If non-None, prev_event_ids must also be provided.
ratelimit: Whether to rate limit this send.
txn_id: The transaction ID.
ignore_shadow_ban: True if shadow-banned users should be allowed to
@@ -1050,7 +1058,9 @@ class EventCreationHandler:
assert event.content["membership"] == Membership.LEAVE
else:
try:
- await self.auth.check_from_context(room_version, event, context)
+ await self._event_auth_handler.check_from_context(
+ room_version, event, context
+ )
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
@@ -1375,7 +1385,7 @@ class EventCreationHandler:
raise AuthError(403, "Redacting server ACL events is not permitted")
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self.auth.compute_auth_events(
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_map = await self.store.get_events(auth_events_ids)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0880bd0496..c863504beb 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,9 +15,10 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from prometheus_client import Counter
+from typing_extensions import TypedDict
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -54,6 +55,16 @@ login_counter = Counter(
["guest", "auth_provider"],
)
+LoginDict = TypedDict(
+ "LoginDict",
+ {
+ "device_id": str,
+ "access_token": str,
+ "valid_until_ms": Optional[int],
+ "refresh_token": Optional[str],
+ },
+)
+
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
@@ -88,6 +99,7 @@ class RegistrationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
+ self.access_token_lifetime = hs.config.access_token_lifetime
async def check_username(
self,
@@ -418,11 +430,32 @@ class RegistrationHandler(BaseHandler):
room_alias = RoomAlias.from_string(r)
if self.hs.hostname != room_alias.domain:
- logger.warning(
- "Cannot create room alias %s, "
- "it does not match server domain",
+ # If the alias is remote, try to join the room. This might fail
+ # because the room might be invite only, but we don't have any local
+ # user in the room to invite this one with, so at this point that's
+ # the best we can do.
+ logger.info(
+ "Cannot automatically create room with alias %s as it isn't"
+ " local, trying to join the room instead",
r,
)
+
+ (
+ room,
+ remote_room_hosts,
+ ) = await room_member_handler.lookup_room_alias(room_alias)
+ room_id = room.to_string()
+
+ await room_member_handler.update_membership(
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
+ target=UserID.from_string(user_id),
+ room_id=room_id,
+ remote_room_hosts=remote_room_hosts,
+ action="join",
+ ratelimit=False,
+ )
else:
# A shallow copy is OK here since the only key that is
# modified is room_alias_name.
@@ -480,22 +513,32 @@ class RegistrationHandler(BaseHandler):
)
# Calculate whether the room requires an invite or can be
- # joined directly. Note that unless a join rule of public exists,
- # it is treated as requiring an invite.
- requires_invite = True
-
- state = await self.store.get_filtered_current_state_ids(
- room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
+ # joined directly. By default, we consider the room as requiring an
+ # invite if the homeserver is in the room (unless told otherwise by the
+ # join rules). Otherwise we consider it as being joinable, at the risk of
+ # failing to join, but in this case there's little more we can do since
+ # we don't have a local user in the room to craft up an invite with.
+ requires_invite = await self.store.is_host_joined(
+ room_id,
+ self.server_name,
)
- event_id = state.get((EventTypes.JoinRules, ""))
- if event_id:
- join_rules_event = await self.store.get_event(
- event_id, allow_none=True
+ if requires_invite:
+ # If the server is in the room, check if the room is public.
+ state = await self.store.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
)
- if join_rules_event:
- join_rule = join_rules_event.content.get("join_rule", None)
- requires_invite = join_rule and join_rule != JoinRules.PUBLIC
+
+ event_id = state.get((EventTypes.JoinRules, ""))
+ if event_id:
+ join_rules_event = await self.store.get_event(
+ event_id, allow_none=True
+ )
+ if join_rules_event:
+ join_rule = join_rules_event.content.get("join_rule", None)
+ requires_invite = (
+ join_rule and join_rule != JoinRules.PUBLIC
+ )
# Send the invite, if necessary.
if requires_invite:
@@ -745,7 +788,8 @@ class RegistrationHandler(BaseHandler):
is_guest: bool = False,
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
- ) -> Tuple[str, str]:
+ should_issue_refresh_token: bool = False,
+ ) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
@@ -757,8 +801,9 @@ class RegistrationHandler(BaseHandler):
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
+ should_issue_refresh_token: Whether it should also issue a refresh token
Returns:
- Tuple of device ID and access token
+ Tuple of device ID, access token, access token expiration time and refresh token
"""
res = await self._register_device_client(
user_id=user_id,
@@ -766,6 +811,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
+ should_issue_refresh_token=should_issue_refresh_token,
)
login_counter.labels(
@@ -773,7 +819,12 @@ class RegistrationHandler(BaseHandler):
auth_provider=(auth_provider_id or ""),
).inc()
- return res["device_id"], res["access_token"]
+ return (
+ res["device_id"],
+ res["access_token"],
+ res["valid_until_ms"],
+ res["refresh_token"],
+ )
async def register_device_inner(
self,
@@ -782,7 +833,8 @@ class RegistrationHandler(BaseHandler):
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
- ) -> Dict[str, str]:
+ should_issue_refresh_token: bool = False,
+ ) -> LoginDict:
"""Helper for register_device
Does the bits that need doing on the main process. Not for use outside this
@@ -797,6 +849,9 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
+ refresh_token = None
+ refresh_token_id = None
+
registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
@@ -804,14 +859,30 @@ class RegistrationHandler(BaseHandler):
assert valid_until_ms is None
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
else:
+ if should_issue_refresh_token:
+ (
+ refresh_token,
+ refresh_token_id,
+ ) = await self._auth_handler.get_refresh_token_for_user_id(
+ user_id,
+ device_id=registered_device_id,
+ )
+ valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
+
access_token = await self._auth_handler.get_access_token_for_user_id(
user_id,
device_id=registered_device_id,
valid_until_ms=valid_until_ms,
is_appservice_ghost=is_appservice_ghost,
+ refresh_token_id=refresh_token_id,
)
- return {"device_id": registered_device_id, "access_token": access_token}
+ return {
+ "device_id": registered_device_id,
+ "access_token": access_token,
+ "valid_until_ms": valid_until_ms,
+ "refresh_token": refresh_token,
+ }
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 58739e5016..0cb6855767 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -83,6 +83,7 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
self.config = hs.config
# Room state based off defined presets
@@ -226,7 +227,7 @@ class RoomCreationHandler(BaseHandler):
},
)
old_room_version = await self.store.get_room_version_id(old_room_id)
- await self.auth.check_from_context(
+ await self._event_auth_handler.check_from_context(
old_room_version, tombstone_event, tombstone_context
)
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 17fc47ce16..b585057ec3 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -25,6 +25,7 @@ from synapse.api.constants import (
EventTypes,
HistoryVisibility,
Membership,
+ RoomTypes,
)
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
@@ -318,7 +319,8 @@ class SpaceSummaryHandler:
Returns:
A tuple of:
- An iterable of a single value of the room.
+ The room information, if the room should be returned to the
+ user. None, otherwise.
An iterable of the sorted children events. This may be limited
to a maximum size or may include all children.
@@ -328,7 +330,11 @@ class SpaceSummaryHandler:
room_entry = await self._build_room_entry(room_id)
- # look for child rooms/spaces.
+ # If the room is not a space, return just the room information.
+ if room_entry.get("room_type") != RoomTypes.SPACE:
+ return room_entry, ()
+
+ # Otherwise, look for child rooms/spaces.
child_events = await self._get_child_events(room_id)
if suggested_only:
@@ -348,6 +354,7 @@ class SpaceSummaryHandler:
event_format=format_event_for_client_v2,
)
)
+
return room_entry, events_result
async def _summarize_remote_room(
@@ -465,7 +472,7 @@ class SpaceSummaryHandler:
# If this is a request over federation, check if the host is in the room or
# is in one of the spaces specified via the join rules.
elif origin:
- if await self._auth.check_host_in_room(room_id, origin):
+ if await self._event_auth_handler.check_host_in_room(room_id, origin):
return True
# Alternately, if the host has a user in any of the spaces specified
@@ -478,7 +485,9 @@ class SpaceSummaryHandler:
await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
)
for space_id in allowed_rooms:
- if await self._auth.check_host_in_room(space_id, origin):
+ if await self._event_auth_handler.check_host_in_room(
+ space_id, origin
+ ):
return True
# otherwise, check if the room is peekable
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845651e606..efbc6d5b25 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -728,7 +728,7 @@ def set_cors_headers(request: Request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
- b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
+ b"X-Requested-With, Content-Type, Authorization, Date",
)
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 07eb4f439b..bf706b3a8f 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -113,8 +113,18 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_bytes_from_args(
args: Dict[bytes, List[bytes]],
name: str,
+ default: Optional[bytes] = None,
+) -> Optional[bytes]:
+ ...
+
+
+@overload
+def parse_bytes_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
default: Literal[None] = None,
- required: Literal[True] = True,
+ *,
+ required: Literal[True],
) -> bytes:
...
@@ -197,7 +207,12 @@ def parse_string(
"""
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
return parse_string_from_args(
- args, name, default, required, allowed_values, encoding
+ args,
+ name,
+ default,
+ required=required,
+ allowed_values=allowed_values,
+ encoding=encoding,
)
@@ -227,7 +242,20 @@ def parse_strings_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[List[str]] = None,
- required: Literal[True] = True,
+ *,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> Optional[List[str]]:
+ ...
+
+
+@overload
+def parse_strings_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[List[str]] = None,
+ *,
+ required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> List[str]:
@@ -239,6 +267,7 @@ def parse_strings_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[List[str]] = None,
+ *,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
@@ -299,7 +328,20 @@ def parse_string_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[str] = None,
- required: Literal[True] = True,
+ *,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> Optional[str]:
+ ...
+
+
+@overload
+def parse_string_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[str] = None,
+ *,
+ required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> str:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 58b255eb1b..721c45abac 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -168,7 +168,7 @@ class ModuleApi:
"Using deprecated ModuleApi.register which creates a dummy user device."
)
user_id = yield self.register_user(localpart, displayname, emails or [])
- _, access_token = yield self.register_device(user_id)
+ _, access_token, _, _ = yield self.register_device(user_id)
return user_id, access_token
def register_user(
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 350646f458..669ea462e2 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -104,7 +104,7 @@ class BulkPushRuleEvaluator:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
- self.auth = hs.get_auth()
+ self._event_auth_handler = hs.get_event_auth_handler()
# Used by `RulesForRoom` to ensure only one thing mutates the cache at a
# time. Keyed off room_id.
@@ -172,7 +172,7 @@ class BulkPushRuleEvaluator:
# not having a power level event is an extreme edge case
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
- auth_events_ids = self.auth.compute_auth_events(
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events_dict = await self.store.get_events(auth_events_ids)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index c2e8c00293..550bd5c95f 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,20 +36,29 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
@staticmethod
async def _serialize_payload(
- user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest,
+ is_appservice_ghost,
+ should_issue_refresh_token,
):
"""
Args:
+ user_id (int)
device_id (str|None): Device ID to use, if None a new one is
generated.
initial_display_name (str|None)
is_guest (bool)
+ is_appservice_ghost (bool)
+ should_issue_refresh_token (bool)
"""
return {
"device_id": device_id,
"initial_display_name": initial_display_name,
"is_guest": is_guest,
"is_appservice_ghost": is_appservice_ghost,
+ "should_issue_refresh_token": should_issue_refresh_token,
}
async def _handle_request(self, request, user_id):
@@ -59,6 +68,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
+ should_issue_refresh_token = content["should_issue_refresh_token"]
res = await self.registration_handler.register_device_inner(
user_id,
@@ -66,6 +76,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
initial_display_name,
is_guest,
is_appservice_ghost=is_appservice_ghost,
+ should_issue_refresh_token=should_issue_refresh_token,
)
return 200, res
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6be5f1020..cbcb60fe31 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,9 @@
import logging
import re
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+
+from typing_extensions import TypedDict
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -25,6 +27,8 @@ from synapse.http import get_request_uri
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
+ assert_params_in_dict,
+ parse_boolean,
parse_bytes_from_args,
parse_json_object_from_request,
parse_string,
@@ -40,6 +44,21 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+LoginResponse = TypedDict(
+ "LoginResponse",
+ {
+ "user_id": str,
+ "access_token": str,
+ "home_server": str,
+ "expires_in_ms": Optional[int],
+ "refresh_token": Optional[str],
+ "device_id": str,
+ "well_known": Optional[Dict[str, Any]],
+ },
+ total=False,
+)
+
+
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE = "org.matrix.login.jwt"
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
+ REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+ self._msc2918_enabled = hs.config.access_token_lifetime is not None
self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
@@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest):
login_submission = parse_json_object_from_request(request)
+ if self._msc2918_enabled:
+ # Check if this login should also issue a refresh token, as per
+ # MSC2918
+ should_issue_refresh_token = parse_boolean(
+ request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
+ )
+ else:
+ should_issue_refresh_token = False
+
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
@@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
None, request.getClientIP()
)
- result = await self._do_appservice_login(login_submission, appservice)
+ result = await self._do_appservice_login(
+ login_submission,
+ appservice,
+ should_issue_refresh_token=should_issue_refresh_token,
+ )
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
- result = await self._do_jwt_login(login_submission)
+ result = await self._do_jwt_login(
+ login_submission,
+ should_issue_refresh_token=should_issue_refresh_token,
+ )
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
- result = await self._do_token_login(login_submission)
+ result = await self._do_token_login(
+ login_submission,
+ should_issue_refresh_token=should_issue_refresh_token,
+ )
else:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
- result = await self._do_other_login(login_submission)
+ result = await self._do_other_login(
+ login_submission,
+ should_issue_refresh_token=should_issue_refresh_token,
+ )
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
return 200, result
async def _do_appservice_login(
- self, login_submission: JsonDict, appservice: ApplicationService
+ self,
+ login_submission: JsonDict,
+ appservice: ApplicationService,
+ should_issue_refresh_token: bool = False,
):
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
return await self._complete_login(
- qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+ qualified_user_id,
+ login_submission,
+ ratelimit=appservice.is_rate_limited(),
+ should_issue_refresh_token=should_issue_refresh_token,
)
- async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ async def _do_other_login(
+ self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ ) -> LoginResponse:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
+ should_issue_refresh_token: True if this login should issue
+ a refresh token alongside the access token.
Returns:
HTTP response
@@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
login_submission, ratelimit=True
)
result = await self._complete_login(
- canonical_user_id, login_submission, callback
+ canonical_user_id,
+ login_submission,
+ callback,
+ should_issue_refresh_token=should_issue_refresh_token,
)
return result
@@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
self,
user_id: str,
login_submission: JsonDict,
- callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
+ callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
auth_provider_id: Optional[str] = None,
- ) -> Dict[str, str]:
+ should_issue_refresh_token: bool = False,
+ ) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
all successful logins.
@@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
ratelimit: Whether to ratelimit the login request.
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
+ should_issue_refresh_token: True if this login should issue
+ a refresh token alongside the access token.
Returns:
result: Dictionary of account information after successful login.
@@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
+ (
+ device_id,
+ access_token,
+ valid_until_ms,
+ refresh_token,
+ ) = await self.registration_handler.register_device(
+ user_id,
+ device_id,
+ initial_display_name,
+ auth_provider_id=auth_provider_id,
+ should_issue_refresh_token=should_issue_refresh_token,
)
- result = {
- "user_id": user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "device_id": device_id,
- }
+ result = LoginResponse(
+ user_id=user_id,
+ access_token=access_token,
+ home_server=self.hs.hostname,
+ device_id=device_id,
+ )
+
+ if valid_until_ms is not None:
+ expires_in_ms = valid_until_ms - self.clock.time_msec()
+ result["expires_in_ms"] = expires_in_ms
+
+ if refresh_token is not None:
+ result["refresh_token"] = refresh_token
if callback is not None:
await callback(result)
return result
- async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ async def _do_token_login(
+ self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ ) -> LoginResponse:
"""
Handle the final stage of SSO login.
Args:
- login_submission: The JSON request body.
+ login_submission: The JSON request body.
+ should_issue_refresh_token: True if this login should issue
+ a refresh token alongside the access token.
Returns:
The body of the JSON response.
@@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
login_submission,
self.auth_handler._sso_login_callback,
auth_provider_id=res.auth_provider_id,
+ should_issue_refresh_token=should_issue_refresh_token,
)
- async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ async def _do_jwt_login(
+ self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ ) -> LoginResponse:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
@@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
- user_id, login_submission, create_non_existent_users=True
+ user_id,
+ login_submission,
+ create_non_existent_users=True,
+ should_issue_refresh_token=should_issue_refresh_token,
)
return result
@@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
return e
+class RefreshTokenServlet(RestServlet):
+ PATTERNS = client_patterns(
+ "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth_handler = hs.get_auth_handler()
+ self._clock = hs.get_clock()
+ self.access_token_lifetime = hs.config.access_token_lifetime
+
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ ):
+ refresh_submission = parse_json_object_from_request(request)
+
+ assert_params_in_dict(refresh_submission, ["refresh_token"])
+ token = refresh_submission["refresh_token"]
+ if not isinstance(token, str):
+ raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
+
+ valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
+ access_token, refresh_token = await self._auth_handler.refresh_token(
+ token, valid_until_ms
+ )
+ expires_in_ms = valid_until_ms - self._clock.time_msec()
+ return (
+ 200,
+ {
+ "access_token": access_token,
+ "refresh_token": refresh_token,
+ "expires_in_ms": expires_in_ms,
+ },
+ )
+
+
class SsoRedirectServlet(RestServlet):
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
re.compile(
@@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
+ if hs.config.access_token_lifetime is not None:
+ RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 0be5fbb7f7..f2c155dfae 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -42,11 +42,13 @@ from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_boolean,
parse_json_object_from_request,
parse_string,
)
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
+from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -396,6 +398,7 @@ class RegisterRestServlet(RestServlet):
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_enabled = self.hs.config.enable_registration
+ self._msc2918_enabled = hs.config.access_token_lifetime is not None
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -421,6 +424,15 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
+ if self._msc2918_enabled:
+ # Check if this registration should also issue a refresh token, as
+ # per MSC2918
+ should_issue_refresh_token = parse_boolean(
+ request, name="org.matrix.msc2918.refresh_token", default=False
+ )
+ else:
+ should_issue_refresh_token = False
+
# We don't care about usernames for this deployment. In fact, the act
# of checking whether they exist already can leak metadata about
# which users are already registered.
@@ -472,7 +484,12 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Desired Username is missing or not a string")
result = await self._do_appservice_registration(
- desired_username, password, desired_display_name, access_token, body
+ desired_username,
+ password,
+ desired_display_name,
+ access_token,
+ body,
+ should_issue_refresh_token=should_issue_refresh_token,
)
return 200, result
@@ -744,7 +761,9 @@ class RegisterRestServlet(RestServlet):
registered = True
return_dict = await self._create_registration_details(
- registered_user_id, params
+ registered_user_id,
+ params,
+ should_issue_refresh_token=should_issue_refresh_token,
)
if registered:
@@ -757,17 +776,21 @@ class RegisterRestServlet(RestServlet):
return 200, return_dict
async def _do_appservice_registration(
- self, username, password, display_name, as_token, body
+ self,
+ username,
+ password,
+ display_name,
+ as_token,
+ body,
+ should_issue_refresh_token: bool = False,
):
- # FIXME: appservice_register() is horribly duplicated with register()
- # and they should probably just be combined together with a config flag.
-
if password:
# Hash the password
#
# In mainline hashing of the password was moved further on in the registration
# flow, but we need it here for the AS use-case of shadow servers
password = await self.auth_handler.hash(password)
+
user_id = await self.registration_handler.appservice_register(
username, as_token, password, display_name
)
@@ -775,6 +798,7 @@ class RegisterRestServlet(RestServlet):
user_id,
body,
is_appservice_ghost=True,
+ should_issue_refresh_token=should_issue_refresh_token,
)
auth_result = body.get("auth_result")
@@ -793,16 +817,23 @@ class RegisterRestServlet(RestServlet):
return result
async def _create_registration_details(
- self, user_id, params, is_appservice_ghost=False
+ self,
+ user_id: str,
+ params: JsonDict,
+ is_appservice_ghost: bool = False,
+ should_issue_refresh_token: bool = False,
):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
Args:
- (str) user_id: full canonical @user:id
- (object) params: registration parameters, from which we pull
- device_id, initial_device_name and inhibit_login
+ user_id: full canonical @user:id
+ params: registration parameters, from which we pull device_id,
+ initial_device_name and inhibit_login
+ is_appservice_ghost
+ should_issue_refresh_token: True if this registration should issue
+ a refresh token alongside the access token.
Returns:
dictionary for response from /register
"""
@@ -810,15 +841,29 @@ class RegisterRestServlet(RestServlet):
if not params.get("inhibit_login", False):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = await self.registration_handler.register_device(
+ (
+ device_id,
+ access_token,
+ valid_until_ms,
+ refresh_token,
+ ) = await self.registration_handler.register_device(
user_id,
device_id,
initial_display_name,
is_guest=False,
is_appservice_ghost=is_appservice_ghost,
+ should_issue_refresh_token=should_issue_refresh_token,
)
result.update({"access_token": access_token, "device_id": device_id})
+
+ if valid_until_ms is not None:
+ expires_in_ms = valid_until_ms - self.clock.time_msec()
+ result["expires_in_ms"] = expires_in_ms
+
+ if refresh_token is not None:
+ result["refresh_token"] = refresh_token
+
return result
async def _do_guest_registration(self, params, address=None):
@@ -832,19 +877,30 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = await self.registration_handler.register_device(
+ (
+ device_id,
+ access_token,
+ valid_until_ms,
+ refresh_token,
+ ) = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True
)
- return (
- 200,
- {
- "user_id": user_id,
- "device_id": device_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- },
- )
+ result = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "access_token": access_token,
+ "home_server": self.hs.hostname,
+ }
+
+ if valid_until_ms is not None:
+ expires_in_ms = valid_until_ms - self.clock.time_msec()
+ result["expires_in_ms"] = expires_in_ms
+
+ if refresh_token is not None:
+ result["refresh_token"] = refresh_token
+
+ return 200, result
def cap(name):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 042e1788b6..ecbbcf3851 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -13,6 +13,7 @@
# limitations under the License.
import itertools
import logging
+from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
from synapse.api.constants import Membership, PresenceState
@@ -232,29 +233,51 @@ class SyncRestServlet(RestServlet):
)
logger.debug("building sync response dict")
- return {
- "account_data": {"events": sync_result.account_data},
- "to_device": {"events": sync_result.to_device},
- "device_lists": {
- "changed": list(sync_result.device_lists.changed),
- "left": list(sync_result.device_lists.left),
- },
- "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
- "rooms": {
- Membership.JOIN: joined,
- Membership.INVITE: invited,
- Membership.KNOCK: knocked,
- Membership.LEAVE: archived,
- },
- "groups": {
- Membership.JOIN: sync_result.groups.join,
- Membership.INVITE: sync_result.groups.invite,
- Membership.LEAVE: sync_result.groups.leave,
- },
- "device_one_time_keys_count": sync_result.device_one_time_keys_count,
- "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
- "next_batch": await sync_result.next_batch.to_string(self.store),
- }
+
+ response: dict = defaultdict(dict)
+ response["next_batch"] = await sync_result.next_batch.to_string(self.store)
+
+ if sync_result.account_data:
+ response["account_data"] = {"events": sync_result.account_data}
+ if sync_result.presence:
+ response["presence"] = SyncRestServlet.encode_presence(
+ sync_result.presence, time_now
+ )
+
+ if sync_result.to_device:
+ response["to_device"] = {"events": sync_result.to_device}
+
+ if sync_result.device_lists.changed:
+ response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
+ if sync_result.device_lists.left:
+ response["device_lists"]["left"] = list(sync_result.device_lists.left)
+
+ if sync_result.device_one_time_keys_count:
+ response[
+ "device_one_time_keys_count"
+ ] = sync_result.device_one_time_keys_count
+ if sync_result.device_unused_fallback_key_types:
+ response[
+ "org.matrix.msc2732.device_unused_fallback_key_types"
+ ] = sync_result.device_unused_fallback_key_types
+
+ if joined:
+ response["rooms"][Membership.JOIN] = joined
+ if invited:
+ response["rooms"][Membership.INVITE] = invited
+ if knocked:
+ response["rooms"][Membership.KNOCK] = knocked
+ if archived:
+ response["rooms"][Membership.LEAVE] = archived
+
+ if sync_result.groups.join:
+ response["groups"][Membership.JOIN] = sync_result.groups.join
+ if sync_result.groups.invite:
+ response["groups"][Membership.INVITE] = sync_result.groups.invite
+ if sync_result.groups.leave:
+ response["groups"][Membership.LEAVE] = sync_result.groups.leave
+
+ return response
@staticmethod
def encode_presence(events, time_now):
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d470cdacde..33c42cf95a 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -111,7 +111,7 @@ def make_conn(
db_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
default_txn_name: str,
-) -> Connection:
+) -> "LoggingDatabaseConnection":
"""Make a new connection to the database and return it.
Returns:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index f23f8c6ecf..c4474df975 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,6 +16,8 @@ import logging
from queue import Empty, PriorityQueue
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
+from prometheus_client import Gauge
+
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion
@@ -32,6 +34,16 @@ from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
+oldest_pdu_in_federation_staging = Gauge(
+ "synapse_federation_server_oldest_inbound_pdu_in_staging",
+ "The age in seconds since we received the oldest pdu in the federation staging area",
+)
+
+number_pdus_in_federation_queue = Gauge(
+ "synapse_federation_server_number_inbound_pdu_in_staging",
+ "The total number of events in the inbound federation staging",
+)
+
logger = logging.getLogger(__name__)
@@ -54,6 +66,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
500000, "_event_auth_cache", size_callback=len
) # type: LruCache[str, List[Tuple[str, int]]]
+ self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
+
async def get_auth_chain(
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
@@ -1075,16 +1089,62 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
self,
origin: str,
event_id: str,
- ) -> None:
- """Remove the given event from the staging area"""
- await self.db_pool.simple_delete(
- table="federation_inbound_events_staging",
- keyvalues={
- "origin": origin,
- "event_id": event_id,
- },
- desc="remove_received_event_from_staging",
- )
+ ) -> Optional[int]:
+ """Remove the given event from the staging area.
+
+ Returns:
+ The received_ts of the row that was deleted, if any.
+ """
+ if self.db_pool.engine.supports_returning:
+
+ def _remove_received_event_from_staging_txn(txn):
+ sql = """
+ DELETE FROM federation_inbound_events_staging
+ WHERE origin = ? AND event_id = ?
+ RETURNING received_ts
+ """
+
+ txn.execute(sql, (origin, event_id))
+ return txn.fetchone()
+
+ row = await self.db_pool.runInteraction(
+ "remove_received_event_from_staging",
+ _remove_received_event_from_staging_txn,
+ db_autocommit=True,
+ )
+ if row is None:
+ return None
+
+ return row[0]
+
+ else:
+
+ def _remove_received_event_from_staging_txn(txn):
+ received_ts = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="federation_inbound_events_staging",
+ keyvalues={
+ "origin": origin,
+ "event_id": event_id,
+ },
+ retcol="received_ts",
+ allow_none=True,
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="federation_inbound_events_staging",
+ keyvalues={
+ "origin": origin,
+ "event_id": event_id,
+ },
+ )
+
+ return received_ts
+
+ return await self.db_pool.runInteraction(
+ "remove_received_event_from_staging",
+ _remove_received_event_from_staging_txn,
+ )
async def get_next_staged_event_id_for_room(
self,
@@ -1147,6 +1207,40 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return origin, event
+ async def get_all_rooms_with_staged_incoming_events(self) -> List[str]:
+ """Get the room IDs of all events currently staged."""
+ return await self.db_pool.simple_select_onecol(
+ table="federation_inbound_events_staging",
+ keyvalues={},
+ retcol="DISTINCT room_id",
+ desc="get_all_rooms_with_staged_incoming_events",
+ )
+
+ @wrap_as_background_process("_get_stats_for_federation_staging")
+ async def _get_stats_for_federation_staging(self):
+ """Update the prometheus metrics for the inbound federation staging area."""
+
+ def _get_stats_for_federation_staging_txn(txn):
+ txn.execute(
+ "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging"
+ )
+ (count,) = txn.fetchone()
+
+ txn.execute(
+ "SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging"
+ )
+
+ (age,) = txn.fetchone()
+
+ return count, age
+
+ count, age = await self.db_pool.runInteraction(
+ "_get_stats_for_federation_staging", _get_stats_for_federation_staging_txn
+ )
+
+ number_pdus_in_federation_queue.set(count)
+ oldest_pdu_in_federation_staging.set(age)
+
class EventFederationStore(EventFederationWorkerStore):
"""Responsible for storing and serving up the various graphs associated
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index cbe4be1437..29f33bac55 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -29,6 +29,34 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+_REPLACE_STREAM_ORDERING_SQL_COMMANDS = (
+ # there should be no leftover rows without a stream_ordering2, but just in case...
+ "UPDATE events SET stream_ordering2 = stream_ordering WHERE stream_ordering2 IS NULL",
+ # now we can drop the rule and switch the columns
+ "DROP RULE populate_stream_ordering2 ON events",
+ "ALTER TABLE events DROP COLUMN stream_ordering",
+ "ALTER TABLE events RENAME COLUMN stream_ordering2 TO stream_ordering",
+ # ... and finally, rename the indexes into place for consistency with sqlite
+ "ALTER INDEX event_contains_url_index2 RENAME TO event_contains_url_index",
+ "ALTER INDEX events_order_room2 RENAME TO events_order_room",
+ "ALTER INDEX events_room_stream2 RENAME TO events_room_stream",
+ "ALTER INDEX events_ts2 RENAME TO events_ts",
+)
+
+
+class _BackgroundUpdates:
+ EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
+ EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
+ DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
+ POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2"
+ INDEX_STREAM_ORDERING2 = "index_stream_ordering2"
+ INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url"
+ INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order"
+ INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream"
+ INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts"
+ REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
+
+
@attr.s(slots=True, frozen=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn."""
@@ -48,19 +76,15 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore):
-
- EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
- EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
- DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
-
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
- self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
+ _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME,
+ self._background_reindex_origin_server_ts,
)
self.db_pool.updates.register_background_update_handler(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
+ _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
@@ -85,7 +109,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
self.db_pool.updates.register_background_update_handler(
- self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
+ _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES,
+ self._cleanup_extremities_bg_update,
)
self.db_pool.updates.register_background_update_handler(
@@ -139,6 +164,59 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
+ ################################################################################
+
+ # bg updates for replacing stream_ordering with a BIGINT
+ # (these only run on postgres.)
+
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.POPULATE_STREAM_ORDERING2,
+ self._background_populate_stream_ordering2,
+ )
+ # CREATE UNIQUE INDEX events_stream_ordering ON events(stream_ordering2);
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.INDEX_STREAM_ORDERING2,
+ index_name="events_stream_ordering",
+ table="events",
+ columns=["stream_ordering2"],
+ unique=True,
+ )
+ # CREATE INDEX event_contains_url_index ON events(room_id, topological_ordering, stream_ordering) WHERE contains_url = true AND outlier = false;
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.INDEX_STREAM_ORDERING2_CONTAINS_URL,
+ index_name="event_contains_url_index2",
+ table="events",
+ columns=["room_id", "topological_ordering", "stream_ordering2"],
+ where_clause="contains_url = true AND outlier = false",
+ )
+ # CREATE INDEX events_order_room ON events(room_id, topological_ordering, stream_ordering);
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.INDEX_STREAM_ORDERING2_ROOM_ORDER,
+ index_name="events_order_room2",
+ table="events",
+ columns=["room_id", "topological_ordering", "stream_ordering2"],
+ )
+ # CREATE INDEX events_room_stream ON events(room_id, stream_ordering);
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.INDEX_STREAM_ORDERING2_ROOM_STREAM,
+ index_name="events_room_stream2",
+ table="events",
+ columns=["room_id", "stream_ordering2"],
+ )
+ # CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.INDEX_STREAM_ORDERING2_TS,
+ index_name="events_ts2",
+ table="events",
+ columns=["origin_server_ts", "stream_ordering2"],
+ )
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN,
+ self._background_replace_stream_ordering_column,
+ )
+
+ ################################################################################
+
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -190,18 +268,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
}
self.db_pool.updates._background_update_progress_txn(
- txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
+ txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
result = await self.db_pool.runInteraction(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
+ _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
await self.db_pool.updates._end_background_update(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
+ _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
@@ -264,18 +342,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
}
self.db_pool.updates._background_update_progress_txn(
- txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
+ txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows_to_update)
result = await self.db_pool.runInteraction(
- self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
+ _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
await self.db_pool.updates._end_background_update(
- self.EVENT_ORIGIN_SERVER_TS_NAME
+ _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
@@ -454,7 +532,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if not num_handled:
await self.db_pool.updates._end_background_update(
- self.DELETE_SOFT_FAILED_EXTREMITIES
+ _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
@@ -1009,3 +1087,81 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
await self.db_pool.updates._end_background_update("purged_chain_cover")
return result
+
+ async def _background_populate_stream_ordering2(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Populate events.stream_ordering2, then replace stream_ordering
+
+ This is to deal with the fact that stream_ordering was initially created as a
+ 32-bit integer field.
+ """
+ batch_size = max(batch_size, 1)
+
+ def process(txn: Cursor) -> int:
+ last_stream = progress.get("last_stream", -(1 << 31))
+ txn.execute(
+ """
+ UPDATE events SET stream_ordering2=stream_ordering
+ WHERE stream_ordering IN (
+ SELECT stream_ordering FROM events WHERE stream_ordering > ?
+ ORDER BY stream_ordering LIMIT ?
+ )
+ RETURNING stream_ordering;
+ """,
+ (last_stream, batch_size),
+ )
+ row_count = txn.rowcount
+ if row_count == 0:
+ return 0
+ last_stream = max(row[0] for row in txn)
+ logger.info("populated stream_ordering2 up to %i", last_stream)
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.POPULATE_STREAM_ORDERING2,
+ {"last_stream": last_stream},
+ )
+ return row_count
+
+ result = await self.db_pool.runInteraction(
+ "_background_populate_stream_ordering2", process
+ )
+
+ if result != 0:
+ return result
+
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.POPULATE_STREAM_ORDERING2
+ )
+ return 0
+
+ async def _background_replace_stream_ordering_column(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Drop the old 'stream_ordering' column and rename 'stream_ordering2' into its place."""
+
+ def process(txn: Cursor) -> None:
+ for sql in _REPLACE_STREAM_ORDERING_SQL_COMMANDS:
+ logger.info("completing stream_ordering migration: %s", sql)
+ txn.execute(sql)
+
+ # ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
+ # indexes on it.
+ # We need to pass execute a dummy function to handle the txn's result otherwise
+ # it tries to call fetchall() on it and fails because there's no result to fetch.
+ await self.db_pool.execute(
+ "background_analyze_new_stream_ordering_column",
+ lambda txn: None,
+ "ANALYZE events(stream_ordering2)",
+ )
+
+ await self.db_pool.runInteraction(
+ "_background_replace_stream_ordering_column", process
+ )
+
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN
+ )
+
+ return 0
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index e76188328c..774861074c 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -310,14 +310,25 @@ class Lock:
_excinst: Optional[BaseException],
_exctb: Optional[TracebackType],
) -> bool:
+ await self.release()
+
+ return False
+
+ async def release(self) -> None:
+ """Release the lock.
+
+ This is automatically called when using the lock as a context manager.
+ """
+
+ if self._dropped:
+ return
+
if self._looping_call.running:
self._looping_call.stop()
await self._store._drop_lock(self._lock_name, self._lock_key, self._token)
self._dropped = True
- return False
-
def __del__(self) -> None:
if not self._dropped:
# We should not be dropped without the lock being released (unless
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 77e2eb27db..b8bdfc721e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -53,6 +53,9 @@ class TokenLookupResult:
valid_until_ms: The timestamp the token expires, if any.
token_owner: The "owner" of the token. This is either the same as the
user, or a server admin who is logged in as the user.
+ token_used: True if this token was used at least once in a request.
+ This field can be out of date since `get_user_by_access_token` is
+ cached.
"""
user_id = attr.ib(type=str)
@@ -62,6 +65,7 @@ class TokenLookupResult:
device_id = attr.ib(type=Optional[str], default=None)
valid_until_ms = attr.ib(type=Optional[int], default=None)
token_owner = attr.ib(type=str)
+ token_used = attr.ib(type=bool, default=False)
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
@@ -69,6 +73,29 @@ class TokenLookupResult:
return self.user_id
+@attr.s(frozen=True, slots=True)
+class RefreshTokenLookupResult:
+ """Result of looking up a refresh token."""
+
+ user_id = attr.ib(type=str)
+ """The user this token belongs to."""
+
+ device_id = attr.ib(type=str)
+ """The device associated with this refresh token."""
+
+ token_id = attr.ib(type=int)
+ """The ID of this refresh token."""
+
+ next_token_id = attr.ib(type=Optional[int])
+ """The ID of the refresh token which replaced this one."""
+
+ has_next_refresh_token_been_refreshed = attr.ib(type=bool)
+ """True if the next refresh token was used for another refresh."""
+
+ has_next_access_token_been_used = attr.ib(type=bool)
+ """True if the next access token was already used at least once."""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -521,7 +548,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
access_tokens.id as token_id,
access_tokens.device_id,
access_tokens.valid_until_ms,
- access_tokens.user_id as token_owner
+ access_tokens.user_id as token_owner,
+ access_tokens.used as token_used
FROM users
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ?
@@ -529,8 +557,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
+
if rows:
- return TokenLookupResult(**rows[0])
+ row = rows[0]
+
+ # This field is nullable, ensure it comes out as a boolean
+ if row["token_used"] is None:
+ row["token_used"] = False
+
+ return TokenLookupResult(**row)
return None
@@ -1152,6 +1187,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated",
)
+ @cached()
+ async def mark_access_token_as_used(self, token_id: int) -> None:
+ """
+ Mark the access token as used, which invalidates the refresh token used
+ to obtain it.
+
+ Because get_user_by_access_token is cached, this function might be
+ called multiple times for the same token, effectively doing unnecessary
+ SQL updates. Because updating the `used` field only goes one way (from
+ False to True) it is safe to cache this function as well to avoid this
+ issue.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if there was a problem updating this.
+ """
+ await self.db_pool.simple_update_one(
+ "access_tokens",
+ {"id": token_id},
+ {"used": True},
+ desc="mark_access_token_as_used",
+ )
+
+ async def lookup_refresh_token(
+ self, token: str
+ ) -> Optional[RefreshTokenLookupResult]:
+ """Lookup a refresh token with hints about its validity."""
+
+ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+ txn.execute(
+ """
+ SELECT
+ rt.id token_id,
+ rt.user_id,
+ rt.device_id,
+ rt.next_token_id,
+ (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
+ at.used has_next_access_token_been_used
+ FROM refresh_tokens rt
+ LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
+ LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
+ WHERE rt.token = ?
+ """,
+ (token,),
+ )
+ row = txn.fetchone()
+
+ if row is None:
+ return None
+
+ return RefreshTokenLookupResult(
+ token_id=row[0],
+ user_id=row[1],
+ device_id=row[2],
+ next_token_id=row[3],
+ has_next_refresh_token_been_refreshed=row[4],
+ # This column is nullable, ensure it's a boolean
+ has_next_access_token_been_used=(row[5] or False),
+ )
+
+ return await self.db_pool.runInteraction(
+ "lookup_refresh_token", _lookup_refresh_token_txn
+ )
+
+ async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
+ """
+ Set the successor of a refresh token, removing the existing successor
+ if any.
+
+ Args:
+ token_id: ID of the refresh token to update.
+ next_token_id: ID of its successor.
+ """
+
+ def _replace_refresh_token_txn(txn) -> None:
+ # First check if there was an existing refresh token
+ old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ "refresh_tokens",
+ {"id": token_id},
+ "next_token_id",
+ allow_none=True,
+ )
+
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "refresh_tokens",
+ {"id": token_id},
+ {"next_token_id": next_token_id},
+ )
+
+ # Delete the old "next" token if it exists. This should cascade and
+ # delete the associated access_token
+ if old_next_token_id is not None:
+ self.db_pool.simple_delete_one_txn(
+ txn,
+ "refresh_tokens",
+ {"id": old_next_token_id},
+ )
+
+ await self.db_pool.runInteraction(
+ "replace_refresh_token", _replace_refresh_token_txn
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@@ -1343,6 +1483,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+ self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
async def add_access_token_to_user(
self,
@@ -1351,14 +1492,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
+ refresh_token_id: Optional[int] = None,
) -> int:
"""Adds an access token for the given user.
Args:
user_id: The user ID.
token: The new access token to add.
- device_id: ID of the device to associate with the access token
+ device_id: ID of the device to associate with the access token.
valid_until_ms: when the token is valid until. None for no expiry.
+ puppets_user_id
+ refresh_token_id: ID of the refresh token generated alongside this
+ access token.
Raises:
StoreError if there was a problem adding this.
Returns:
@@ -1377,12 +1522,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
"last_validated": now,
+ "refresh_token_id": refresh_token_id,
+ "used": False,
},
desc="add_access_token_to_user",
)
return next_id
+ async def add_refresh_token_to_user(
+ self,
+ user_id: str,
+ token: str,
+ device_id: Optional[str],
+ ) -> int:
+ """Adds a refresh token for the given user.
+
+ Args:
+ user_id: The user ID.
+ token: The new access token to add.
+ device_id: ID of the device to associate with the refresh token.
+ Raises:
+ StoreError if there was a problem adding this.
+ Returns:
+ The token ID
+ """
+ next_id = self._refresh_tokens_id_gen.get_next()
+
+ await self.db_pool.simple_insert(
+ "refresh_tokens",
+ {
+ "id": next_id,
+ "user_id": user_id,
+ "device_id": device_id,
+ "token": token,
+ "next_token_id": None,
+ },
+ desc="add_refresh_token_to_user",
+ )
+
+ return next_id
+
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"
@@ -1625,7 +1805,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
- Invalidate access tokens belonging to a user
+ Invalidate access and refresh tokens belonging to a user
Args:
user_id: ID of user the tokens belong to
@@ -1645,7 +1825,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items] # type: List[Union[str, int]]
+ # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
+ # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
+ # clause and values before we handle that. This seems to be only used in the "set password" handler.
+ refresh_where_clause = where_clause
+ refresh_values = values.copy()
if except_token_id:
+ # TODO: support that for refresh tokens
where_clause += " AND id != ?"
values.append(except_token_id)
@@ -1663,6 +1849,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
+ txn.execute(
+ "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
+ refresh_values,
+ )
+
return tokens_and_devices
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
@@ -1679,6 +1870,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
+ async def delete_refresh_token(self, refresh_token: str) -> None:
+ def f(txn):
+ self.db_pool.simple_delete_one_txn(
+ txn, table="refresh_tokens", keyvalues={"token": refresh_token}
+ )
+
+ await self.db_pool.runInteraction("delete_refresh_token", f)
+
async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 1882bfd9cf..20cd63c330 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -49,6 +49,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
"""
...
+ @property
+ @abc.abstractmethod
+ def supports_returning(self) -> bool:
+ """Do we support the `RETURNING` clause in insert/update/delete?"""
+ ...
+
@abc.abstractmethod
def check_database(
self, db_conn: ConnectionType, allow_outdated_version: bool = False
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 21411c5fea..30f948a0f7 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -133,6 +133,11 @@ class PostgresEngine(BaseDatabaseEngine):
"""Do we support using `a = ANY(?)` and passing a list"""
return True
+ @property
+ def supports_returning(self) -> bool:
+ """Do we support the `RETURNING` clause in insert/update/delete?"""
+ return True
+
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 5fe1b205e1..70d17d4f2c 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -60,6 +60,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
"""Do we support using `a = ANY(?)` and passing a list"""
return False
+ @property
+ def supports_returning(self) -> bool:
+ """Do we support the `RETURNING` clause in insert/update/delete?"""
+ return self.module.sqlite_version_info >= (3, 35, 0)
+
def check_database(self, db_conn, allow_outdated_version: bool = False):
if not allow_outdated_version:
version = self.module.sqlite_version_info
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index d36ba1d773..0a53b73ccc 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 59
+SCHEMA_VERSION = 60
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
diff --git a/synapse/storage/schema/main/delta/59/14refresh_tokens.sql b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql
new file mode 100644
index 0000000000..9a6bce1e3e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Holds MSC2918 refresh tokens
+CREATE TABLE refresh_tokens (
+ id BIGINT PRIMARY KEY,
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ token TEXT NOT NULL,
+ -- When consumed, a new refresh token is generated, which is tracked by
+ -- this foreign key
+ next_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE,
+ UNIQUE(token)
+);
+
+-- Add a reference to the refresh token generated alongside each access token
+ALTER TABLE "access_tokens"
+ ADD COLUMN refresh_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE;
+
+-- Add a flag whether the token was already used or not
+ALTER TABLE "access_tokens"
+ ADD COLUMN used BOOLEAN;
diff --git a/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres
new file mode 100644
index 0000000000..0edc9fe7a2
--- /dev/null
+++ b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres
@@ -0,0 +1,45 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This migration handles the process of changing the type of `stream_ordering` to
+-- a BIGINT.
+--
+-- Note that this is only a problem on postgres as sqlite only has one "integer" type
+-- which can cope with values up to 2^63.
+
+-- First add a new column to contain the bigger stream_ordering
+ALTER TABLE events ADD COLUMN stream_ordering2 BIGINT;
+
+-- Create a rule which will populate it for new rows.
+CREATE OR REPLACE RULE "populate_stream_ordering2" AS
+ ON INSERT TO events
+ DO UPDATE events SET stream_ordering2=NEW.stream_ordering WHERE stream_ordering=NEW.stream_ordering;
+
+-- Start a bg process to populate it for old events
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6001, 'populate_stream_ordering2', '{}');
+
+-- ... and some more to build indexes on it. These aren't really interdependent
+-- but the backround_updates manager can only handle a single dependency per update.
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (6001, 'index_stream_ordering2', '{}', 'populate_stream_ordering2'),
+ (6001, 'index_stream_ordering2_room_order', '{}', 'index_stream_ordering2'),
+ (6001, 'index_stream_ordering2_contains_url', '{}', 'index_stream_ordering2_room_order'),
+ (6001, 'index_stream_ordering2_room_stream', '{}', 'index_stream_ordering2_contains_url'),
+ (6001, 'index_stream_ordering2_ts', '{}', 'index_stream_ordering2_room_stream');
+
+-- ... and another to do the switcheroo
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (6001, 'replace_stream_ordering_column', '{}', 'index_stream_ordering2_ts');
diff --git a/synapse/storage/schema/main/delta/60/02change_stream_ordering_columns.sql.postgres b/synapse/storage/schema/main/delta/60/02change_stream_ordering_columns.sql.postgres
new file mode 100644
index 0000000000..630c24fd9e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/60/02change_stream_ordering_columns.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This migration is closely related to '01recreate_stream_ordering.sql.postgres'.
+--
+-- It updates the other tables which use an INTEGER to refer to a stream ordering.
+-- These tables are all small enough that a re-create is tractable.
+ALTER TABLE pushers ALTER COLUMN last_stream_ordering SET DATA TYPE BIGINT;
+ALTER TABLE federation_stream_position ALTER COLUMN stream_id SET DATA TYPE BIGINT;
+
+-- these aren't actually event stream orderings, but they are numbers where 2 billion
+-- is a bit limiting, application_services_state is tiny, and I don't want to ever have
+-- to do this again.
+ALTER TABLE application_services_state ALTER COLUMN last_txn SET DATA TYPE BIGINT;
+ALTER TABLE application_services_state ALTER COLUMN read_receipt_stream_id SET DATA TYPE BIGINT;
+ALTER TABLE application_services_state ALTER COLUMN presence_stream_id SET DATA TYPE BIGINT;
+
+
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index d89e9d9b1d..4b9d0433ff 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import threading
+import weakref
from functools import wraps
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Collection,
@@ -31,10 +34,19 @@ from typing import (
from typing_extensions import Literal
+from twisted.internet import reactor
+
from synapse.config import cache as cache_config
-from synapse.util import caches
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.linked_list import ListNode
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
try:
from pympler.asizeof import Asizer
@@ -82,19 +94,126 @@ def enumerate_leaves(node, depth):
yield m
+P = TypeVar("P")
+
+
+class _TimedListNode(ListNode[P]):
+ """A `ListNode` that tracks last access time."""
+
+ __slots__ = ["last_access_ts_secs"]
+
+ def update_last_access(self, clock: Clock):
+ self.last_access_ts_secs = int(clock.time())
+
+
+# Whether to insert new cache entries to the global list. We only add to it if
+# time based eviction is enabled.
+USE_GLOBAL_LIST = False
+
+# A linked list of all cache entries, allowing efficient time based eviction.
+GLOBAL_ROOT = ListNode["_Node"].create_root_node()
+
+
+@wrap_as_background_process("LruCache._expire_old_entries")
+async def _expire_old_entries(clock: Clock, expiry_seconds: int):
+ """Walks the global cache list to find cache entries that haven't been
+ accessed in the given number of seconds.
+ """
+
+ now = int(clock.time())
+ node = GLOBAL_ROOT.prev_node
+ assert node is not None
+
+ i = 0
+
+ logger.debug("Searching for stale caches")
+
+ while node is not GLOBAL_ROOT:
+ # Only the root node isn't a `_TimedListNode`.
+ assert isinstance(node, _TimedListNode)
+
+ if node.last_access_ts_secs > now - expiry_seconds:
+ break
+
+ cache_entry = node.get_cache_entry()
+ next_node = node.prev_node
+
+ # The node should always have a reference to a cache entry and a valid
+ # `prev_node`, as we only drop them when we remove the node from the
+ # list.
+ assert next_node is not None
+ assert cache_entry is not None
+ cache_entry.drop_from_cache()
+
+ # If we do lots of work at once we yield to allow other stuff to happen.
+ if (i + 1) % 10000 == 0:
+ logger.debug("Waiting during drop")
+ await clock.sleep(0)
+ logger.debug("Waking during drop")
+
+ node = next_node
+
+ # If we've yielded then our current node may have been evicted, so we
+ # need to check that its still valid.
+ if node.prev_node is None:
+ break
+
+ i += 1
+
+ logger.info("Dropped %d items from caches", i)
+
+
+def setup_expire_lru_cache_entries(hs: "HomeServer"):
+ """Start a background job that expires all cache entries if they have not
+ been accessed for the given number of seconds.
+ """
+ if not hs.config.caches.expiry_time_msec:
+ return
+
+ logger.info(
+ "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000
+ )
+
+ global USE_GLOBAL_LIST
+ USE_GLOBAL_LIST = True
+
+ clock = hs.get_clock()
+ clock.looping_call(
+ _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000
+ )
+
+
class _Node:
- __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
+ __slots__ = [
+ "_list_node",
+ "_global_list_node",
+ "_cache",
+ "key",
+ "value",
+ "callbacks",
+ "memory",
+ ]
def __init__(
self,
- prev_node,
- next_node,
+ root: "ListNode[_Node]",
key,
value,
+ cache: "weakref.ReferenceType[LruCache]",
+ clock: Clock,
callbacks: Collection[Callable[[], None]] = (),
):
- self.prev_node = prev_node
- self.next_node = next_node
+ self._list_node = ListNode.insert_after(self, root)
+ self._global_list_node = None
+ if USE_GLOBAL_LIST:
+ self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT)
+ self._global_list_node.update_last_access(clock)
+
+ # We store a weak reference to the cache object so that this _Node can
+ # remove itself from the cache. If the cache is dropped we ensure we
+ # remove our entries in the lists.
+ self._cache = cache
+
self.key = key
self.value = value
@@ -116,11 +235,16 @@ class _Node:
self.memory = (
_get_size_of(key)
+ _get_size_of(value)
+ + _get_size_of(self._list_node, recurse=False)
+ _get_size_of(self.callbacks, recurse=False)
+ _get_size_of(self, recurse=False)
)
self.memory += _get_size_of(self.memory, recurse=False)
+ if self._global_list_node:
+ self.memory += _get_size_of(self._global_list_node, recurse=False)
+ self.memory += _get_size_of(self._global_list_node.last_access_ts_secs)
+
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
"""Add to stored list of callbacks, removing duplicates."""
@@ -147,6 +271,32 @@ class _Node:
self.callbacks = None
+ def drop_from_cache(self) -> None:
+ """Drop this node from the cache.
+
+ Ensures that the entry gets removed from the cache and that we get
+ removed from all lists.
+ """
+ cache = self._cache()
+ if not cache or not cache.pop(self.key, None):
+ # `cache.pop` should call `drop_from_lists()`, unless this Node had
+ # already been removed from the cache.
+ self.drop_from_lists()
+
+ def drop_from_lists(self) -> None:
+ """Remove this node from the cache lists."""
+ self._list_node.remove_from_list()
+
+ if self._global_list_node:
+ self._global_list_node.remove_from_list()
+
+ def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None:
+ """Moves this node to the front of all the lists its in."""
+ self._list_node.move_after(cache_list_root)
+ if self._global_list_node:
+ self._global_list_node.move_after(GLOBAL_ROOT)
+ self._global_list_node.update_last_access(clock)
+
class LruCache(Generic[KT, VT]):
"""
@@ -163,6 +313,7 @@ class LruCache(Generic[KT, VT]):
size_callback: Optional[Callable] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
+ clock: Optional[Clock] = None,
):
"""
Args:
@@ -188,6 +339,13 @@ class LruCache(Generic[KT, VT]):
apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
"""
+ # Default `clock` to something sensible. Note that we rename it to
+ # `real_clock` so that mypy doesn't think its still `Optional`.
+ if clock is None:
+ real_clock = Clock(reactor)
+ else:
+ real_clock = clock
+
cache = cache_type()
self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config
@@ -219,17 +377,31 @@ class LruCache(Generic[KT, VT]):
# this is exposed for access from outside this class
self.metrics = metrics
- list_root = _Node(None, None, None, None)
- list_root.next_node = list_root
- list_root.prev_node = list_root
+ # We create a single weakref to self here so that we don't need to keep
+ # creating more each time we create a `_Node`.
+ weak_ref_to_self = weakref.ref(self)
+
+ list_root = ListNode[_Node].create_root_node()
lock = threading.Lock()
def evict():
while cache_len() > self.max_size:
+ # Get the last node in the list (i.e. the oldest node).
todelete = list_root.prev_node
- evicted_len = delete_node(todelete)
- cache.pop(todelete.key, None)
+
+ # The list root should always have a valid `prev_node` if the
+ # cache is not empty.
+ assert todelete is not None
+
+ # The node should always have a reference to a cache entry, as
+ # we only drop the cache entry when we remove the node from the
+ # list.
+ node = todelete.get_cache_entry()
+ assert node is not None
+
+ evicted_len = delete_node(node)
+ cache.pop(node.key, None)
if metrics:
metrics.inc_evictions(evicted_len)
@@ -255,11 +427,7 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
- prev_node = list_root
- next_node = prev_node.next_node
- node = _Node(prev_node, next_node, key, value, callbacks)
- prev_node.next_node = node
- next_node.prev_node = node
+ node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks)
cache[key] = node
if size_callback:
@@ -268,23 +436,11 @@ class LruCache(Generic[KT, VT]):
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory)
- def move_node_to_front(node):
- prev_node = node.prev_node
- next_node = node.next_node
- prev_node.next_node = next_node
- next_node.prev_node = prev_node
- prev_node = list_root
- next_node = prev_node.next_node
- node.prev_node = prev_node
- node.next_node = next_node
- prev_node.next_node = node
- next_node.prev_node = node
-
- def delete_node(node):
- prev_node = node.prev_node
- next_node = node.next_node
- prev_node.next_node = next_node
- next_node.prev_node = prev_node
+ def move_node_to_front(node: _Node):
+ node.move_to_front(real_clock, list_root)
+
+ def delete_node(node: _Node) -> int:
+ node.drop_from_lists()
deleted_len = 1
if size_callback:
@@ -411,10 +567,13 @@ class LruCache(Generic[KT, VT]):
@synchronized
def cache_clear() -> None:
- list_root.next_node = list_root
- list_root.prev_node = list_root
for node in cache.values():
node.run_and_clear_callbacks()
+ node.drop_from_lists()
+
+ assert list_root.next_node == list_root
+ assert list_root.prev_node == list_root
+
cache.clear()
if size_callback:
cached_cache_len[0] = 0
@@ -484,3 +643,11 @@ class LruCache(Generic[KT, VT]):
self._on_resize()
return True
return False
+
+ def __del__(self) -> None:
+ # We're about to be deleted, so we make sure to clear up all the nodes
+ # and run callbacks, etc.
+ #
+ # This happens e.g. in the sync code where we have an expiring cache of
+ # lru caches.
+ self.clear()
diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
new file mode 100644
index 0000000000..a456b136f0
--- /dev/null
+++ b/synapse/util/linked_list.py
@@ -0,0 +1,150 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A circular doubly linked list implementation.
+"""
+
+import threading
+from typing import Generic, Optional, Type, TypeVar
+
+P = TypeVar("P")
+LN = TypeVar("LN", bound="ListNode")
+
+
+class ListNode(Generic[P]):
+ """A node in a circular doubly linked list, with an (optional) reference to
+ a cache entry.
+
+ The reference should only be `None` for the root node or if the node has
+ been removed from the list.
+ """
+
+ # A lock to protect mutating the list prev/next pointers.
+ _LOCK = threading.Lock()
+
+ # We don't use attrs here as in py3.6 you can't have `attr.s(slots=True)`
+ # and inherit from `Generic` for some reason
+ __slots__ = [
+ "cache_entry",
+ "prev_node",
+ "next_node",
+ ]
+
+ def __init__(self, cache_entry: Optional[P] = None) -> None:
+ self.cache_entry = cache_entry
+ self.prev_node: Optional[ListNode[P]] = None
+ self.next_node: Optional[ListNode[P]] = None
+
+ @classmethod
+ def create_root_node(cls: Type["ListNode[P]"]) -> "ListNode[P]":
+ """Create a new linked list by creating a "root" node, which is a node
+ that has prev_node/next_node pointing to itself and no associated cache
+ entry.
+ """
+ root = cls()
+ root.prev_node = root
+ root.next_node = root
+ return root
+
+ @classmethod
+ def insert_after(
+ cls: Type[LN],
+ cache_entry: P,
+ node: "ListNode[P]",
+ ) -> LN:
+ """Create a new list node that is placed after the given node.
+
+ Args:
+ cache_entry: The associated cache entry.
+ node: The existing node in the list to insert the new entry after.
+ """
+ new_node = cls(cache_entry)
+ with cls._LOCK:
+ new_node._refs_insert_after(node)
+ return new_node
+
+ def remove_from_list(self):
+ """Remove this node from the list."""
+ with self._LOCK:
+ self._refs_remove_node_from_list()
+
+ # We drop the reference to the cache entry to break the reference cycle
+ # between the list node and cache entry, allowing the two to be dropped
+ # immediately rather than at the next GC.
+ self.cache_entry = None
+
+ def move_after(self, node: "ListNode"):
+ """Move this node from its current location in the list to after the
+ given node.
+ """
+ with self._LOCK:
+ # We assert that both this node and the target node is still "alive".
+ assert self.prev_node
+ assert self.next_node
+ assert node.prev_node
+ assert node.next_node
+
+ assert self is not node
+
+ # Remove self from the list
+ self._refs_remove_node_from_list()
+
+ # Insert self back into the list, after target node
+ self._refs_insert_after(node)
+
+ def _refs_remove_node_from_list(self):
+ """Internal method to *just* remove the node from the list, without
+ e.g. clearing out the cache entry.
+ """
+ if self.prev_node is None or self.next_node is None:
+ # We've already been removed from the list.
+ return
+
+ prev_node = self.prev_node
+ next_node = self.next_node
+
+ prev_node.next_node = next_node
+ next_node.prev_node = prev_node
+
+ # We set these to None so that we don't get circular references,
+ # allowing us to be dropped without having to go via the GC.
+ self.prev_node = None
+ self.next_node = None
+
+ def _refs_insert_after(self, node: "ListNode"):
+ """Internal method to insert the node after the given node."""
+
+ # This method should only be called when we're not already in the list.
+ assert self.prev_node is None
+ assert self.next_node is None
+
+ # We expect the given node to be in the list and thus have valid
+ # prev/next refs.
+ assert node.next_node
+ assert node.prev_node
+
+ prev_node = node
+ next_node = node.next_node
+
+ self.prev_node = prev_node
+ self.next_node = next_node
+
+ prev_node.next_node = self
+ next_node.prev_node = self
+
+ def get_cache_entry(self) -> Optional[P]:
+ """Get the cache entry, returns None if this is the root node (i.e.
+ cache_entry is None) or if the entry has been dropped.
+ """
+ return self.cache_entry
|