From 7010a3d0151b88b3a9a7451201eaf9c5bbe48d64 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 21 Dec 2022 13:05:21 -0500 Subject: Switch to ruff instead of flake8. (#14633) ruff is a flake8-compatible Python linter written in Rust. It supports the flake8 plugins that we use and is significantly faster in testing. --- synapse/config/_base.pyi | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse') diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 01ea2b4dab..bd265de536 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse from typing import ( Any, -- cgit 1.5.1 From 5c9be9c76021ac54f425f10e8f935532d3197de5 Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Thu, 22 Dec 2022 13:26:37 -0500 Subject: Check sqlite database file exists before porting. (#14692) To avoid creating an empty SQLite file if the given path is incorrect. --- changelog.d/14692.misc | 1 + synapse/_scripts/synapse_port_db.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14692.misc (limited to 'synapse') diff --git a/changelog.d/14692.misc b/changelog.d/14692.misc new file mode 100644 index 0000000000..0edac253b7 --- /dev/null +++ b/changelog.d/14692.misc @@ -0,0 +1 @@ +Check that the SQLite database file exists before porting to PostgreSQL. \ No newline at end of file diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index d850e54e17..c463b60b26 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -1307,7 +1307,7 @@ def main() -> None: sqlite_config = { "name": "sqlite3", "args": { - "database": args.sqlite_database, + "database": "file:{}?mode=rw".format(args.sqlite_database), "cp_min": 1, "cp_max": 1, "check_same_thread": False, -- cgit 1.5.1 From a52822d39c866b4d5e6d2a0176f29ae49bf3f8e9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:04:50 +0000 Subject: Log to-device msgids when we return them over /sync (#14724) --- changelog.d/14724.misc | 1 + synapse/handlers/sync.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14724.misc (limited to 'synapse') diff --git a/changelog.d/14724.misc b/changelog.d/14724.misc new file mode 100644 index 0000000000..270e5ed188 --- /dev/null +++ b/changelog.d/14724.misc @@ -0,0 +1 @@ +If debug logging is enabled, log the `msgid`s of any to-device messages that are returned over `/sync`. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7d6a653747..4fa480262b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.handlers.relations import BundledAggregations +from synapse.logging import issue9533_logger from synapse.logging.context import current_context from synapse.logging.opentracing import ( SynapseTags, @@ -1623,13 +1624,18 @@ class SyncHandler: } ) - logger.debug( - "Returning %d to-device messages between %d and %d (current token: %d)", - len(messages), - since_stream_id, - stream_id, - now_token.to_device_key, - ) + if messages and issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Returning to-device messages with stream_ids (%d, %d]; now: %d;" + " msgids: %s", + since_stream_id, + stream_id, + now_token.to_device_key, + [ + message["content"].get(EventContentFields.TO_DEVICE_MSGID) + for message in messages + ], + ) sync_result_builder.now_token = now_token.copy_and_replace( StreamKeyType.TO_DEVICE, stream_id ) -- cgit 1.5.1 From 3854d0f94947ddd5a9ee98198af8d7ae839962c9 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 28 Dec 2022 14:48:21 +0100 Subject: Add a `cached` helper to the module API (#14663) --- changelog.d/14663.feature | 1 + synapse/module_api/__init__.py | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14663.feature (limited to 'synapse') diff --git a/changelog.d/14663.feature b/changelog.d/14663.feature new file mode 100644 index 0000000000..b03f3ee54e --- /dev/null +++ b/changelog.d/14663.feature @@ -0,0 +1 @@ +Add a `cached` function to `synapse.module_api` that returns a decorator to cache return values of functions. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 0092a03c59..6f4a934b05 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -18,6 +18,7 @@ from typing import ( TYPE_CHECKING, Any, Callable, + Collection, Dict, Generator, Iterable, @@ -126,7 +127,7 @@ from synapse.types import ( from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable -from synapse.util.caches.descriptors import CachedFunction, cached +from synapse.util.caches.descriptors import CachedFunction, cached as _cached from synapse.util.frozenutils import freeze if TYPE_CHECKING: @@ -136,6 +137,7 @@ if TYPE_CHECKING: T = TypeVar("T") P = ParamSpec("P") +F = TypeVar("F", bound=Callable[..., Any]) """ This package defines the 'stable' API which can be used by extension modules which @@ -185,6 +187,42 @@ class UserIpAndAgent: last_seen: int +def cached( + *, + max_entries: int = 1000, + num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, +) -> Callable[[F], CachedFunction[F]]: + """Returns a decorator that applies a memoizing cache around the function. This + decorator behaves similarly to functools.lru_cache. + + Example: + + @cached() + def foo('a', 'b'): + ... + + Added in Synapse v1.74.0. + + Args: + max_entries: The maximum number of entries in the cache. If the cache is full + and a new entry is added, the least recently accessed entry will be evicted + from the cache. + num_args: The number of positional arguments (excluding `self`) to use as cache + keys. Defaults to all named args of the function. + uncached_args: A list of argument names to not use as the cache key. (`self` is + always ignored.) Cannot be used with num_args. + + Returns: + A decorator that applies a memoizing cache around the function. + """ + return _cached( + max_entries=max_entries, + num_args=num_args, + uncached_args=uncached_args, + ) + + class ModuleApi: """A proxy object that gets passed to various plugin modules so they can register new users etc if necessary. -- cgit 1.5.1 From 044fa1a1de3c954f247a98c0ce8f734c675a5efb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 29 Dec 2022 12:18:06 -0500 Subject: Actually use the picture_claim as configured in OIDC config. (#14751) Previously it was only using the default value ("picture") when fetching the picture from the user info. --- changelog.d/14751.bugfix | 1 + synapse/handlers/oidc.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14751.bugfix (limited to 'synapse') diff --git a/changelog.d/14751.bugfix b/changelog.d/14751.bugfix new file mode 100644 index 0000000000..56ef852288 --- /dev/null +++ b/changelog.d/14751.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.73.0 where the `picture_claim` configured under `oidc_providers` was unused (the default value of `"picture"` was used instead). diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 03de6a4ba6..23fb00c9c9 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1615,7 +1615,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) - picture = userinfo.get("picture") + picture = userinfo.get(self._config.picture_claim) return UserAttributeDict( localpart=localpart, -- cgit 1.5.1 From c4456114e1a5471bb61cb45605e782263dc8233c Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Sun, 1 Jan 2023 03:40:46 +0000 Subject: Add experimental support for MSC3391: deleting account data (#14714) --- changelog.d/14714.feature | 1 + .../complement/conf/workers-shared-extra.yaml.j2 | 2 + scripts-dev/complement.sh | 2 +- synapse/config/experimental.py | 3 + synapse/handlers/account_data.py | 111 ++++++++++- synapse/replication/http/account_data.py | 92 ++++++++- synapse/rest/client/account_data.py | 115 +++++++++++ synapse/storage/database.py | 33 +++- synapse/storage/databases/main/account_data.py | 219 +++++++++++++++++++-- 9 files changed, 547 insertions(+), 31 deletions(-) create mode 100644 changelog.d/14714.feature (limited to 'synapse') diff --git a/changelog.d/14714.feature b/changelog.d/14714.feature new file mode 100644 index 0000000000..5f3a20b7a7 --- /dev/null +++ b/changelog.d/14714.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3391](https://github.com/matrix-org/matrix-spec-proposals/pull/3391) (removing account data). \ No newline at end of file diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index ca640c343b..cb839fed07 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -102,6 +102,8 @@ experimental_features: {% endif %} # Filtering /messages by relation type. msc3874_enabled: true + # Enable removing account data support + msc3391_enabled: true server_notices: system_mxid_localpart: _server diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 8741ba3e34..51d1bac618 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -190,7 +190,7 @@ fi extra_test_args=() -test_tags="synapse_blacklist,msc3787,msc3874" +test_tags="synapse_blacklist,msc3787,msc3874,msc3391" # All environment variables starting with PASS_ will be shared. # (The prefix is stripped off before reaching the container.) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 573fa0386f..0f3870bfe1 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -136,3 +136,6 @@ class ExperimentalConfig(Config): # Enable room version (and thus applicable push rules from MSC3931/3932) version_id = RoomVersions.MSC1767v10.identifier KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10 + + # MSC3391: Removing account data. + self.msc3391_enabled = experimental.get("msc3391_enabled", False) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index fc21d58001..aba7315cf7 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -17,10 +17,12 @@ import random from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple from synapse.replication.http.account_data import ( + ReplicationAddRoomAccountDataRestServlet, ReplicationAddTagRestServlet, + ReplicationAddUserAccountDataRestServlet, + ReplicationRemoveRoomAccountDataRestServlet, ReplicationRemoveTagRestServlet, - ReplicationRoomAccountDataRestServlet, - ReplicationUserAccountDataRestServlet, + ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource from synapse.types import JsonDict, StreamKeyType, UserID @@ -41,8 +43,18 @@ class AccountDataHandler: self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() - self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) - self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) + self._add_user_data_client = ( + ReplicationAddUserAccountDataRestServlet.make_client(hs) + ) + self._remove_user_data_client = ( + ReplicationRemoveUserAccountDataRestServlet.make_client(hs) + ) + self._add_room_data_client = ( + ReplicationAddRoomAccountDataRestServlet.make_client(hs) + ) + self._remove_room_data_client = ( + ReplicationRemoveRoomAccountDataRestServlet.make_client(hs) + ) self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) self._account_data_writers = hs.config.worker.writers.account_data @@ -112,7 +124,7 @@ class AccountDataHandler: return max_stream_id else: - response = await self._room_data_client( + response = await self._add_room_data_client( instance_name=random.choice(self._account_data_writers), user_id=user_id, room_id=room_id, @@ -121,15 +133,59 @@ class AccountDataHandler: ) return response["max_stream_id"] + async def remove_account_data_for_room( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[int]: + """ + Deletes the room account data for the given user and account data type. + + "Deleting" account data merely means setting the content of the account data + to an empty JSON object: {}. + + Args: + user_id: The user ID to remove room account data for. + room_id: The room ID to target. + account_data_type: The account data type to remove. + + Returns: + The maximum stream ID, or None if the room account data item did not exist. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + if max_stream_id is None: + # The referenced account data did not exist, so no delete occurred. + return None + + self._notifier.on_new_event( + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] + ) + + # Notify Synapse modules that the content of the type has changed to an + # empty dictionary. + await self._notify_modules(user_id, room_id, account_data_type, {}) + + return max_stream_id + else: + response = await self._remove_room_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + account_data_type=account_data_type, + content={}, + ) + return response["max_stream_id"] + async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: """Add some global account_data for a user. Args: - user_id: The user to add a tag for. + user_id: The user to add some account data for. account_data_type: The type of account_data to add. - content: A json object to associate with the tag. + content: The content json dictionary. Returns: The maximum stream ID. @@ -148,7 +204,7 @@ class AccountDataHandler: return max_stream_id else: - response = await self._user_data_client( + response = await self._add_user_data_client( instance_name=random.choice(self._account_data_writers), user_id=user_id, account_data_type=account_data_type, @@ -156,6 +212,45 @@ class AccountDataHandler: ) return response["max_stream_id"] + async def remove_account_data_for_user( + self, user_id: str, account_data_type: str + ) -> Optional[int]: + """Removes a piece of global account_data for a user. + + Args: + user_id: The user to remove account data for. + account_data_type: The type of account_data to remove. + + Returns: + The maximum stream ID, or None if the room account data item did not exist. + """ + + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_account_data_for_user( + user_id, account_data_type + ) + if max_stream_id is None: + # The referenced account data did not exist, so no delete occurred. + return None + + self._notifier.on_new_event( + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] + ) + + # Notify Synapse modules that the content of the type has changed to an + # empty dictionary. + await self._notify_modules(user_id, None, account_data_type, {}) + + return max_stream_id + else: + response = await self._remove_user_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + account_data_type=account_data_type, + content={}, + ) + return response["max_stream_id"] + async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict ) -> int: diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 310f609153..0edc95977b 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): +class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint): """Add user account data on the appropriate account data worker. Request format: @@ -49,7 +49,6 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -73,7 +72,45 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} -class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): +class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint): + """Remove user account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_user_account_data/:user_id/:type + + { + "content": { ... }, + } + + """ + + NAME = "remove_user_account_data" + PATH_ARGS = ("user_id", "account_data_type") + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, account_data_type: str + ) -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: + max_stream_id = await self.handler.remove_account_data_for_user( + user_id, account_data_type + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint): """Add room account data on the appropriate account data worker. Request format: @@ -94,7 +131,6 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -118,6 +154,44 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} +class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint): + """Remove room account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_room_account_data/:user_id/:room_id/:account_data_type + + { + "content": { ... }, + } + + """ + + NAME = "remove_room_account_data" + PATH_ARGS = ("user_id", "room_id", "account_data_type") + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, room_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: + max_stream_id = await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + + return 200, {"max_stream_id": max_stream_id} + + class ReplicationAddTagRestServlet(ReplicationEndpoint): """Add tag on the appropriate account data worker. @@ -139,7 +213,6 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -186,7 +259,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] @@ -206,7 +278,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - ReplicationUserAccountDataRestServlet(hs).register(http_server) - ReplicationRoomAccountDataRestServlet(hs).register(http_server) + ReplicationAddUserAccountDataRestServlet(hs).register(http_server) + ReplicationAddRoomAccountDataRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server) ReplicationRemoveTagRestServlet(hs).register(http_server) + + if hs.config.experimental.msc3391_enabled: + ReplicationRemoveUserAccountDataRestServlet(hs).register(http_server) + ReplicationRemoveRoomAccountDataRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index f13970b898..e805196fec 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -41,6 +41,7 @@ class AccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() @@ -54,6 +55,16 @@ class AccountDataServlet(RestServlet): body = parse_json_object_from_request(request) + # If experimental support for MSC3391 is enabled, then providing an empty dict + # as the value for an account data type should be functionally equivalent to + # calling the DELETE method on the same type. + if self._hs.config.experimental.msc3391_enabled: + if body == {}: + await self.handler.remove_account_data_for_user( + user_id, account_data_type + ) + return 200, {} + await self.handler.add_account_data_for_user(user_id, account_data_type, body) return 200, {} @@ -72,9 +83,48 @@ class AccountDataServlet(RestServlet): if event is None: raise NotFoundError("Account data not found") + # If experimental support for MSC3391 is enabled, then this endpoint should + # return a 404 if the content for an account data type is an empty dict. + if self._hs.config.experimental.msc3391_enabled and event == {}: + raise NotFoundError("Account data not found") + return 200, event +class UnstableAccountDataServlet(RestServlet): + """ + Contains an unstable endpoint for removing user account data, as specified by + MSC3391. If that MSC is accepted, this code should have unstable prefixes removed + and become incorporated into AccountDataServlet above. + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3391/user/(?P[^/]*)" + "/account_data/(?P[^/]*)", + unstable=True, + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.handler = hs.get_account_data_handler() + + async def on_DELETE( + self, + request: SynapseRequest, + user_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot delete account data for other users.") + + await self.handler.remove_account_data_for_user(user_id, account_data_type) + + return 200, {} + + class RoomAccountDataServlet(RestServlet): """ PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 @@ -89,6 +139,7 @@ class RoomAccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() @@ -121,6 +172,16 @@ class RoomAccountDataServlet(RestServlet): Codes.BAD_JSON, ) + # If experimental support for MSC3391 is enabled, then providing an empty dict + # as the value for an account data type should be functionally equivalent to + # calling the DELETE method on the same type. + if self._hs.config.experimental.msc3391_enabled: + if body == {}: + await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + return 200, {} + await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, body ) @@ -152,9 +213,63 @@ class RoomAccountDataServlet(RestServlet): if event is None: raise NotFoundError("Room account data not found") + # If experimental support for MSC3391 is enabled, then this endpoint should + # return a 404 if the content for an account data type is an empty dict. + if self._hs.config.experimental.msc3391_enabled and event == {}: + raise NotFoundError("Room account data not found") + return 200, event +class UnstableRoomAccountDataServlet(RestServlet): + """ + Contains an unstable endpoint for removing room account data, as specified by + MSC3391. If that MSC is accepted, this code should have unstable prefixes removed + and become incorporated into RoomAccountDataServlet above. + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3391/user/(?P[^/]*)" + "/rooms/(?P[^/]*)" + "/account_data/(?P[^/]*)", + unstable=True, + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.handler = hs.get_account_data_handler() + + async def on_DELETE( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot delete account data for other users.") + + if not RoomID.is_valid(room_id): + raise SynapseError( + 400, + f"{room_id} is not a valid room ID", + Codes.INVALID_PARAM, + ) + + await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + + return 200, {} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server) + + if hs.config.experimental.msc3391_enabled: + UnstableAccountDataServlet(hs).register(http_server) + UnstableRoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0b29e67b94..88479a16db 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1762,7 +1762,8 @@ class DatabasePool: desc: description of the transaction, for logging and metrics Returns: - A list of dictionaries. + A list of dictionaries, one per result row, each a mapping between the + column names from `retcols` and that column's value for the row. """ return await self.runInteraction( desc, @@ -1791,6 +1792,10 @@ class DatabasePool: column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + + Returns: + A list of dictionaries, one per result row, each a mapping between the + column names from `retcols` and that column's value for the row. """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1898,6 +1903,19 @@ class DatabasePool: updatevalues: Dict[str, Any], desc: str, ) -> int: + """ + Update rows in the given database table. + If the given keyvalues don't match anything, nothing will be updated. + + Args: + table: The database table to update. + keyvalues: A mapping of column name to value to match rows on. + updatevalues: A mapping of column name to value to replace in any matched rows. + desc: description of the transaction, for logging and metrics. + + Returns: + The number of rows that were updated. Will be 0 if no matching rows were found. + """ return await self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @@ -1909,6 +1927,19 @@ class DatabasePool: keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], ) -> int: + """ + Update rows in the given database table. + If the given keyvalues don't match anything, nothing will be updated. + + Args: + txn: The database transaction object. + table: The database table to update. + keyvalues: A mapping of column name to value to match rows on. + updatevalues: A mapping of column name to value to replace in any matched rows. + + Returns: + The number of rows that were updated. Will be 0 if no matching rows were found. + """ if keyvalues: where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) else: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 07908c41d9..e59776f434 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -123,7 +123,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_account_data_for_user( self, user_id: str ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - """Get all the client account_data for a user. + """ + Get all the client account_data for a user. + + If experimental MSC3391 support is enabled, any entries with an empty + content body are excluded; as this means they have been deleted. Args: user_id: The user to get the account_data for. @@ -135,27 +139,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_account_data_for_user_txn( txn: LoggingTransaction, ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - rows = self.db_pool.simple_select_list_txn( - txn, - "account_data", - {"user_id": user_id}, - ["account_data_type", "content"], - ) + # The 'content != '{}' condition below prevents us from using + # `simple_select_list_txn` here, as it doesn't support conditions + # other than 'equals'. + sql = """ + SELECT account_data_type, content FROM account_data + WHERE user_id = ? + """ + + # If experimental MSC3391 support is enabled, then account data entries + # with an empty content are considered "deleted". So skip adding them to + # the results. + if self.hs.config.experimental.msc3391_enabled: + sql += " AND content != '{}'" + + txn.execute(sql, (user_id,)) + rows = self.db_pool.cursor_to_dict(txn) global_account_data = { row["account_data_type"]: db_to_json(row["content"]) for row in rows } - rows = self.db_pool.simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id}, - ["room_id", "account_data_type", "content"], - ) + # The 'content != '{}' condition below prevents us from using + # `simple_select_list_txn` here, as it doesn't support conditions + # other than 'equals'. + sql = """ + SELECT room_id, account_data_type, content FROM room_account_data + WHERE user_id = ? + """ + + # If experimental MSC3391 support is enabled, then account data entries + # with an empty content are considered "deleted". So skip adding them to + # the results. + if self.hs.config.experimental.msc3391_enabled: + sql += " AND content != '{}'" + + txn.execute(sql, (user_id,)) + rows = self.db_pool.cursor_to_dict(txn) by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room @@ -469,6 +494,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return self._account_data_id_gen.get_current_token() + async def remove_account_data_for_room( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[int]: + """Delete the room account data for the user of a given type. + + Args: + user_id: The user to remove account_data for. + room_id: The room ID to scope the request to. + account_data_type: The account data type to delete. + + Returns: + The maximum stream position, or None if there was no matching room account + data to delete. + """ + assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) + + def _remove_account_data_for_room_txn( + txn: LoggingTransaction, next_id: int + ) -> bool: + """ + Args: + txn: The transaction object. + next_id: The stream_id to update any existing rows to. + + Returns: + True if an entry in room_account_data had its content set to '{}', + otherwise False. This informs callers of whether there actually was an + existing room account data entry to delete, or if the call was a no-op. + """ + # We can't use `simple_update` as it doesn't have the ability to specify + # where clauses other than '=', which we need for `content != '{}'` below. + sql = """ + UPDATE room_account_data + SET stream_id = ?, content = '{}' + WHERE user_id = ? + AND room_id = ? + AND account_data_type = ? + AND content != '{}' + """ + txn.execute( + sql, + (next_id, user_id, room_id, account_data_type), + ) + # Return true if any rows were updated. + return txn.rowcount != 0 + + async with self._account_data_id_gen.get_next() as next_id: + row_updated = await self.db_pool.runInteraction( + "remove_account_data_for_room", + _remove_account_data_for_room_txn, + next_id, + ) + + if not row_updated: + return None + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), {} + ) + + return self._account_data_id_gen.get_current_token() + async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: @@ -569,6 +660,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) + async def remove_account_data_for_user( + self, + user_id: str, + account_data_type: str, + ) -> Optional[int]: + """ + Delete a single piece of user account data by type. + + A "delete" is performed by updating a potentially existing row in the + "account_data" database table for (user_id, account_data_type) and + setting its content to "{}". + + Args: + user_id: The user ID to modify the account data of. + account_data_type: The type to remove. + + Returns: + The maximum stream position, or None if there was no matching account data + to delete. + """ + assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) + + def _remove_account_data_for_user_txn( + txn: LoggingTransaction, next_id: int + ) -> bool: + """ + Args: + txn: The transaction object. + next_id: The stream_id to update any existing rows to. + + Returns: + True if an entry in account_data had its content set to '{}', otherwise + False. This informs callers of whether there actually was an existing + account data entry to delete, or if the call was a no-op. + """ + # We can't use `simple_update` as it doesn't have the ability to specify + # where clauses other than '=', which we need for `content != '{}'` below. + sql = """ + UPDATE account_data + SET stream_id = ?, content = '{}' + WHERE user_id = ? + AND account_data_type = ? + AND content != '{}' + """ + txn.execute(sql, (next_id, user_id, account_data_type)) + if txn.rowcount == 0: + # We didn't update any rows. This means that there was no matching room + # account data entry to delete in the first place. + return False + + # Ignored users get denormalized into a separate table as an optimisation. + if account_data_type == AccountDataTypes.IGNORED_USER_LIST: + # If this method was called with the ignored users account data type, we + # simply delete all ignored users. + + # First pull all the users that this user ignores. + previously_ignored_users = set( + self.db_pool.simple_select_onecol_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + ) + ) + + # Then delete them from the database. + self.db_pool.simple_delete_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + ) + + # Invalidate the cache for ignored users which were removed. + for ignored_user_id in previously_ignored_users: + self._invalidate_cache_and_stream( + txn, self.ignored_by, (ignored_user_id,) + ) + + # Invalidate for this user the cache tracking ignored users. + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) + + return True + + async with self._account_data_id_gen.get_next() as next_id: + row_updated = await self.db_pool.runInteraction( + "remove_account_data_for_user", + _remove_account_data_for_user_txn, + next_id, + ) + + if not row_updated: + return None + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.prefill( + (user_id, account_data_type), {} + ) + + return self._account_data_id_gen.get_current_token() + async def purge_account_data_for_user(self, user_id: str) -> None: """ Removes ALL the account data for a user. -- cgit 1.5.1 From db1cfe9c80a707995fcad8f3faa839acb247068a Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 4 Jan 2023 11:49:26 +0000 Subject: Update all stream IDs after processing replication rows (#14723) This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: https://github.com/matrix-org/synapse/issues/14158#issuecomment-1344048703 --- changelog.d/14723.bugfix | 1 + synapse/replication/tcp/client.py | 3 +++ synapse/storage/_base.py | 17 ++++++++++++++++- synapse/storage/databases/main/account_data.py | 14 ++++++++++---- synapse/storage/databases/main/cache.py | 11 ++++++++--- synapse/storage/databases/main/deviceinbox.py | 7 +++++++ synapse/storage/databases/main/devices.py | 11 +++++++++-- synapse/storage/databases/main/events_worker.py | 15 ++++++++++----- synapse/storage/databases/main/presence.py | 8 +++++++- synapse/storage/databases/main/push_rule.py | 7 +++++++ synapse/storage/databases/main/pusher.py | 6 +++--- synapse/storage/databases/main/receipts.py | 7 +++++++ synapse/storage/databases/main/tags.py | 8 +++++++- 13 files changed, 95 insertions(+), 20 deletions(-) create mode 100644 changelog.d/14723.bugfix (limited to 'synapse') diff --git a/changelog.d/14723.bugfix b/changelog.d/14723.bugfix new file mode 100644 index 0000000000..e1f89cee35 --- /dev/null +++ b/changelog.d/14723.bugfix @@ -0,0 +1 @@ +Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 658d89210d..b5e40da533 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -152,6 +152,9 @@ class ReplicationDataHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + # NOTE: this must be called after process_replication_rows to ensure any + # cache invalidations are first handled before any stream ID advances. + self.store.process_replication_position(stream_name, instance_name, token) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 69abf6fa87..41d9111019 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta): token: int, rows: Iterable[Any], ) -> None: - pass + """ + Used by storage classes to invalidate caches based on incoming replication data. These + must not update any ID generators, use `process_replication_position`. + """ + + def process_replication_position( # noqa: B027 (no-op by design) + self, + stream_name: str, + instance_name: str, + token: int, + ) -> None: + """ + Used by storage classes to advance ID generators based on incoming replication data. This + is called after process_replication_rows such that caches are invalidated before any token + positions advance. + """ def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index e59776f434..86032897f5 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) + if stream_name == AccountDataStream.NAME: for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( @@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a58668a380..2179a8bf59 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=True, ) elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.keys is None: @@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 48a54d9cb8..713be91c5d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a5bb4d404e..db877e3f13 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 761b15a815..d150fa8a94 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore): token: int, rows: Iterable[Any], ) -> None: - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(instance_name, token) - elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(instance_name, -token) - elif stream_name == UnPartialStatedEventStream.NAME: + if stream_name == UnPartialStatedEventStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) @@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(instance_name, token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(instance_name, -token) + super().process_replication_position(stream_name, instance_name, token) + async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased from the database due to a redaction. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9769a18a9d..7b60815043 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) rows: Iterable[Any], ) -> None: if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4c64c46ad..d4e4b777da 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -154,6 +154,13 @@ class PushRulesWorkerStore( self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 40fd781a6a..7f24a3b6ec 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e06725f69c..86f5bce5f0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a3..e23c927e02 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore): rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + class TagsStore(TagsWorkerStore): pass -- cgit 1.5.1 From 906dfaa2cf5a79ed9c18529b1a370ffd49c0204e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 4 Jan 2023 08:26:10 -0500 Subject: Support non-OpenID compliant user info endpoints (#14753) OpenID specifies the format of the user info endpoint and some OAuth 2.0 IdPs do not follow it, e.g. NextCloud and Twitter. This adds subject_template and picture_template options to the default mapping provider for more flexibility in matching those user info responses. --- changelog.d/14753.feature | 1 + docs/usage/configuration/config_documentation.md | 18 ++++++++++++++ synapse/handlers/oidc.py | 31 ++++++++++++++++++------ 3 files changed, 42 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14753.feature (limited to 'synapse') diff --git a/changelog.d/14753.feature b/changelog.d/14753.feature new file mode 100644 index 0000000000..38b4d6af4b --- /dev/null +++ b/changelog.d/14753.feature @@ -0,0 +1 @@ +Support non-OpenID compliant userinfo claims for subject and picture. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 67e0acc910..23f9dcbea2 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3098,10 +3098,26 @@ Options for each entry include: For the default provider, the following settings are available: + * `subject_template`: Jinja2 template for a unique identifier for the user. + Defaults to `{{ user.sub }}`, which OpenID Connect compliant providers should provide. + + This replaces and overrides `subject_claim`. + * `subject_claim`: name of the claim containing a unique identifier for the user. Defaults to 'sub', which OpenID Connect compliant providers should provide. + *Deprecated in Synapse v1.75.0.* + + * `picture_template`: Jinja2 template for an url for the user's profile picture. + Defaults to `{{ user.picture }}`, which OpenID Connect compliant providers should + provide and has to refer to a direct image file such as PNG, JPEG, or GIF image file. + + This replaces and overrides `picture_claim`. + + Currently only supported in monolithic (single-process) server configurations + where the media repository runs within the Synapse process. + * `picture_claim`: name of the claim containing an url for the user's profile picture. Defaults to 'picture', which OpenID Connect compliant providers should provide and has to refer to a direct image file such as PNG, JPEG, or GIF image file. @@ -3109,6 +3125,8 @@ Options for each entry include: Currently only supported in monolithic (single-process) server configurations where the media repository runs within the Synapse process. + *Deprecated in Synapse v1.75.0.* + * `localpart_template`: Jinja2 template for the localpart of the MXID. If this is not set, the user will be prompted to choose their own username (see the documentation for the `sso_auth_account_details.html` diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 23fb00c9c9..24e1cec5b6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1520,8 +1520,8 @@ env.filters.update( @attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: - subject_claim: str - picture_claim: str + subject_template: Template + picture_template: Template localpart_template: Optional[Template] display_name_template: Optional[Template] email_template: Optional[Template] @@ -1540,8 +1540,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @staticmethod def parse_config(config: dict) -> JinjaOidcMappingConfig: - subject_claim = config.get("subject_claim", "sub") - picture_claim = config.get("picture_claim", "picture") + def parse_template_config_with_claim( + option_name: str, default_claim: str + ) -> Template: + template_name = f"{option_name}_template" + template = config.get(template_name) + if not template: + # Convert the legacy subject_claim into a template. + claim = config.get(f"{option_name}_claim", default_claim) + template = "{{ user.%s }}" % (claim,) + + try: + return env.from_string(template) + except Exception as e: + raise ConfigError("invalid jinja template", path=[template_name]) from e + + subject_template = parse_template_config_with_claim("subject", "sub") + picture_template = parse_template_config_with_claim("picture", "picture") def parse_template_config(option_name: str) -> Optional[Template]: if option_name not in config: @@ -1574,8 +1589,8 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): raise ConfigError("must be a bool", path=["confirm_localpart"]) return JinjaOidcMappingConfig( - subject_claim=subject_claim, - picture_claim=picture_claim, + subject_template=subject_template, + picture_template=picture_template, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, @@ -1584,7 +1599,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): ) def get_remote_user_id(self, userinfo: UserInfo) -> str: - return userinfo[self._config.subject_claim] + return self._config.subject_template.render(user=userinfo).strip() async def map_user_attributes( self, userinfo: UserInfo, token: Token, failures: int @@ -1615,7 +1630,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) - picture = userinfo.get(self._config.picture_claim) + picture = self._config.picture_template.render(user=userinfo).strip() return UserAttributeDict( localpart=localpart, -- cgit 1.5.1 From 630d0aeaf607b4016e67895d81b0402a5dfcc769 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 4 Jan 2023 14:58:08 -0500 Subject: Support RFC7636 PKCE in the OAuth 2.0 flow. (#14750) PKCE can protect against certain attacks and is enabled by default. Support can be controlled manually by setting the pkce_method of each oidc_providers entry to 'auto' (default), 'always', or 'never'. This is required by Twitter OAuth 2.0 support. --- changelog.d/14750.feature | 1 + docs/usage/configuration/config_documentation.md | 7 +- synapse/config/oidc.py | 6 + synapse/handlers/oidc.py | 54 ++++++-- synapse/util/macaroons.py | 7 ++ tests/handlers/test_oidc.py | 152 +++++++++++++++++++++-- tests/util/test_macaroons.py | 1 + 7 files changed, 212 insertions(+), 16 deletions(-) create mode 100644 changelog.d/14750.feature (limited to 'synapse') diff --git a/changelog.d/14750.feature b/changelog.d/14750.feature new file mode 100644 index 0000000000..cfed64ee80 --- /dev/null +++ b/changelog.d/14750.feature @@ -0,0 +1 @@ +Support [RFC7636](https://datatracker.ietf.org/doc/html/rfc7636) Proof Key for Code Exchange for OAuth single sign-on. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 23f9dcbea2..ec8403c7e9 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3053,8 +3053,13 @@ Options for each entry include: values are `client_secret_basic` (default), `client_secret_post` and `none`. +* `pkce_method`: Whether to use proof key for code exchange when requesting + and exchanging the token. Valid values are: `auto`, `always`, or `never`. Defaults + to `auto`, which uses PKCE if supported during metadata discovery. Set to `always` + to force enable PKCE or `never` to force disable PKCE. + * `scopes`: list of scopes to request. This should normally include the "openid" - scope. Defaults to ["openid"]. + scope. Defaults to `["openid"]`. * `authorization_endpoint`: the oauth2 authorization endpoint. Required if provider discovery is disabled. diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 0bd83f4010..df8c422043 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -117,6 +117,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { # to avoid importing authlib here. "enum": ["client_secret_basic", "client_secret_post", "none"], }, + "pkce_method": {"type": "string", "enum": ["auto", "always", "never"]}, "scopes": {"type": "array", "items": {"type": "string"}}, "authorization_endpoint": {"type": "string"}, "token_endpoint": {"type": "string"}, @@ -289,6 +290,7 @@ def _parse_oidc_config_dict( client_secret=oidc_config.get("client_secret"), client_secret_jwt_key=client_secret_jwt_key, client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), + pkce_method=oidc_config.get("pkce_method", "auto"), scopes=oidc_config.get("scopes", ["openid"]), authorization_endpoint=oidc_config.get("authorization_endpoint"), token_endpoint=oidc_config.get("token_endpoint"), @@ -357,6 +359,10 @@ class OidcProviderConfig: # 'none'. client_auth_method: str + # Whether to enable PKCE when exchanging the authorization & token. + # Valid values are 'auto', 'always', and 'never'. + pkce_method: str + # list of scopes to request scopes: Collection[str] diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 24e1cec5b6..0fc829acf7 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -36,6 +36,7 @@ from authlib.jose import JsonWebToken, JWTClaims from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri +from authlib.oauth2.rfc7636.challenge import create_s256_code_challenge from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template @@ -475,6 +476,16 @@ class OidcProvider: ) ) + # If PKCE support is advertised ensure the wanted method is available. + if m.get("code_challenge_methods_supported") is not None: + m.validate_code_challenge_methods_supported() + if "S256" not in m["code_challenge_methods_supported"]: + raise ValueError( + '"S256" not in "code_challenge_methods_supported" ({supported!r})'.format( + supported=m["code_challenge_methods_supported"], + ) + ) + if m.get("response_types_supported") is not None: m.validate_response_types_supported() @@ -602,6 +613,11 @@ class OidcProvider: if self._config.jwks_uri: metadata["jwks_uri"] = self._config.jwks_uri + if self._config.pkce_method == "always": + metadata["code_challenge_methods_supported"] = ["S256"] + elif self._config.pkce_method == "never": + metadata.pop("code_challenge_methods_supported", None) + self._validate_metadata(metadata) return metadata @@ -653,7 +669,7 @@ class OidcProvider: return jwk_set - async def _exchange_code(self, code: str) -> Token: + async def _exchange_code(self, code: str, code_verifier: str) -> Token: """Exchange an authorization code for a token. This calls the ``token_endpoint`` with the authorization code we @@ -666,6 +682,7 @@ class OidcProvider: Args: code: The authorization code we got from the callback. + code_verifier: The PKCE code verifier to send, blank if unused. Returns: A dict containing various tokens. @@ -696,6 +713,8 @@ class OidcProvider: "code": code, "redirect_uri": self._callback_url, } + if code_verifier: + args["code_verifier"] = code_verifier body = urlencode(args, True) # Fill the body/headers with credentials @@ -914,11 +933,14 @@ class OidcProvider: - ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``state``: a random string - ``nonce``: a random string + - ``code_challenge``: a RFC7636 code challenge (if PKCE is supported) - In addition generating a redirect URL, we are setting a cookie with - a signed macaroon token containing the state, the nonce and the - client_redirect_url params. Those are then checked when the client - comes back from the provider. + In addition to generating a redirect URL, we are setting a cookie with + a signed macaroon token containing the state, the nonce, the + client_redirect_url, and (optionally) the code_verifier params. The state, + nonce, and client_redirect_url are then checked when the client comes back + from the provider. The code_verifier is passed back to the server during + the token exchange and compared to the code_challenge sent in this request. Args: request: the incoming request from the browser. @@ -935,10 +957,25 @@ class OidcProvider: state = generate_token() nonce = generate_token() + code_verifier = "" if not client_redirect_url: client_redirect_url = b"" + metadata = await self.load_metadata() + + # Automatically enable PKCE if it is supported. + extra_grant_values = {} + if metadata.get("code_challenge_methods_supported"): + code_verifier = generate_token(48) + + # Note that we verified the server supports S256 earlier (in + # OidcProvider._validate_metadata). + extra_grant_values = { + "code_challenge_method": "S256", + "code_challenge": create_s256_code_challenge(code_verifier), + } + cookie = self._macaroon_generaton.generate_oidc_session_token( state=state, session_data=OidcSessionData( @@ -946,6 +983,7 @@ class OidcProvider: nonce=nonce, client_redirect_url=client_redirect_url.decode(), ui_auth_session_id=ui_auth_session_id or "", + code_verifier=code_verifier, ), ) @@ -966,7 +1004,6 @@ class OidcProvider: ) ) - metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") return prepare_grant_uri( authorization_endpoint, @@ -976,6 +1013,7 @@ class OidcProvider: scope=self._scopes, state=state, nonce=nonce, + **extra_grant_values, ) async def handle_oidc_callback( @@ -1003,7 +1041,9 @@ class OidcProvider: # Exchange the code with the provider try: logger.debug("Exchanging OAuth2 code for a token") - token = await self._exchange_code(code) + token = await self._exchange_code( + code, code_verifier=session_data.code_verifier + ) except OidcError as e: logger.warning("Could not exchange OAuth2 code: %s", e) self._sso_handler.render_error(request, e.error, e.error_description) diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index 5df03d3ddc..644c341e8c 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -110,6 +110,9 @@ class OidcSessionData: ui_auth_session_id: str """The session ID of the ongoing UI Auth ("" if this is a login)""" + code_verifier: str + """The random string used in the RFC7636 code challenge ("" if PKCE is not being used).""" + class MacaroonGenerator: def __init__(self, clock: Clock, location: str, secret_key: bytes): @@ -187,6 +190,7 @@ class MacaroonGenerator: macaroon.add_first_party_caveat( f"ui_auth_session_id = {session_data.ui_auth_session_id}" ) + macaroon.add_first_party_caveat(f"code_verifier = {session_data.code_verifier}") macaroon.add_first_party_caveat(f"time < {expiry}") return macaroon.serialize() @@ -278,6 +282,7 @@ class MacaroonGenerator: v.satisfy_general(lambda c: c.startswith("idp_id = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) + v.satisfy_general(lambda c: c.startswith("code_verifier = ")) satisfy_expiry(v, self._clock.time_msec) v.verify(macaroon, self._secret_key) @@ -287,11 +292,13 @@ class MacaroonGenerator: idp_id = get_value_from_macaroon(macaroon, "idp_id") client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url") ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id") + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") return OidcSessionData( nonce=nonce, idp_id=idp_id, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, + code_verifier=code_verifier, ) def _generate_base_macaroon(self, type: MacaroonType) -> pymacaroons.Macaroon: diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 49a1842b5c..adddbd002f 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -396,6 +396,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(params["client_id"], [CLIENT_ID]) self.assertEqual(len(params["state"]), 1) self.assertEqual(len(params["nonce"]), 1) + self.assertNotIn("code_challenge", params) # Check what is in the cookies self.assertEqual(len(req.cookies), 2) # two cookies @@ -411,12 +412,117 @@ class OidcHandlerTestCase(HomeserverTestCase): macaroon = pymacaroons.Macaroon.deserialize(cookie) state = get_value_from_macaroon(macaroon, "state") nonce = get_value_from_macaroon(macaroon, "nonce") + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") redirect = get_value_from_macaroon(macaroon, "client_redirect_url") self.assertEqual(params["state"], [state]) self.assertEqual(params["nonce"], [nonce]) + self.assertEqual(code_verifier, "") self.assertEqual(redirect, "http://client/redirect") + @override_config({"oidc_config": DEFAULT_CONFIG}) + def test_redirect_request_with_code_challenge(self) -> None: + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["cookies"]) + req.cookies = [] + + with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}): + url = urlparse( + self.get_success( + self.provider.handle_redirect_request( + req, b"http://client/redirect" + ) + ) + ) + + # Ensure the code_challenge param is added to the redirect. + params = parse_qs(url.query) + self.assertEqual(len(params["code_challenge"]), 1) + + # Check what is in the cookies + self.assertEqual(len(req.cookies), 2) # two cookies + cookie_header = req.cookies[0] + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + parts = [p.strip() for p in cookie_header.split(b";")] + self.assertIn(b"Path=/_synapse/client/oidc", parts) + name, cookie = parts[0].split(b"=") + self.assertEqual(name, b"oidc_session") + + # Ensure the code_verifier is set in the cookie. + macaroon = pymacaroons.Macaroon.deserialize(cookie) + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") + self.assertNotEqual(code_verifier, "") + + @override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "always"}}) + def test_redirect_request_with_forced_code_challenge(self) -> None: + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["cookies"]) + req.cookies = [] + + url = urlparse( + self.get_success( + self.provider.handle_redirect_request(req, b"http://client/redirect") + ) + ) + + # Ensure the code_challenge param is added to the redirect. + params = parse_qs(url.query) + self.assertEqual(len(params["code_challenge"]), 1) + + # Check what is in the cookies + self.assertEqual(len(req.cookies), 2) # two cookies + cookie_header = req.cookies[0] + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + parts = [p.strip() for p in cookie_header.split(b";")] + self.assertIn(b"Path=/_synapse/client/oidc", parts) + name, cookie = parts[0].split(b"=") + self.assertEqual(name, b"oidc_session") + + # Ensure the code_verifier is set in the cookie. + macaroon = pymacaroons.Macaroon.deserialize(cookie) + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") + self.assertNotEqual(code_verifier, "") + + @override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "never"}}) + def test_redirect_request_with_disabled_code_challenge(self) -> None: + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["cookies"]) + req.cookies = [] + + # The metadata should state that PKCE is enabled. + with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}): + url = urlparse( + self.get_success( + self.provider.handle_redirect_request( + req, b"http://client/redirect" + ) + ) + ) + + # Ensure the code_challenge param is added to the redirect. + params = parse_qs(url.query) + self.assertNotIn("code_challenge", params) + + # Check what is in the cookies + self.assertEqual(len(req.cookies), 2) # two cookies + cookie_header = req.cookies[0] + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + parts = [p.strip() for p in cookie_header.split(b";")] + self.assertIn(b"Path=/_synapse/client/oidc", parts) + name, cookie = parts[0].split(b"=") + self.assertEqual(name, b"oidc_session") + + # Ensure the code_verifier is blank in the cookie. + macaroon = pymacaroons.Macaroon.deserialize(cookie) + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") + self.assertEqual(code_verifier, "") + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_error(self) -> None: """Errors from the provider returned in the callback are displayed.""" @@ -601,7 +707,7 @@ class OidcHandlerTestCase(HomeserverTestCase): payload=token ) code = "code" - ret = self.get_success(self.provider._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code, code_verifier="")) kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) @@ -615,13 +721,34 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args["client_secret"], [CLIENT_SECRET]) self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + # Test providing a code verifier. + code_verifier = "code_verifier" + ret = self.get_success( + self.provider._exchange_code(code, code_verifier=code_verifier) + ) + kwargs = self.fake_server.request.call_args[1] + + self.assertEqual(ret, token) + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) + + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["client_secret"], [CLIENT_SECRET]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + self.assertEqual(args["code_verifier"], [code_verifier]) + # Test error handling self.fake_server.post_token_handler.return_value = FakeResponse.json( code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "foo") self.assertEqual(exc.value.error_description, "bar") @@ -629,7 +756,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.fake_server.post_token_handler.return_value = FakeResponse( code=500, body=b"Not JSON" ) - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body @@ -637,21 +766,27 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, payload={"error": "internal_server_error"} ) - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.fake_server.post_token_handler.return_value = FakeResponse.json( code=400, payload={} ) - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field self.fake_server.post_token_handler.return_value = FakeResponse.json( code=200, payload={"error": "some_error"} ) - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "some_error") @override_config( @@ -688,7 +823,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # timestamps. self.reactor.advance(1000) start_time = self.reactor.seconds() - ret = self.get_success(self.provider._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code, code_verifier="")) self.assertEqual(ret, token) @@ -739,7 +874,7 @@ class OidcHandlerTestCase(HomeserverTestCase): payload=token ) code = "code" - ret = self.get_success(self.provider._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code, code_verifier="")) self.assertEqual(ret, token) @@ -1203,6 +1338,7 @@ class OidcHandlerTestCase(HomeserverTestCase): nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, + code_verifier="", ), ) diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index f68377a05a..e56ec2c860 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -92,6 +92,7 @@ class MacaroonGeneratorTestCase(TestCase): nonce="nonce", client_redirect_url="https://example.com/", ui_auth_session_id="", + code_verifier="", ) token = self.macaroon_generator.generate_oidc_session_token( state, session_data, duration_in_ms=2 * 60 * 1000 -- cgit 1.5.1 From 5e0888076fea8c70ab84114e1c261dd46330c1d6 Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Mon, 9 Jan 2023 06:12:03 -0500 Subject: Disable sending confirmation email when 3pid is disabled #14682 (#14725) * Fixes #12277 :Disable sending confirmation email when 3pid is disabled * Fix test_add_email_if_disabled test case to reflect changes to enable_3pid_changes flag * Add changelog file * Rename newsfragment. Co-authored-by: Patrick Cloke --- changelog.d/14725.misc | 1 + synapse/rest/client/account.py | 5 +++++ tests/rest/client/test_account.py | 30 +++++------------------------- 3 files changed, 11 insertions(+), 25 deletions(-) create mode 100644 changelog.d/14725.misc (limited to 'synapse') diff --git a/changelog.d/14725.misc b/changelog.d/14725.misc new file mode 100644 index 0000000000..a86c4f8c05 --- /dev/null +++ b/changelog.d/14725.misc @@ -0,0 +1 @@ +Disable sending confirmation email when 3pid is disabled. diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index b4b92f0c99..4373c73662 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -338,6 +338,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + if not self.hs.config.registration.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + if not self.config.email.can_verify_email: logger.warning( "Adding emails have been disabled due to lack of an email config" diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index c1a7fb2f8a..88f255c9ee 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -690,41 +690,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEqual(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - channel = self.make_request( "POST", - b"/_matrix/client/unstable/account/3pid/add", + b"/_matrix/client/unstable/account/3pid/email/requestToken", { "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, + "email": "test@example.com", + "send_attempt": 1, }, - access_token=self.user_id_tok, ) + self.assertEqual( HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] ) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_delete_email(self) -> None: """Test deleting an email from profile""" -- cgit 1.5.1 From 7e582a25f8f350df29d7d83ca902bdb522d1bbaf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 9 Jan 2023 08:43:50 -0500 Subject: Improve /sync performance of when passing filters with empty arrays. (#14786) This has two related changes: * It enables fast-path processing for an empty filter (`[]`) which was previously only used for wildcard not-filters (`["*"]`). * It special cases a `/sync` filter with no-rooms to skip all room processing, previously we would partially skip processing, but would generally still calculate intermediate values for each room which were then unused. Future changes might consider further optimizations: * Skip calculating per-room account data when all rooms are filtered (currently this is thrown away). * Make similar improvements to other endpoints which support filters. --- changelog.d/14786.feature | 1 + synapse/api/filtering.py | 13 ++++++++----- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 14 +++++++++++--- 4 files changed, 21 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14786.feature (limited to 'synapse') diff --git a/changelog.d/14786.feature b/changelog.d/14786.feature new file mode 100644 index 0000000000..008d61ab03 --- /dev/null +++ b/changelog.d/14786.feature @@ -0,0 +1 @@ +Improve performance of `/sync` when filtering all rooms, message types, or senders. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index a9888381b4..2b5af264b4 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -283,6 +283,9 @@ class FilterCollection: await self._room_filter.filter(events) ) + def blocks_all_rooms(self) -> bool: + return self._room_filter.filters_all_rooms() + def blocks_all_presence(self) -> bool: return ( self._presence_filter.filters_all_types() @@ -351,13 +354,13 @@ class Filter: self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: - return "*" in self.not_types + return self.types == [] or "*" in self.not_types def filters_all_senders(self) -> bool: - return "*" in self.not_senders + return self.senders == [] or "*" in self.not_senders def filters_all_rooms(self) -> bool: - return "*" in self.not_rooms + return self.rooms == [] or "*" in self.not_rooms def _check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. @@ -450,8 +453,8 @@ class Filter: if any(map(match_func, disallowed_values)): return False - # Other the event does not match at least one of the allowed values, - # reject it. + # Otherwise if the event does not match at least one of the allowed + # values, reject it. allowed_values = getattr(self, name) if allowed_values is not None: if not any(map(match_func, allowed_values)): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 33115ce488..40f4635c4e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -275,7 +275,7 @@ class SearchHandler: ) room_ids = {r.room_id for r in rooms} - # If doing a subset of all rooms seearch, check if any of the rooms + # If doing a subset of all rooms search, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: historical_room_ids: List[str] = [] diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4fa480262b..6942e06c77 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1403,11 +1403,14 @@ class SyncHandler: logger.debug("Fetching room data") - res = await self._generate_sync_entry_for_rooms( + ( + newly_joined_rooms, + newly_joined_or_invited_or_knocked_users, + newly_left_rooms, + newly_left_users, + ) = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res - _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( since_token is None and sync_config.filter_collection.blocks_all_presence() @@ -1789,6 +1792,11 @@ class SyncHandler: - newly_left_rooms - newly_left_users """ + + # If the request doesn't care about rooms then nothing to do! + if sync_result_builder.sync_config.filter_collection.blocks_all_rooms(): + return set(), set(), set(), set() + since_token = sync_result_builder.since_token # 1. Start by fetching all ephemeral events in rooms we've joined (if required). -- cgit 1.5.1 From babeeb4e7a6f5b5c643b837bf724d674805546f6 Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Mon, 9 Jan 2023 09:22:02 -0500 Subject: Unescape HTML entities in oEmbed titles. (#14781) It doesn't seem valid that HTML entities should appear in the title field of oEmbed responses, but a popular WordPress plug-in seems to do it. There should not be harm in unescaping these. --- changelog.d/14781.misc | 1 + synapse/rest/media/v1/oembed.py | 15 +++++++++------ tests/rest/media/v1/test_oembed.py | 10 ++++++++++ 3 files changed, 20 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14781.misc (limited to 'synapse') diff --git a/changelog.d/14781.misc b/changelog.d/14781.misc new file mode 100644 index 0000000000..04f565b410 --- /dev/null +++ b/changelog.d/14781.misc @@ -0,0 +1 @@ +Unescape HTML entities in URL preview titles making use of oEmbed responses. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 827afd868d..a3738a6250 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import html import logging import urllib.parse from typing import TYPE_CHECKING, List, Optional @@ -161,7 +162,9 @@ class OEmbedProvider: title = oembed.get("title") if title and isinstance(title, str): - open_graph_response["og:title"] = title + # A common WordPress plug-in seems to incorrectly escape entities + # in the oEmbed response. + open_graph_response["og:title"] = html.unescape(title) author_name = oembed.get("author_name") if not isinstance(author_name, str): @@ -180,9 +183,9 @@ class OEmbedProvider: # Process each type separately. oembed_type = oembed.get("type") if oembed_type == "rich": - html = oembed.get("html") - if isinstance(html, str): - calc_description_and_urls(open_graph_response, html) + html_str = oembed.get("html") + if isinstance(html_str, str): + calc_description_and_urls(open_graph_response, html_str) elif oembed_type == "photo": # If this is a photo, use the full image, not the thumbnail. @@ -192,8 +195,8 @@ class OEmbedProvider: elif oembed_type == "video": open_graph_response["og:type"] = "video.other" - html = oembed.get("html") - if html and isinstance(html, str): + html_str = oembed.get("html") + if html_str and isinstance(html_str, str): calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py index 319ae8b1cc..3f7f1dbab9 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/rest/media/v1/test_oembed.py @@ -150,3 +150,13 @@ class OEmbedTests(HomeserverTestCase): result = self.parse_response({"type": "link"}) self.assertIn("og:type", result.open_graph_result) self.assertEqual(result.open_graph_result["og:type"], "website") + + def test_title_html_entities(self) -> None: + """Test HTML entities in title""" + result = self.parse_response( + {"title": "Why JSON isn’t a Good Configuration Language"} + ) + self.assertEqual( + result.open_graph_result["og:title"], + "Why JSON isn’t a Good Configuration Language", + ) -- cgit 1.5.1 From 58d2adc3da6a988452dbb9c6c4202a5ea19c4ca9 Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Mon, 9 Jan 2023 12:17:24 -0500 Subject: Remove undocumented device from pushrules (#14727) * Remove undocumented device from pushrules * Add changelog * Update changelog.d/14727.misc * Rename 14727.misc to 14727.bugfix Co-authored-by: David Robertson --- changelog.d/14727.bugfix | 1 + synapse/push/clientformat.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14727.bugfix (limited to 'synapse') diff --git a/changelog.d/14727.bugfix b/changelog.d/14727.bugfix new file mode 100644 index 0000000000..25079496e4 --- /dev/null +++ b/changelog.d/14727.bugfix @@ -0,0 +1 @@ +Remove the unspecced `device` field from `/pushrules` responses. diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 622a1e35c5..bb76c169c6 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -26,10 +26,7 @@ def format_push_rules_for_user( """Converts a list of rawrules and a enabled map into nested dictionaries to match the Matrix client-server format for push rules""" - rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = { - "global": {}, - "device": {}, - } + rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {"global": {}} rules["global"] = _add_empty_priority_class_arrays(rules["global"]) -- cgit 1.5.1 From ba4ea7d13ffae53644b206222af95a5171faa27c Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 10 Jan 2023 11:17:59 +0000 Subject: Batch up replication requests to request the resyncing of remote users's devices. (#14716) --- changelog.d/14716.misc | 1 + synapse/handlers/device.py | 124 +++++++++++++++++++++++------- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/e2e_keys.py | 93 +++++++++++++--------- synapse/handlers/federation_event.py | 2 +- synapse/replication/http/devices.py | 74 +++++++++++++++++- synapse/storage/databases/main/devices.py | 30 ++++++-- synapse/types/__init__.py | 4 + synapse/util/async_helpers.py | 55 ++++++++++++- 9 files changed, 306 insertions(+), 79 deletions(-) create mode 100644 changelog.d/14716.misc (limited to 'synapse') diff --git a/changelog.d/14716.misc b/changelog.d/14716.misc new file mode 100644 index 0000000000..ef9522e01d --- /dev/null +++ b/changelog.d/14716.misc @@ -0,0 +1 @@ +Batch up replication requests to request the resyncing of remote users's devices. \ No newline at end of file diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d4750a32e6..89864e1119 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -33,6 +34,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, HttpResponseException, + InvalidAPICallError, RequestSendFailed, SynapseError, ) @@ -45,6 +47,7 @@ from synapse.types import ( JsonDict, StreamKeyType, StreamToken, + UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -893,12 +896,47 @@ class DeviceListWorkerUpdater: def __init__(self, hs: "HomeServer"): from synapse.replication.http.devices import ( + ReplicationMultiUserDevicesResyncRestServlet, ReplicationUserDevicesResyncRestServlet, ) self._user_device_resync_client = ( ReplicationUserDevicesResyncRestServlet.make_client(hs) ) + self._multi_user_device_resync_client = ( + ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) + ) + + async def multi_user_device_resync( + self, user_ids: List[str], mark_failed_as_stale: bool = True + ) -> Dict[str, Optional[JsonDict]]: + """ + Like `user_device_resync` but operates on multiple users **from the same origin** + at once. + + Returns: + Dict from User ID to the same Dict as `user_device_resync`. + """ + # mark_failed_as_stale is not sent. Ensure this doesn't break expectations. + assert mark_failed_as_stale + + if not user_ids: + # Shortcut empty requests + return {} + + try: + return await self._multi_user_device_resync_client(user_ids=user_ids) + except SynapseError as err: + if not ( + err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED + ): + raise + + # Fall back to single requests + result: Dict[str, Optional[JsonDict]] = {} + for user_id in user_ids: + result[user_id] = await self._user_device_resync_client(user_id=user_id) + return result async def user_device_resync( self, user_id: str, mark_failed_as_stale: bool = True @@ -913,8 +951,10 @@ class DeviceListWorkerUpdater: A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + None when we weren't able to fetch the device info for some reason, + e.g. due to a connection problem. """ - return await self._user_device_resync_client(user_id=user_id) + return (await self.multi_user_device_resync([user_id]))[user_id] class DeviceListUpdater(DeviceListWorkerUpdater): @@ -1160,19 +1200,66 @@ class DeviceListUpdater(DeviceListWorkerUpdater): # Allow future calls to retry resyncinc out of sync device lists. self._resync_retry_in_progress = False + async def multi_user_device_resync( + self, user_ids: List[str], mark_failed_as_stale: bool = True + ) -> Dict[str, Optional[JsonDict]]: + """ + Like `user_device_resync` but operates on multiple users **from the same origin** + at once. + + Returns: + Dict from User ID to the same Dict as `user_device_resync`. + """ + if not user_ids: + return {} + + origins = {UserID.from_string(user_id).domain for user_id in user_ids} + + if len(origins) != 1: + raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}") + + result = {} + failed = set() + # TODO(Perf): Actually batch these up + for user_id in user_ids: + user_result, user_failed = await self._user_device_resync_returning_failed( + user_id + ) + result[user_id] = user_result + if user_failed: + failed.add(user_id) + + if mark_failed_as_stale: + await self.store.mark_remote_users_device_caches_as_stale(failed) + + return result + async def user_device_resync( self, user_id: str, mark_failed_as_stale: bool = True ) -> Optional[JsonDict]: + result, failed = await self._user_device_resync_returning_failed(user_id) + + if failed and mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + await self.store.mark_remote_users_device_caches_as_stale((user_id,)) + + return result + + async def _user_device_resync_returning_failed( + self, user_id: str + ) -> Tuple[Optional[JsonDict], bool]: """Fetches all devices for a user and updates the device cache with them. Args: user_id: The user's id whose device_list will be updated. - mark_failed_as_stale: Whether to mark the user's device list as stale - if the attempt to resync failed. Returns: - A dict with device info as under the "devices" in the result of this - request: - https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + - A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + None when we weren't able to fetch the device info for some reason, + e.g. due to a connection problem. + - True iff the resync failed and the device list should be marked as stale. """ logger.debug("Attempting to resync the device list for %s", user_id) log_kv({"message": "Doing resync to update device list."}) @@ -1181,12 +1268,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): try: result = await self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - - return None + return None, True except (RequestSendFailed, HttpResponseException) as e: logger.warning( "Failed to handle device list update for %s: %s", @@ -1194,23 +1276,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater): e, ) - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync # next time we get a device list update for this user_id. # This makes it more likely that the device lists will # eventually become consistent. - return None + return None, True except FederationDeniedError as e: set_tag("error", True) log_kv({"reason": "FederationDeniedError"}) logger.info(e) - return None + return None, False except Exception as e: set_tag("error", True) log_kv( @@ -1218,12 +1295,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): ) logger.exception("Failed to handle device list update for %s", user_id) - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - - return None + return None, True log_kv({"result": result}) stream_id = result["stream_id"] devices = result["devices"] @@ -1305,7 +1377,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): # point. self._seen_updates[user_id] = {stream_id} - return result + return result, False async def process_cross_signing_key_update( self, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 75e89850f5..00c403db49 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -195,7 +195,7 @@ class DeviceMessageHandler: sender_user_id, unknown_devices, ) - await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) + await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,)) # Immediately attempt a resync in the background run_in_background(self._user_device_resync, user_id=sender_user_id) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5fe102e2f2..d2188ca08f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -36,8 +36,8 @@ from synapse.types import ( get_domain_from_id, get_verify_key_from_cross_signing_key, ) -from synapse.util import json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer, delay_cancellation +from synapse.util import json_decoder +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.cancellation import cancellable from synapse.util.retryutils import NotRetryingDestination @@ -238,24 +238,28 @@ class E2eKeysHandler: # Now fetch any devices that we don't have in our cache # TODO It might make sense to propagate cancellations into the # deferreds which are querying remote homeservers. - await make_deferred_yieldable( - delay_cancellation( - defer.gatherResults( - [ - run_in_background( - self._query_devices_for_destination, - results, - cross_signing_keys, - failures, - destination, - queries, - timeout, - ) - for destination, queries in remote_queries_not_in_cache.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + logger.debug( + "%d destinations to query devices for", len(remote_queries_not_in_cache) + ) + + async def _query( + destination_queries: Tuple[str, Dict[str, Iterable[str]]] + ) -> None: + destination, queries = destination_queries + return await self._query_devices_for_destination( + results, + cross_signing_keys, + failures, + destination, + queries, + timeout, ) + + await concurrently_execute( + _query, + remote_queries_not_in_cache.items(), + 10, + delay_cancellation=True, ) ret = {"device_keys": results, "failures": failures} @@ -300,28 +304,41 @@ class E2eKeysHandler: # queries. We use the more efficient batched query_client_keys for all # remaining users user_ids_updated = [] - for (user_id, device_list) in destination_query.items(): - if user_id in user_ids_updated: - continue - if device_list: - continue + # Perform a user device resync for each user only once and only as long as: + # - they have an empty device_list + # - they are in some rooms that this server can see + users_to_resync_devices = { + user_id + for (user_id, device_list) in destination_query.items() + if (not device_list) and (await self.store.get_rooms_for_user(user_id)) + } - room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: - continue + logger.debug( + "%d users to resync devices for from destination %s", + len(users_to_resync_devices), + destination, + ) - # We've decided we're sharing a room with this user and should - # probably be tracking their device lists. However, we haven't - # done an initial sync on the device list so we do it now. - try: - resync_results = ( - await self.device_handler.device_list_updater.user_device_resync( - user_id - ) + try: + user_resync_results = ( + await self.device_handler.device_list_updater.multi_user_device_resync( + list(users_to_resync_devices) ) + ) + for user_id in users_to_resync_devices: + resync_results = user_resync_results[user_id] + if resync_results is None: - raise ValueError("Device resync failed") + # TODO: It's weird that we'll store a failure against a + # destination, yet continue processing users from that + # destination. + # We might want to consider changing this, but for now + # I'm leaving it as I found it. + failures[destination] = _exception_to_failure( + ValueError(f"Device resync failed for {user_id!r}") + ) + continue # Add the device keys to the results. user_devices = resync_results["devices"] @@ -339,8 +356,8 @@ class E2eKeysHandler: if self_signing_key: cross_signing_keys["self_signing_keys"][user_id] = self_signing_key - except Exception as e: - failures[destination] = _exception_to_failure(e) + except Exception as e: + failures[destination] = _exception_to_failure(e) if len(destination_query) == len(user_ids_updated): # We've updated all the users in the query and we do not need to diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 31df7f55cc..6df000faaf 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1423,7 +1423,7 @@ class FederationEventHandler: """ try: - await self._store.mark_remote_user_device_cache_as_stale(sender) + await self._store.mark_remote_users_device_caches_as_stale((sender,)) # Immediately attempt a resync in the background if self._config.worker.worker_app: diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 7c4941c3d3..ea5c08e6cf 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,12 +13,13 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request +from synapse.logging.opentracing import active_span from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -84,6 +85,76 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return 200, user_devices +class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): + """Ask master to resync the device list for multiple users from the same + remote server by contacting their server. + + This must happen on master so that the results can be correctly cached in + the database and streamed to workers. + + Request format: + + POST /_synapse/replication/multi_user_device_resync + + { + "user_ids": ["@alice:example.org", "@bob:example.org", ...] + } + + Response is roughly equivalent to ` /_matrix/federation/v1/user/devices/:user_id` + response, but there is a map from user ID to response, e.g.: + + { + "@alice:example.org": { + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": { ... }, + "device_display_name": "Alice's Mobile Phone" + } + ] + }, + ... + } + """ + + NAME = "multi_user_device_resync" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[override] + return {"user_ids": user_ids} + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, Dict[str, Optional[JsonDict]]]: + content = parse_json_object_from_request(request) + user_ids: List[str] = content["user_ids"] + + logger.info("Resync for %r", user_ids) + span = active_span() + if span: + span.set_tag("user_ids", f"{user_ids!r}") + + multi_user_devices = await self.device_list_updater.multi_user_device_resync( + user_ids + ) + + return 200, multi_user_devices + + class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): """Ask master to upload keys for the user and send them out over federation to update other servers. @@ -151,4 +222,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationUserDevicesResyncRestServlet(hs).register(http_server) + ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server) ReplicationUploadKeysForUserRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index db877e3f13..b067664473 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -54,7 +54,7 @@ from synapse.storage.util.id_generators import ( AbstractStreamIdTracker, StreamIdGenerator, ) -from synapse.types import JsonDict, get_verify_key_from_cross_signing_key +from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache @@ -1069,16 +1069,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return {row["user_id"] for row in rows} - async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: + async def mark_remote_users_device_caches_as_stale( + self, user_ids: StrCollection + ) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - await self.db_pool.simple_upsert( - table="device_lists_remote_resync", - keyvalues={"user_id": user_id}, - values={}, - insertion_values={"added_ts": self._clock.time_msec()}, - desc="mark_remote_user_device_cache_as_stale", + + def _mark_remote_users_device_caches_as_stale_txn( + txn: LoggingTransaction, + ) -> None: + # TODO add insertion_values support to simple_upsert_many and use + # that! + for user_id in user_ids: + self.db_pool.simple_upsert_txn( + txn, + table="device_lists_remote_resync", + keyvalues={"user_id": user_id}, + values={}, + insertion_values={"added_ts": self._clock.time_msec()}, + ) + + await self.db_pool.runInteraction( + "mark_remote_users_device_caches_as_stale", + _mark_remote_users_device_caches_as_stale_txn, ) async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None: diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index f2d436ddc3..0c725eb967 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -77,6 +77,10 @@ JsonMapping = Mapping[str, Any] # A JSON-serialisable object. JsonSerializable = object +# Collection[str] that does not include str itself; str being a Sequence[str] +# is very misleading and results in bugs. +StrCollection = Union[Tuple[str, ...], List[str], Set[str]] + # Note that this seems to require inheriting *directly* from Interface in order # for mypy-zope to realize it is an interface. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index d24c4f68c4..01e3cd46f6 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -205,7 +205,10 @@ T = TypeVar("T") async def concurrently_execute( - func: Callable[[T], Any], args: Iterable[T], limit: int + func: Callable[[T], Any], + args: Iterable[T], + limit: int, + delay_cancellation: bool = False, ) -> None: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -215,6 +218,8 @@ async def concurrently_execute( args: List of arguments to pass to func, each invocation of func gets a single argument. limit: Maximum number of conccurent executions. + delay_cancellation: Whether to delay cancellation until after the invocations + have finished. Returns: None, when all function invocations have finished. The return values @@ -233,9 +238,16 @@ async def concurrently_execute( # We use `itertools.islice` to handle the case where the number of args is # less than the limit, avoiding needlessly spawning unnecessary background # tasks. - await yieldable_gather_results( - _concurrently_execute_inner, (value for value in itertools.islice(it, limit)) - ) + if delay_cancellation: + await yieldable_gather_results_delaying_cancellation( + _concurrently_execute_inner, + (value for value in itertools.islice(it, limit)), + ) + else: + await yieldable_gather_results( + _concurrently_execute_inner, + (value for value in itertools.islice(it, limit)), + ) P = ParamSpec("P") @@ -292,6 +304,41 @@ async def yieldable_gather_results( raise dfe.subFailure.value from None +async def yieldable_gather_results_delaying_cancellation( + func: Callable[Concatenate[T, P], Awaitable[R]], + iter: Iterable[T], + *args: P.args, + **kwargs: P.kwargs, +) -> List[R]: + """Executes the function with each argument concurrently. + Cancellation is delayed until after all the results have been gathered. + + See `yieldable_gather_results`. + + Args: + func: Function to execute that returns a Deferred + iter: An iterable that yields items that get passed as the first + argument to the function + *args: Arguments to be passed to each call to func + **kwargs: Keyword arguments to be passed to each call to func + + Returns + A list containing the results of the function + """ + try: + return await make_deferred_yieldable( + delay_cancellation( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type] + consumeErrors=True, + ) + ) + ) + except defer.FirstError as dfe: + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None + + T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") -- cgit 1.5.1