From 1d8b80b3346b31a297668e093fb813d9ce7a1b48 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 26 Nov 2021 14:27:14 +0000 Subject: Support expiry of refresh tokens and expiry of the overall session when refresh tokens are in use. (#11425) --- synapse/handlers/auth.py | 90 ++++++++++++++++++++++++++++++++++++++------ synapse/handlers/register.py | 44 ++++++++++++++++++---- 2 files changed, 115 insertions(+), 19 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4b66a9862f..4d9c4e5834 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,6 +18,7 @@ import time import unicodedata import urllib.parse from binascii import crc32 +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -756,53 +757,109 @@ class AuthHandler: async def refresh_token( self, refresh_token: str, - valid_until_ms: Optional[int], - ) -> Tuple[str, str]: + access_token_valid_until_ms: Optional[int], + refresh_token_valid_until_ms: Optional[int], + ) -> Tuple[str, str, Optional[int]]: """ Consumes a refresh token and generate both a new access token and a new refresh token from it. The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token. + The lifetime of both the access token and refresh token will be capped so that they + do not exceed the session's ultimate expiry time, if applicable. + Args: refresh_token: The token to consume. - valid_until_ms: The expiration timestamp of the new access token. - + access_token_valid_until_ms: The expiration timestamp of the new access token. + None if the access token does not expire. + refresh_token_valid_until_ms: The expiration timestamp of the new refresh token. + None if the refresh token does not expire. Returns: - A tuple containing the new access token and refresh token + A tuple containing: + - the new access token + - the new refresh token + - the actual expiry time of the access token, which may be earlier than + `access_token_valid_until_ms`. """ # Verify the token signature first before looking up the token if not self._verify_refresh_token(refresh_token): - raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN) + raise SynapseError( + HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN + ) existing_token = await self.store.lookup_refresh_token(refresh_token) if existing_token is None: - raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN) + raise SynapseError( + HTTPStatus.UNAUTHORIZED, + "refresh token does not exist", + Codes.UNKNOWN_TOKEN, + ) if ( existing_token.has_next_access_token_been_used or existing_token.has_next_refresh_token_been_refreshed ): raise SynapseError( - 403, "refresh token isn't valid anymore", Codes.FORBIDDEN + HTTPStatus.FORBIDDEN, + "refresh token isn't valid anymore", + Codes.FORBIDDEN, + ) + + now_ms = self._clock.time_msec() + + if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms: + + raise SynapseError( + HTTPStatus.FORBIDDEN, + "The supplied refresh token has expired", + Codes.FORBIDDEN, ) + if existing_token.ultimate_session_expiry_ts is not None: + # This session has a bounded lifetime, even across refreshes. + + if access_token_valid_until_ms is not None: + access_token_valid_until_ms = min( + access_token_valid_until_ms, + existing_token.ultimate_session_expiry_ts, + ) + else: + access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts + + if refresh_token_valid_until_ms is not None: + refresh_token_valid_until_ms = min( + refresh_token_valid_until_ms, + existing_token.ultimate_session_expiry_ts, + ) + else: + refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts + if existing_token.ultimate_session_expiry_ts < now_ms: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "The session has expired and can no longer be refreshed", + Codes.FORBIDDEN, + ) + ( new_refresh_token, new_refresh_token_id, ) = await self.create_refresh_token_for_user_id( - user_id=existing_token.user_id, device_id=existing_token.device_id + user_id=existing_token.user_id, + device_id=existing_token.device_id, + expiry_ts=refresh_token_valid_until_ms, + ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts, ) access_token = await self.create_access_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id, - valid_until_ms=valid_until_ms, + valid_until_ms=access_token_valid_until_ms, refresh_token_id=new_refresh_token_id, ) await self.store.replace_refresh_token( existing_token.token_id, new_refresh_token_id ) - return access_token, new_refresh_token + return access_token, new_refresh_token, access_token_valid_until_ms def _verify_refresh_token(self, token: str) -> bool: """ @@ -836,6 +893,8 @@ class AuthHandler: self, user_id: str, device_id: str, + expiry_ts: Optional[int], + ultimate_session_expiry_ts: Optional[int], ) -> Tuple[str, int]: """ Creates a new refresh token for the user with the given user ID. @@ -843,6 +902,13 @@ class AuthHandler: Args: user_id: canonical user ID device_id: the device ID to associate with the token. + expiry_ts (milliseconds since the epoch): Time after which the + refresh token cannot be used. + If None, the refresh token never expires until it has been used. + ultimate_session_expiry_ts (milliseconds since the epoch): + Time at which the session will end and can not be extended any + further. + If None, the session can be refreshed indefinitely. Returns: The newly created refresh token and its ID in the database @@ -852,6 +918,8 @@ class AuthHandler: user_id=user_id, token=refresh_token, device_id=device_id, + expiry_ts=expiry_ts, + ultimate_session_expiry_ts=ultimate_session_expiry_ts, ) return refresh_token, refresh_token_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 448a36108e..8136ae264d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -119,6 +119,7 @@ class RegistrationHandler: self.refreshable_access_token_lifetime = ( hs.config.registration.refreshable_access_token_lifetime ) + self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime init_counters_for_auth_provider("") @@ -793,13 +794,13 @@ class RegistrationHandler: class and RegisterDeviceReplicationServlet. """ assert not self.hs.config.worker.worker_app - valid_until_ms = None + access_token_expiry = None if self.session_lifetime is not None: if is_guest: raise Exception( "session_lifetime is not currently implemented for guest access" ) - valid_until_ms = self.clock.time_msec() + self.session_lifetime + access_token_expiry = self.clock.time_msec() + self.session_lifetime refresh_token = None refresh_token_id = None @@ -808,25 +809,52 @@ class RegistrationHandler: user_id, device_id, initial_display_name ) if is_guest: - assert valid_until_ms is None + assert access_token_expiry is None access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: if should_issue_refresh_token: + now_ms = self.clock.time_msec() + + # Set the expiry time of the refreshable access token + access_token_expiry = now_ms + self.refreshable_access_token_lifetime + + # Set the refresh token expiry time (if configured) + refresh_token_expiry = None + if self.refresh_token_lifetime is not None: + refresh_token_expiry = now_ms + self.refresh_token_lifetime + + # Set an ultimate session expiry time (if configured) + ultimate_session_expiry_ts = None + if self.session_lifetime is not None: + ultimate_session_expiry_ts = now_ms + self.session_lifetime + + # Also ensure that the issued tokens don't outlive the + # session. + # (It would be weird to configure a homeserver with a shorter + # session lifetime than token lifetime, but may as well handle + # it.) + access_token_expiry = min( + access_token_expiry, ultimate_session_expiry_ts + ) + if refresh_token_expiry is not None: + refresh_token_expiry = min( + refresh_token_expiry, ultimate_session_expiry_ts + ) + ( refresh_token, refresh_token_id, ) = await self._auth_handler.create_refresh_token_for_user_id( user_id, device_id=registered_device_id, - ) - valid_until_ms = ( - self.clock.time_msec() + self.refreshable_access_token_lifetime + expiry_ts=refresh_token_expiry, + ultimate_session_expiry_ts=ultimate_session_expiry_ts, ) access_token = await self._auth_handler.create_access_token_for_user_id( user_id, device_id=registered_device_id, - valid_until_ms=valid_until_ms, + valid_until_ms=access_token_expiry, is_appservice_ghost=is_appservice_ghost, refresh_token_id=refresh_token_id, ) @@ -834,7 +862,7 @@ class RegistrationHandler: return { "device_id": registered_device_id, "access_token": access_token, - "valid_until_ms": valid_until_ms, + "valid_until_ms": access_token_expiry, "refresh_token": refresh_token, } -- cgit 1.5.1 From a82b90ab32629108fc69439a8cd38d025a406019 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 29 Nov 2021 13:34:14 +0000 Subject: Add type annotations to some of the configuration surrounding refresh tokens. (#11428) --- changelog.d/11428.misc | 1 + synapse/config/registration.py | 7 +++++-- synapse/handlers/register.py | 5 +++++ 3 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11428.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11428.misc b/changelog.d/11428.misc new file mode 100644 index 0000000000..2f814fa5fb --- /dev/null +++ b/changelog.d/11428.misc @@ -0,0 +1 @@ +Add type annotations to some of the configuration surrounding refresh tokens. \ No newline at end of file diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 5e21548060..1ddad7cb70 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError @@ -123,12 +124,14 @@ class RegistrationConfig(Config): refreshable_access_token_lifetime = self.parse_duration( refreshable_access_token_lifetime ) - self.refreshable_access_token_lifetime = refreshable_access_token_lifetime + self.refreshable_access_token_lifetime: Optional[ + int + ] = refreshable_access_token_lifetime refresh_token_lifetime = config.get("refresh_token_lifetime") if refresh_token_lifetime is not None: refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) - self.refresh_token_lifetime = refresh_token_lifetime + self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime # The fallback template used for authenticating using a registration token self.registration_token_template = self.read_template("registration_token.html") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8136ae264d..24ca11b924 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -813,6 +813,11 @@ class RegistrationHandler: access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: if should_issue_refresh_token: + # A refreshable access token lifetime must be configured + # since we're told to issue a refresh token (the caller checks + # that this value is set before setting this flag). + assert self.refreshable_access_token_lifetime is not None + now_ms = self.clock.time_msec() # Set the expiry time of the refreshable access token -- cgit 1.5.1 From a4521ce0a8d252e77ca8bd261ecf40ba67511a31 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 29 Nov 2021 14:32:20 -0500 Subject: Support the stable /hierarchy endpoint from MSC2946 (#11329) This also makes additional updates where the implementation had drifted from the approved MSC. Unstable endpoints will be removed at a later data. --- .github/workflows/tests.yml | 2 +- changelog.d/11329.feature | 1 + docs/workers.md | 4 +- scripts-dev/complement.sh | 2 +- synapse/app/homeserver.py | 1 + synapse/federation/federation_client.py | 31 ++++++-- synapse/federation/transport/client.py | 22 +++++- synapse/federation/transport/server/federation.py | 6 +- synapse/handlers/room_summary.py | 14 +++- synapse/rest/client/room.py | 8 +- tests/handlers/test_room_summary.py | 94 ++++++++++++++++------- 11 files changed, 134 insertions(+), 51 deletions(-) create mode 100644 changelog.d/11329.feature (limited to 'synapse/handlers') diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3bee4ee9f3..21c9ee7823 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -374,7 +374,7 @@ jobs: working-directory: complement/dockerfiles # Run Complement - - run: go test -v -tags synapse_blacklist,msc2403,msc2946 ./tests/... + - run: go test -v -tags synapse_blacklist,msc2403 ./tests/... env: COMPLEMENT_BASE_IMAGE: complement-synapse:latest working-directory: complement diff --git a/changelog.d/11329.feature b/changelog.d/11329.feature new file mode 100644 index 0000000000..7e0efb3b00 --- /dev/null +++ b/changelog.d/11329.feature @@ -0,0 +1 @@ +Support the stable API endpoints for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946): the room `/hierarchy` endpoint. diff --git a/docs/workers.md b/docs/workers.md index 17c8bfeef6..fd83e2ddeb 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -210,7 +210,7 @@ expressions: ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/ - ^/_matrix/federation/unstable/org.matrix.msc2946/hierarchy/ + ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/ # Inbound federation transaction request ^/_matrix/federation/v1/send/ @@ -223,7 +223,7 @@ expressions: ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ - ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/hierarchy$ + ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$ ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$ diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 972244f4c9..53295b58fc 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7e09530ad2..52541faab2 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -194,6 +194,7 @@ class SynapseHomeServer(HomeServer): { "/_matrix/client/api/v1": client_resource, "/_matrix/client/r0": client_resource, + "/_matrix/client/v1": client_resource, "/_matrix/client/v3": client_resource, "/_matrix/client/unstable": client_resource, "/_matrix/client/v2_alpha": client_resource, diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 3b85b135e0..bc3f96c1fc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1395,11 +1395,28 @@ class FederationClient(FederationBase): async def send_request( destination: str, ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: - res = await self.transport_layer.get_room_hierarchy( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - ) + try: + res = await self.transport_layer.get_room_hierarchy( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) + except HttpResponseException as e: + # If an error is received that is due to an unrecognised endpoint, + # fallback to the unstable endpoint. Otherwise consider it a + # legitmate error and raise. + if not self._is_unknown_endpoint(e): + raise + + logger.debug( + "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API" + ) + + res = await self.transport_layer.get_room_hierarchy_unstable( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) room = res.get("room") if not isinstance(room, dict): @@ -1449,6 +1466,10 @@ class FederationClient(FederationBase): if e.code != 502: raise + logger.debug( + "Couldn't fetch room hierarchy, falling back to the spaces API" + ) + # Fallback to the old federation API and translate the results if # no servers implement the new API. # diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 0fea221165..fe29bcfd4b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1192,10 +1192,24 @@ class TransportLayerClient: ) async def get_room_hierarchy( - self, - destination: str, - room_id: str, - suggested_only: bool, + self, destination: str, room_id: str, suggested_only: bool + ) -> JsonDict: + """ + Args: + destination: The remote server + room_id: The room ID to ask about. + suggested_only: if True, only suggested rooms will be returned + """ + path = _create_v1_path("/hierarchy/%s", room_id) + + return await self.client.get_json( + destination=destination, + path=path, + args={"suggested_only": "true" if suggested_only else "false"}, + ) + + async def get_room_hierarchy_unstable( + self, destination: str, room_id: str, suggested_only: bool ) -> JsonDict: """ Args: diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 2fdf6cc99e..66e915228c 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -611,7 +611,6 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): class FederationRoomHierarchyServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" PATH = "/hierarchy/(?P[^/]*)" def __init__( @@ -637,6 +636,10 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): ) +class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + + class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -701,6 +704,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( RoomComplexityServlet, FederationSpaceSummaryServlet, FederationRoomHierarchyServlet, + FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, ) diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 8181cc0b52..b2cfe537df 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,8 +36,9 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict +from synapse.types import JsonDict, Requester from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -93,6 +94,9 @@ class RoomSummaryHandler: self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() + self._ratelimiter = Ratelimiter( + store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 + ) # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. @@ -249,7 +253,7 @@ class RoomSummaryHandler: async def get_room_hierarchy( self, - requester: str, + requester: Requester, requested_room_id: str, suggested_only: bool = False, max_depth: Optional[int] = None, @@ -276,6 +280,8 @@ class RoomSummaryHandler: Returns: The JSON hierarchy dictionary. """ + await self._ratelimiter.ratelimit(requester) + # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. # @@ -283,7 +289,7 @@ class RoomSummaryHandler: # to process multiple requests for the same page will result in errors. return await self._pagination_response_cache.wrap( ( - requester, + requester.user.to_string(), requested_room_id, suggested_only, max_depth, @@ -291,7 +297,7 @@ class RoomSummaryHandler: from_token, ), self._get_room_hierarchy, - requester, + requester.user.to_string(), requested_room_id, suggested_only, max_depth, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 955d4e8641..73d0f7c950 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1138,12 +1138,12 @@ class RoomSpaceSummaryRestServlet(RestServlet): class RoomHierarchyRestServlet(RestServlet): - PATTERNS = ( + PATTERNS = [ re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" + "^/_matrix/client/(v1|unstable/org.matrix.msc2946)" "/rooms/(?P[^/]*)/hierarchy$" ), - ) + ] def __init__(self, hs: "HomeServer"): super().__init__() @@ -1168,7 +1168,7 @@ class RoomHierarchyRestServlet(RestServlet): ) return 200, await self._room_summary_handler.get_room_hierarchy( - requester.user.to_string(), + requester, room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), max_depth=max_depth, diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 7b95844b55..e5a6a6c747 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -32,7 +32,7 @@ from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEnt from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from tests import unittest @@ -249,7 +249,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -263,7 +263,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): expected = [(self.space, [self.room]), (self.room, ())] self._assert_rooms(result, expected) - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) self._assert_hierarchy(result, expected) # If the space is made invite-only, it should no longer be viewable. @@ -274,7 +276,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + self.get_failure( + self.handler.get_room_hierarchy(create_requester(user2), self.space), + AuthError, + ) # If the space is made world-readable it should return a result. self.helper.send_state( @@ -286,7 +291,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) self._assert_hierarchy(result, expected) # Make it not world-readable again and confirm it results in an error. @@ -297,7 +304,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + self.get_failure( + self.handler.get_room_hierarchy(create_requester(user2), self.space), + AuthError, + ) # Join the space and results should be returned. self.helper.invite(self.space, targ=user2, tok=self.token) @@ -305,7 +315,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) self._assert_hierarchy(result, expected) # Attempting to view an unknown room returns the same error. @@ -314,7 +326,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): AuthError, ) self.get_failure( - self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), + self.handler.get_room_hierarchy( + create_requester(user2), "#not-a-space:" + self.hs.hostname + ), AuthError, ) @@ -322,10 +336,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): """In-flight room hierarchy requests are deduplicated.""" # Run two `get_room_hierarchy` calls up until they block. deferred1 = ensureDeferred( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) deferred2 = ensureDeferred( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) # Complete the two calls. @@ -340,7 +354,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # A subsequent `get_room_hierarchy` call should not reuse the result. result3 = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result3, expected) self.assertIsNot(result1, result3) @@ -359,9 +373,11 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Run two `get_room_hierarchy` calls for different users up until they block. deferred1 = ensureDeferred( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + deferred2 = ensureDeferred( + self.handler.get_room_hierarchy(create_requester(user2), self.space) ) - deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space)) # Complete the two calls. result1 = self.get_success(deferred1) @@ -465,7 +481,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_rooms(result, expected) - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) self._assert_hierarchy(result, expected) def test_complex_space(self): @@ -507,7 +525,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -522,7 +540,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): room_ids.append(self.room) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, limit=7) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, limit=7 + ) ) # The result should have the space and all of the links, plus some of the # rooms and a pagination token. @@ -534,7 +554,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Check the next page. result = self.get_success( self.handler.get_room_hierarchy( - self.user, self.space, limit=5, from_token=result["next_batch"] + create_requester(self.user), + self.space, + limit=5, + from_token=result["next_batch"], ) ) # The result should have the space and the room in it, along with a link @@ -554,20 +577,22 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): room_ids.append(self.room) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, limit=7) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, limit=7 + ) ) self.assertIn("next_batch", result) # Changing the room ID, suggested-only, or max-depth causes an error. self.get_failure( self.handler.get_room_hierarchy( - self.user, self.room, from_token=result["next_batch"] + create_requester(self.user), self.room, from_token=result["next_batch"] ), SynapseError, ) self.get_failure( self.handler.get_room_hierarchy( - self.user, + create_requester(self.user), self.space, suggested_only=True, from_token=result["next_batch"], @@ -576,14 +601,19 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self.get_failure( self.handler.get_room_hierarchy( - self.user, self.space, max_depth=0, from_token=result["next_batch"] + create_requester(self.user), + self.space, + max_depth=0, + from_token=result["next_batch"], ), SynapseError, ) # An invalid token is ignored. self.get_failure( - self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, from_token="foo" + ), SynapseError, ) @@ -609,14 +639,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Test just the space itself. result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, max_depth=0 + ) ) expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] self._assert_hierarchy(result, expected) # A single additional layer. result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, max_depth=1 + ) ) expected += [ (rooms[0], ()), @@ -626,7 +660,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # A few layers. result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, max_depth=3 + ) ) expected += [ (rooms[1], ()), @@ -657,7 +693,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -739,7 +775,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -906,7 +942,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -964,7 +1000,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) -- cgit 1.5.1 From a6f1a3abecf8e8fd3e1bff439a06b853df18f194 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 2 Dec 2021 01:02:20 -0600 Subject: Add MSC3030 experimental client and federation API endpoints to get the closest event to a given timestamp (#9445) MSC3030: https://github.com/matrix-org/matrix-doc/pull/3030 Client API endpoint. This will also go and fetch from the federation API endpoint if unable to find an event locally or we found an extremity with possibly a closer event we don't know about. ``` GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= { "event_id": ... "origin_server_ts": ... } ``` Federation API endpoint: ``` GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= { "event_id": ... "origin_server_ts": ... } ``` Co-authored-by: Erik Johnston --- changelog.d/9445.feature | 1 + synapse/config/experimental.py | 3 + synapse/federation/federation_client.py | 77 +++++++++ synapse/federation/federation_server.py | 43 +++++ synapse/federation/transport/client.py | 36 ++++ synapse/federation/transport/server/__init__.py | 12 +- synapse/federation/transport/server/federation.py | 41 +++++ synapse/handlers/federation.py | 61 +++---- synapse/handlers/room.py | 144 ++++++++++++++++ synapse/http/servlet.py | 29 ++++ synapse/rest/client/room.py | 58 +++++++ synapse/server.py | 5 + synapse/storage/databases/main/events_worker.py | 195 ++++++++++++++++++++++ 13 files changed, 674 insertions(+), 31 deletions(-) create mode 100644 changelog.d/9445.feature (limited to 'synapse/handlers') diff --git a/changelog.d/9445.feature b/changelog.d/9445.feature new file mode 100644 index 0000000000..6d12eea71f --- /dev/null +++ b/changelog.d/9445.feature @@ -0,0 +1 @@ +Add [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) experimental client and federation API endpoints to get the closest event to a given timestamp. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 8b098ad48d..d78a15097c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -46,3 +46,6 @@ class ExperimentalConfig(Config): # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) + + # MSC3030 (Jump to date API endpoint) + self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index bc3f96c1fc..be1423da24 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1517,6 +1517,83 @@ class FederationClient(FederationBase): self._get_room_hierarchy_cache[(room_id, suggested_only)] = result return result + async def timestamp_to_event( + self, destination: str, room_id: str, timestamp: int, direction: str + ) -> "TimestampToEventResponse": + """ + Calls a remote federating server at `destination` asking for their + closest event to the given timestamp in the given direction. Also + validates the response to always return the expected keys or raises an + error. + + Args: + destination: Domain name of the remote homeserver + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A parsed TimestampToEventResponse including the closest event_id + and origin_server_ts + + Raises: + Various exceptions when the request fails + InvalidResponseError when the response does not have the correct + keys or wrong types + """ + remote_response = await self.transport_layer.timestamp_to_event( + destination, room_id, timestamp, direction + ) + + if not isinstance(remote_response, dict): + raise InvalidResponseError( + "Response must be a JSON dictionary but received %r" % remote_response + ) + + try: + return TimestampToEventResponse.from_json_dict(remote_response) + except ValueError as e: + raise InvalidResponseError(str(e)) + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class TimestampToEventResponse: + """Typed response dictionary for the federation /timestamp_to_event endpoint""" + + event_id: str + origin_server_ts: int + + # the raw data, including the above keys + data: JsonDict + + @classmethod + def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse": + """Parsed response from the federation /timestamp_to_event endpoint + + Args: + d: JSON object response to be parsed + + Raises: + ValueError if d does not the correct keys or they are the wrong types + """ + + event_id = d.get("event_id") + if not isinstance(event_id, str): + raise ValueError( + "Invalid response: 'event_id' must be a str but received %r" % event_id + ) + + origin_server_ts = d.get("origin_server_ts") + if not isinstance(origin_server_ts, int): + raise ValueError( + "Invalid response: 'origin_server_ts' must be a int but received %r" + % origin_server_ts + ) + + return cls(event_id, origin_server_ts, d) + @attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8fbc75aa65..cce85526e7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -110,6 +110,7 @@ class FederationServer(FederationBase): super().__init__(hs) self.handler = hs.get_federation_handler() + self.storage = hs.get_storage() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -200,6 +201,48 @@ class FederationServer(FederationBase): return 200, res + async def on_timestamp_to_event_request( + self, origin: str, room_id: str, timestamp: int, direction: str + ) -> Tuple[int, Dict[str, Any]]: + """When we receive a federated `/timestamp_to_event` request, + handle all of the logic for validating and fetching the event. + + Args: + origin: The server we received the event from + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + Tuple indicating the response status code and dictionary response + body including `event_id`. + """ + with (await self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + # We only try to fetch data from the local database + event_id = await self.store.get_event_id_for_timestamp( + room_id, timestamp, direction + ) + if event_id: + event = await self.store.get_event( + event_id, allow_none=False, allow_rejected=False + ) + + return 200, { + "event_id": event_id, + "origin_server_ts": event.origin_server_ts, + } + + raise SynapseError( + 404, + "Unable to find event from %s in direction %s" % (timestamp, direction), + errcode=Codes.NOT_FOUND, + ) + async def on_incoming_transaction( self, origin: str, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index fe29bcfd4b..d1f4be641d 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -148,6 +148,42 @@ class TransportLayerClient: destination, path=path, args=args, try_trailing_slash_on_400=True ) + @log_function + async def timestamp_to_event( + self, destination: str, room_id: str, timestamp: int, direction: str + ) -> Union[JsonDict, List]: + """ + Calls a remote federating server at `destination` asking for their + closest event to the given timestamp in the given direction. + + Args: + destination: Domain name of the remote homeserver + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + Response dict received from the remote homeserver. + + Raises: + Various exceptions when the request fails + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, + "/org.matrix.msc3030/timestamp_to_event/%s", + room_id, + ) + + args = {"ts": [str(timestamp)], "dir": [direction]} + + remote_response = await self.client.get_json( + destination, path=path, args=args, try_trailing_slash_on_400=True + ) + + return remote_response + @log_function async def send_transaction( self, diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index c32539bf5a..abcb8728f5 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -22,7 +22,10 @@ from synapse.federation.transport.server._base import ( Authenticator, BaseFederationServlet, ) -from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES +from synapse.federation.transport.server.federation import ( + FEDERATION_SERVLET_CLASSES, + FederationTimestampLookupServlet, +) from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES from synapse.federation.transport.server.groups_server import ( GROUP_SERVER_SERVLET_CLASSES, @@ -324,6 +327,13 @@ def register_servlets( ) for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]: + # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled + if ( + servletclass == FederationTimestampLookupServlet + and not hs.config.experimental.msc3030_enabled + ): + continue + servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 66e915228c..77bfd88ad0 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -174,6 +174,46 @@ class FederationBackfillServlet(BaseFederationServerServlet): return await self.handler.on_backfill_request(origin, room_id, versions, limit) +class FederationTimestampLookupServlet(BaseFederationServerServlet): + """ + API endpoint to fetch the `event_id` of the closest event to the given + timestamp (`ts` query parameter) in the given direction (`dir` query + parameter). + + Useful for other homeservers when they're unable to find an event locally. + + `ts` is a timestamp in milliseconds where we will find the closest event in + the given direction. + + `dir` can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. + + GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= + { + "event_id": ... + } + """ + + PATH = "/timestamp_to_event/(?P[^/]*)/?" + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + timestamp = parse_integer_from_args(query, "ts", required=True) + direction = parse_string_from_args( + query, "dir", default="f", allowed_values=["f", "b"], required=True + ) + + return await self.handler.on_timestamp_to_event_request( + origin, room_id, timestamp, direction + ) + + class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P[^/]*)" @@ -683,6 +723,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, + FederationTimestampLookupServlet, FederationQueryServlet, FederationMakeJoinServlet, FederationMakeLeaveServlet, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3112cc88b1..1ea837d082 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -68,6 +68,37 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: + """Get joined domains from state + + Args: + state: State map from type/state key to event. + + Returns: + Returns a list of servers with the lowest depth of their joins. + Sorted by lowest depth first. + """ + joined_users = [ + (state_key, int(event.depth)) + for (e_type, state_key), event in state.items() + if e_type == EventTypes.Member and event.membership == Membership.JOIN + ] + + joined_domains: Dict[str, int] = {} + for u, d in joined_users: + try: + dom = get_domain_from_id(u) + old_d = joined_domains.get(dom) + if old_d: + joined_domains[dom] = min(d, old_d) + else: + joined_domains[dom] = d + except Exception: + pass + + return sorted(joined_domains.items(), key=lambda d: d[1]) + + class FederationHandler: """Handles general incoming federation requests @@ -268,36 +299,6 @@ class FederationHandler: curr_state = await self.state_handler.get_current_state(room_id) - def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: - """Get joined domains from state - - Args: - state: State map from type/state key to event. - - Returns: - Returns a list of servers with the lowest depth of their joins. - Sorted by lowest depth first. - """ - joined_users = [ - (state_key, int(event.depth)) - for (e_type, state_key), event in state.items() - if e_type == EventTypes.Member and event.membership == Membership.JOIN - ] - - joined_domains: Dict[str, int] = {} - for u, d in joined_users: - try: - dom = get_domain_from_id(u) - old_d = joined_domains.get(dom) - if old_d: - joined_domains[dom] = min(d, old_d) - else: - joined_domains[dom] = d - except Exception: - pass - - return sorted(joined_domains.items(), key=lambda d: d[1]) - curr_domains = get_domains_from_state(curr_state) likely_domains = [ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 88053f9869..2bcdf32dcc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -46,6 +46,7 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, + HttpResponseException, LimitExceededError, NotFoundError, StoreError, @@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents +from synapse.federation.federation_client import InvalidResponseError +from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.streams import EventSource @@ -1220,6 +1223,147 @@ class RoomContextHandler: return results +class TimestampLookupHandler: + def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.state_handler = hs.get_state_handler() + self.federation_client = hs.get_federation_client() + + async def get_event_for_timestamp( + self, + requester: Requester, + room_id: str, + timestamp: int, + direction: str, + ) -> Tuple[str, int]: + """Find the closest event to the given timestamp in the given direction. + If we can't find an event locally or the event we have locally is next to a gap, + it will ask other federated homeservers for an event. + + Args: + requester: The user making the request according to the access token + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A tuple containing the `event_id` closest to the given timestamp in + the given direction and the `origin_server_ts`. + + Raises: + SynapseError if unable to find any event locally in the given direction + """ + + local_event_id = await self.store.get_event_id_for_timestamp( + room_id, timestamp, direction + ) + logger.debug( + "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s", + local_event_id, + timestamp, + ) + + # Check for gaps in the history where events could be hiding in between + # the timestamp given and the event we were able to find locally + is_event_next_to_backward_gap = False + is_event_next_to_forward_gap = False + if local_event_id: + local_event = await self.store.get_event( + local_event_id, allow_none=False, allow_rejected=False + ) + + if direction == "f": + # We only need to check for a backward gap if we're looking forwards + # to ensure there is nothing in between. + is_event_next_to_backward_gap = ( + await self.store.is_event_next_to_backward_gap(local_event) + ) + elif direction == "b": + # We only need to check for a forward gap if we're looking backwards + # to ensure there is nothing in between + is_event_next_to_forward_gap = ( + await self.store.is_event_next_to_forward_gap(local_event) + ) + + # If we found a gap, we should probably ask another homeserver first + # about more history in between + if ( + not local_event_id + or is_event_next_to_backward_gap + or is_event_next_to_forward_gap + ): + logger.debug( + "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first", + local_event_id, + timestamp, + ) + + # Find other homeservers from the given state in the room + curr_state = await self.state_handler.get_current_state(room_id) + curr_domains = get_domains_from_state(curr_state) + likely_domains = [ + domain for domain, depth in curr_domains if domain != self.server_name + ] + + # Loop through each homeserver candidate until we get a succesful response + for domain in likely_domains: + try: + remote_response = await self.federation_client.timestamp_to_event( + domain, room_id, timestamp, direction + ) + logger.debug( + "get_event_for_timestamp: response from domain(%s)=%s", + domain, + remote_response, + ) + + # TODO: Do we want to persist this as an extremity? + # TODO: I think ideally, we would try to backfill from + # this event and run this whole + # `get_event_for_timestamp` function again to make sure + # they didn't give us an event from their gappy history. + remote_event_id = remote_response.event_id + origin_server_ts = remote_response.origin_server_ts + + # Only return the remote event if it's closer than the local event + if not local_event or ( + abs(origin_server_ts - timestamp) + < abs(local_event.origin_server_ts - timestamp) + ): + return remote_event_id, origin_server_ts + except (HttpResponseException, InvalidResponseError) as ex: + # Let's not put a high priority on some other homeserver + # failing to respond or giving a random response + logger.debug( + "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + domain, + type(ex).__name__, + ex, + ex.args, + ) + except Exception as ex: + # But we do want to see some exceptions in our code + logger.warning( + "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + domain, + type(ex).__name__, + ex, + ex.args, + ) + + if not local_event_id: + raise SynapseError( + 404, + "Unable to find event from %s in direction %s" % (timestamp, direction), + errcode=Codes.NOT_FOUND, + ) + + return local_event_id, local_event.origin_server_ts + + class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 91ba93372c..6dd9b9ad03 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -79,6 +79,35 @@ def parse_integer( return parse_integer_from_args(args, name, default, required) +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, +) -> Optional[int]: + ... + + +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + *, + required: Literal[True], +) -> int: + ... + + +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, + required: bool = False, +) -> Optional[int]: + ... + + def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 99f303c88e..3598967be0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1070,6 +1070,62 @@ def register_txn_path( ) +class TimestampLookupRestServlet(RestServlet): + """ + API endpoint to fetch the `event_id` of the closest event to the given + timestamp (`ts` query parameter) in the given direction (`dir` query + parameter). + + Useful for cases like jump to date so you can start paginating messages from + a given date in the archive. + + `ts` is a timestamp in milliseconds where we will find the closest event in + the given direction. + + `dir` can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. + + GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= + { + "event_id": ... + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3030" + "/rooms/(?P[^/]*)/timestamp_to_event$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await self._auth.check_user_in_room(room_id, requester.user.to_string()) + + timestamp = parse_integer(request, "ts", required=True) + direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + + ( + event_id, + origin_server_ts, + ) = await self.timestamp_lookup_handler.get_event_for_timestamp( + requester, room_id, timestamp, direction + ) + + return 200, { + "event_id": event_id, + "origin_server_ts": origin_server_ts, + } + + class RoomSpaceSummaryRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1239,6 +1295,8 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) + if hs.config.experimental.msc3030_enabled: + TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/server.py b/synapse/server.py index 877eba6c08..185e40e4da 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -97,6 +97,7 @@ from synapse.handlers.room import ( RoomContextHandler, RoomCreationHandler, RoomShutdownHandler, + TimestampLookupHandler, ) from synapse.handlers.room_batch import RoomBatchHandler from synapse.handlers.room_list import RoomListHandler @@ -728,6 +729,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_room_context_handler(self) -> RoomContextHandler: return RoomContextHandler(self) + @cache_in_self + def get_timestamp_lookup_handler(self) -> TimestampLookupHandler: + return TimestampLookupHandler(self) + @cache_in_self def get_registration_handler(self) -> RegistrationHandler: return RegistrationHandler(self) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4cefc0a07e..fd19674f93 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1762,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore): "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, ) + + async def is_event_next_to_backward_gap(self, event: EventBase) -> bool: + """Check if the given event is next to a backward gap of missing events. + A(False)--->B(False)--->C(True)---> + + Args: + room_id: room where the event lives + event_id: event to check + + Returns: + Boolean indicating whether it's an extremity + """ + + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: + # If the event in question has any of its prev_events listed as a + # backward extremity, it's next to a gap. + # + # We can't just check the backward edges in `event_edges` because + # when we persist events, we will also record the prev_events as + # edges to the event in question regardless of whether we have those + # prev_events yet. We need to check whether those prev_events are + # backward extremities, also known as gaps, that need to be + # backfilled. + backward_extremity_query = """ + SELECT 1 FROM event_backward_extremities + WHERE + room_id = ? + AND %s + LIMIT 1 + """ + + # If the event in question is a backward extremity or has any of its + # prev_events listed as a backward extremity, it's next to a + # backward gap. + clause, args = make_in_list_sql_clause( + self.database_engine, + "event_id", + [event.event_id] + list(event.prev_event_ids()), + ) + + txn.execute(backward_extremity_query % (clause,), [event.room_id] + args) + backward_extremities = txn.fetchall() + + # We consider any backward extremity as a backward gap + if len(backward_extremities): + return True + + return False + + return await self.db_pool.runInteraction( + "is_event_next_to_backward_gap_txn", + is_event_next_to_backward_gap_txn, + ) + + async def is_event_next_to_forward_gap(self, event: EventBase) -> bool: + """Check if the given event is next to a forward gap of missing events. + The gap in front of the latest events is not considered a gap. + A(False)--->B(False)--->C(False)---> + A(False)--->B(False)---> --->D(True)--->E(False) + + Args: + room_id: room where the event lives + event_id: event to check + + Returns: + Boolean indicating whether it's an extremity + """ + + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: + # If the event in question is a forward extremity, we will just + # consider any potential forward gap as not a gap since it's one of + # the latest events in the room. + # + # `event_forward_extremities` does not include backfilled or outlier + # events so we can't rely on it to find forward gaps. We can only + # use it to determine whether a message is the latest in the room. + # + # We can't combine this query with the `forward_edge_query` below + # because if the event in question has no forward edges (isn't + # referenced by any other event's prev_events) but is in + # `event_forward_extremities`, we don't want to return 0 rows and + # say it's next to a gap. + forward_extremity_query = """ + SELECT 1 FROM event_forward_extremities + WHERE + room_id = ? + AND event_id = ? + LIMIT 1 + """ + + # Check to see whether the event in question is already referenced + # by another event. If we don't see any edges, we're next to a + # forward gap. + forward_edge_query = """ + SELECT 1 FROM event_edges + /* Check to make sure the event referencing our event in question is not rejected */ + LEFT JOIN rejections ON event_edges.event_id == rejections.event_id + WHERE + event_edges.room_id = ? + AND event_edges.prev_event_id = ? + /* It's not a valid edge if the event referencing our event in + * question is rejected. + */ + AND rejections.event_id IS NULL + LIMIT 1 + """ + + # We consider any forward extremity as the latest in the room and + # not a forward gap. + # + # To expand, even though there is technically a gap at the front of + # the room where the forward extremities are, we consider those the + # latest messages in the room so asking other homeservers for more + # is useless. The new latest messages will just be federated as + # usual. + txn.execute(forward_extremity_query, (event.room_id, event.event_id)) + forward_extremities = txn.fetchall() + if len(forward_extremities): + return False + + # If there are no forward edges to the event in question (another + # event hasn't referenced this event in their prev_events), then we + # assume there is a forward gap in the history. + txn.execute(forward_edge_query, (event.room_id, event.event_id)) + forward_edges = txn.fetchall() + if not len(forward_edges): + return True + + return False + + return await self.db_pool.runInteraction( + "is_event_next_to_gap_txn", + is_event_next_to_gap_txn, + ) + + async def get_event_id_for_timestamp( + self, room_id: str, timestamp: int, direction: str + ) -> Optional[str]: + """Find the closest event to the given timestamp in the given direction. + + Args: + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + The closest event_id otherwise None if we can't find any event in + the given direction. + """ + + sql_template = """ + SELECT event_id FROM events + LEFT JOIN rejections USING (event_id) + WHERE + origin_server_ts %s ? + AND room_id = ? + /* Make sure event is not rejected */ + AND rejections.event_id IS NULL + ORDER BY origin_server_ts %s + LIMIT 1; + """ + + def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" + + txn.execute( + sql_template % (comparison_operator, order), (timestamp, room_id) + ) + row = txn.fetchone() + if row: + (event_id,) = row + return event_id + + return None + + if direction not in ("f", "b"): + raise ValueError("Unknown direction: %s" % (direction,)) + + return await self.db_pool.runInteraction( + "get_event_id_for_timestamp_txn", + get_event_id_for_timestamp_txn, + ) -- cgit 1.5.1 From d26808dd854006bd26a2366c675428ce0737238c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 2 Dec 2021 20:58:32 +0000 Subject: Comments on the /sync tentacles (#11494) This mainly consists of docstrings and inline comments. There are one or two type annotations and variable renames thrown in while I was here. Co-authored-by: Patrick Cloke --- changelog.d/11494.misc | 1 + synapse/handlers/sync.py | 156 +++++++++++++++++++++++-------- synapse/storage/databases/main/stream.py | 15 ++- 3 files changed, 129 insertions(+), 43 deletions(-) create mode 100644 changelog.d/11494.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11494.misc b/changelog.d/11494.misc new file mode 100644 index 0000000000..7afd7d3a07 --- /dev/null +++ b/changelog.d/11494.misc @@ -0,0 +1 @@ +Add comments to various parts of the `/sync` handler. \ No newline at end of file diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 891435c14d..53d4627147 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -334,6 +334,19 @@ class SyncHandler: full_state: bool, cache_context: ResponseCacheContext[SyncRequestKey], ) -> SyncResult: + """The start of the machinery that produces a /sync response. + + See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details. + + This method does high-level bookkeeping: + - tracking the kind of sync in the logging context + - deleting any to_device messages whose delivery has been acknowledged. + - deciding if we should dispatch an instant or delayed response + - marking the sync as being lazily loaded, if appropriate + + Computing the body of the response begins in the next method, + `current_sync_for_user`. + """ if since_token is None: sync_type = "initial_sync" elif full_state: @@ -363,7 +376,7 @@ class SyncHandler: sync_config, since_token, full_state=full_state ) else: - + # Otherwise, we wait for something to happen and report it to the user. async def current_sync_callback( before_token: StreamToken, after_token: StreamToken ) -> SyncResult: @@ -402,7 +415,12 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Get the sync for client needed to match what the server has now.""" + """Generates the response body of a sync result, represented as a SyncResult. + + This is a wrapper around `generate_sync_result` which starts an open tracing + span to track the sync. See `generate_sync_result` for the next part of your + indoctrination. + """ with start_active_span("current_sync_for_user"): log_kv({"since_token": since_token}) sync_result = await self.generate_sync_result( @@ -560,7 +578,7 @@ class SyncHandler: # that have happened since `since_key` up to `end_key`, so we # can just use `get_room_events_stream_for_room`. # Otherwise, we want to return the last N events in the room - # in toplogical ordering. + # in topological ordering. if since_key: events, end_key = await self.store.get_room_events_stream_for_room( room_id, @@ -1042,7 +1060,18 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates a sync result.""" + """Generates the response body of a sync result. + + This is represented by a `SyncResult` struct, which is built from small pieces + using a `SyncResultBuilder`. See also + https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync + the `sync_result_builder` is passed as a mutable ("inout") parameter to various + helper functions. These retrieve and process the data which forms the sync body, + often writing to the `sync_result_builder` to store their output. + + At the end, we transfer data from the `sync_result_builder` to a new `SyncResult` + instance to signify that the sync calculation is complete. + """ # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -1344,14 +1373,22 @@ class SyncHandler: async def _generate_sync_entry_for_account_data( self, sync_result_builder: "SyncResultBuilder" ) -> Dict[str, Dict[str, JsonDict]]: - """Generates the account data portion of the sync response. Populates - `sync_result_builder` with the result. + """Generates the account data portion of the sync response. + + Account data (called "Client Config" in the spec) can be set either globally + or for a specific room. Account data consists of a list of events which + accumulate state, much like a room. + + This function retrieves global and per-room account data. The former is written + to the given `sync_result_builder`. The latter is returned directly, to be + later written to the `sync_result_builder` on a room-by-room basis. Args: sync_result_builder Returns: - A dictionary containing the per room account data. + A dictionary whose keys (room ids) map to the per room account data for that + room. """ sync_config = sync_result_builder.sync_config user_id = sync_result_builder.sync_config.user.to_string() @@ -1359,7 +1396,7 @@ class SyncHandler: if since_token and not sync_result_builder.full_state: ( - account_data, + global_account_data, account_data_by_room, ) = await self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key @@ -1370,23 +1407,23 @@ class SyncHandler: ) if push_rules_changed: - account_data["m.push_rules"] = await self.push_rules_for_user( + global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: ( - account_data, + global_account_data, account_data_by_room, ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - account_data["m.push_rules"] = await self.push_rules_for_user( + global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) account_data_for_user = await sync_config.filter_collection.filter_account_data( [ {"type": account_data_type, "content": content} - for account_data_type, content in account_data.items() + for account_data_type, content in global_account_data.items() ] ) @@ -1460,15 +1497,22 @@ class SyncHandler: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. + In the response that reaches the client, rooms are divided into four categories: + `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of + room ids returned by this function. + Args: sync_result_builder account_data_by_room: Dictionary of per room account data Returns: - Returns a 4-tuple of - `(newly_joined_rooms, newly_joined_or_invited_users, - newly_left_rooms, newly_left_users)` + Returns a 4-tuple whose entries are: + - newly_joined_rooms + - newly_joined_or_invited_or_knocked_users + - newly_left_rooms + - newly_left_users """ + # Start by fetching all ephemeral events in rooms we've joined (if required). user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( sync_result_builder.since_token is None @@ -1590,6 +1634,8 @@ class SyncHandler: ) -> bool: """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. + + Does not modify the `sync_result_builder`. """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1597,12 +1643,13 @@ class SyncHandler: assert since_token - # Get a list of membership change events that have happened. - rooms_changed = await self.store.get_membership_changes_for_user( + # Get a list of membership change events that have happened to the user + # requesting the sync. + membership_changes = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) - if rooms_changed: + if membership_changes: return True stream_id = since_token.room_key.stream @@ -1614,7 +1661,25 @@ class SyncHandler: async def _get_rooms_changed( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: - """Gets the the changes that have happened since the last sync.""" + """Determine the changes in rooms to report to the user. + + Ideally, we want to report all events whose stream ordering `s` lies in the + range `since_token < s <= now_token`, where the two tokens are read from the + sync_result_builder. + + If there are too many events in that range to report, things get complicated. + In this situation we return a truncated list of the most recent events, and + indicate in the response that there is a "gap" of omitted events. Additionally: + + - we include a "state_delta", to describe the changes in state over the gap, + - we include all membership events applying to the user making the request, + even those in the gap. + + See the spec for the rationale: + https://spec.matrix.org/v1.1/client-server-api/#syncing + + The sync_result_builder is not modified by this function. + """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token @@ -1622,21 +1687,36 @@ class SyncHandler: assert since_token - # Get a list of membership change events that have happened. - rooms_changed = await self.store.get_membership_changes_for_user( + # The spec + # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync + # notes that membership events need special consideration: + # + # > When a sync is limited, the server MUST return membership events for events + # > in the gap (between since and the start of the returned timeline), regardless + # > as to whether or not they are redundant. + # + # We fetch such events here, but we only seem to use them for categorising rooms + # as newly joined, newly left, invited or knocked. + # TODO: we've already called this function and ran this query in + # _have_rooms_changed. We could keep the results in memory to avoid a + # second query, at the cost of more complicated source code. + membership_change_events = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} - for event in rooms_changed: + for event in membership_change_events: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - newly_joined_rooms = [] - newly_left_rooms = [] - room_entries = [] - invited = [] - knocked = [] + newly_joined_rooms: List[str] = [] + newly_left_rooms: List[str] = [] + room_entries: List[RoomSyncResultBuilder] = [] + invited: List[InvitedSyncResult] = [] + knocked: List[KnockedSyncResult] = [] for room_id, events in mem_change_events_by_room_id.items(): + # The body of this loop will add this room to at least one of the five lists + # above. Things get messy if you've e.g. joined, left, joined then left the + # room all in the same sync period. logger.debug( "Membership changes in %s: [%s]", room_id, @@ -1781,7 +1861,9 @@ class SyncHandler: timeline_limit = sync_config.filter_collection.timeline_limit() - # Get all events for rooms we're currently joined to. + # Get all events since the `from_key` in rooms we're currently joined to. + # If there are too many, we get the most recent events only. This leaves + # a "gap" in the timeline, as described by the spec for /sync. room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, @@ -1842,6 +1924,10 @@ class SyncHandler: ) -> _RoomChanges: """Returns entries for all rooms for the user. + Like `_get_rooms_changed`, but assumes the `since_token` is `None`. + + This function does not modify the sync_result_builder. + Args: sync_result_builder ignored_users: Set of users ignored by user. @@ -1853,16 +1939,9 @@ class SyncHandler: now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config - membership_list = ( - Membership.INVITE, - Membership.KNOCK, - Membership.JOIN, - Membership.LEAVE, - Membership.BAN, - ) - room_list = await self.store.get_rooms_for_local_user_where_membership_is( - user_id=user_id, membership_list=membership_list + user_id=user_id, + membership_list=Membership.LIST, ) room_entries = [] @@ -2212,8 +2291,7 @@ def _calculate_state( # to only include membership events for the senders in the timeline. # In practice, we can do this by removing them from the p_ids list, # which is the list of relevant state we know we have already sent to the client. - # see https://github.com/matrix-org/synapse/pull/2970 - # /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 + # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 if lazy_load_members: p_ids.difference_update( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 42dc807d17..57aab55259 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): oldest `limit` events. Returns: - The list of events (in ascending order) and the token from the start + The list of events (in ascending stream order) and the token from the start of the chunk of events returned. """ if from_key == to_key: @@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [], from_key - def f(txn): + def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: + """Fetch membership events for a given user. + + All such events whose stream ordering `s` lies in the range + `from_key < s <= to_key` are returned. Events are ordered by ascending stream + order. + """ + # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [] - def f(txn): + def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): Returns: A list of events and a token pointing to the start of the returned - events. The events returned are in ascending order. + events. The events returned are in ascending topological order. """ rows, token = await self.get_recent_event_ids_for_room( -- cgit 1.5.1 From 637df95de63196033a6da4a6e286e1d58ea517b6 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 3 Dec 2021 16:42:44 +0000 Subject: Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. (#11445) --- changelog.d/11445.feature | 1 + synapse/config/registration.py | 49 ++++++++++++++++++++ synapse/handlers/register.py | 20 ++++++-- tests/config/test_registration_config.py | 78 ++++++++++++++++++++++++++++++++ tests/rest/client/test_auth.py | 76 +++++++++++++++++++++++++++++++ 5 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11445.feature create mode 100644 tests/config/test_registration_config.py (limited to 'synapse/handlers') diff --git a/changelog.d/11445.feature b/changelog.d/11445.feature new file mode 100644 index 0000000000..211a722b65 --- /dev/null +++ b/changelog.d/11445.feature @@ -0,0 +1 @@ +Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. \ No newline at end of file diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 47853199f4..68a4985398 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -130,11 +130,60 @@ class RegistrationConfig(Config): int ] = refreshable_access_token_lifetime + if ( + self.session_lifetime is not None + and "refreshable_access_token_lifetime" in config + ): + if self.session_lifetime < self.refreshable_access_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `refreshable_access_token_lifetime` " + "configuration options have been set, but `refreshable_access_token_lifetime` " + " exceeds `session_lifetime`!" + ) + + # The `nonrefreshable_access_token_lifetime` applies for tokens that can NOT be + # refreshed using a refresh token. + # If it is None, then these tokens last for the entire length of the session, + # which is infinite by default. + # The intention behind this configuration option is to help with requiring + # all clients to use refresh tokens, if the homeserver administrator requires. + nonrefreshable_access_token_lifetime = config.get( + "nonrefreshable_access_token_lifetime", + None, + ) + if nonrefreshable_access_token_lifetime is not None: + nonrefreshable_access_token_lifetime = self.parse_duration( + nonrefreshable_access_token_lifetime + ) + self.nonrefreshable_access_token_lifetime = nonrefreshable_access_token_lifetime + + if ( + self.session_lifetime is not None + and self.nonrefreshable_access_token_lifetime is not None + ): + if self.session_lifetime < self.nonrefreshable_access_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `nonrefreshable_access_token_lifetime` " + "configuration options have been set, but `nonrefreshable_access_token_lifetime` " + " exceeds `session_lifetime`!" + ) + refresh_token_lifetime = config.get("refresh_token_lifetime") if refresh_token_lifetime is not None: refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime + if ( + self.session_lifetime is not None + and self.refresh_token_lifetime is not None + ): + if self.session_lifetime < self.refresh_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `refresh_token_lifetime` " + "configuration options have been set, but `refresh_token_lifetime` " + " exceeds `session_lifetime`!" + ) + # The fallback template used for authenticating using a registration token self.registration_token_template = self.read_template("registration_token.html") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 24ca11b924..b14ddd8267 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -1,4 +1,5 @@ # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -116,6 +117,9 @@ class RegistrationHandler: self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.registration.session_lifetime + self.nonrefreshable_access_token_lifetime = ( + hs.config.registration.nonrefreshable_access_token_lifetime + ) self.refreshable_access_token_lifetime = ( hs.config.registration.refreshable_access_token_lifetime ) @@ -794,13 +798,25 @@ class RegistrationHandler: class and RegisterDeviceReplicationServlet. """ assert not self.hs.config.worker.worker_app + now_ms = self.clock.time_msec() access_token_expiry = None if self.session_lifetime is not None: if is_guest: raise Exception( "session_lifetime is not currently implemented for guest access" ) - access_token_expiry = self.clock.time_msec() + self.session_lifetime + access_token_expiry = now_ms + self.session_lifetime + + if self.nonrefreshable_access_token_lifetime is not None: + if access_token_expiry is not None: + # Don't allow the non-refreshable access token to outlive the + # session. + access_token_expiry = min( + now_ms + self.nonrefreshable_access_token_lifetime, + access_token_expiry, + ) + else: + access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime refresh_token = None refresh_token_id = None @@ -818,8 +834,6 @@ class RegistrationHandler: # that this value is set before setting this flag). assert self.refreshable_access_token_lifetime is not None - now_ms = self.clock.time_msec() - # Set the expiry time of the refreshable access token access_token_expiry = now_ms + self.refreshable_access_token_lifetime diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py new file mode 100644 index 0000000000..17a84d20d8 --- /dev/null +++ b/tests/config/test_registration_config.py @@ -0,0 +1,78 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config import ConfigError +from synapse.config.homeserver import HomeServerConfig + +from tests.unittest import TestCase +from tests.utils import default_config + + +class RegistrationConfigTestCase(TestCase): + def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): + """ + session_lifetime should logically be larger than, or at least as large as, + all the different token lifetimes. + Test that the user is faced with configuration errors if they make it + smaller, as that configuration doesn't make sense. + """ + config_dict = default_config("test") + + # First test all the error conditions + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "nonrefreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "refreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "refresh_token_lifetime": "31m", + **config_dict, + } + ) + + # Then test all the fine conditions + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "31m", + "nonrefreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "31m", + "refreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + HomeServerConfig().parse_config_dict( + {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict} + ) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index d8a94f4c12..7239e1a1b5 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -524,6 +524,19 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) + def is_access_token_valid(self, access_token) -> bool: + """ + Checks whether an access token is valid, returning whether it is or not. + """ + code = self.make_request( + "GET", "/_matrix/client/v3/account/whoami", access_token=access_token + ).code + + # Either 200 or 401 is what we get back; anything else is a bug. + assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} + + return code == HTTPStatus.OK + def test_login_issue_refresh_token(self): """ A login response should include a refresh_token only if asked. @@ -671,6 +684,69 @@ class RefreshAuthTests(unittest.HomeserverTestCase): HTTPStatus.UNAUTHORIZED, ) + @override_config( + { + "refreshable_access_token_lifetime": "1m", + "nonrefreshable_access_token_lifetime": "10m", + } + ) + def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): + """ + Tests that the expiry times for refreshable and non-refreshable access + tokens can be different. + """ + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + } + login_response1 = self.make_request( + "POST", + "/_matrix/client/r0/login", + {"org.matrix.msc2918.refresh_token": True, **body}, + ) + self.assertEqual(login_response1.code, 200, login_response1.result) + self.assertApproximates( + login_response1.json_body["expires_in_ms"], 60 * 1000, 100 + ) + refreshable_access_token = login_response1.json_body["access_token"] + + login_response2 = self.make_request( + "POST", + "/_matrix/client/r0/login", + body, + ) + self.assertEqual(login_response2.code, 200, login_response2.result) + nonrefreshable_access_token = login_response2.json_body["access_token"] + + # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) + self.reactor.advance(59.0) + + # Both tokens should still be valid. + self.assertTrue(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 61 s (just past 1 minute, the time of expiry) + self.reactor.advance(2.0) + + # Only the non-refreshable token is still valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 599 s (just shy of 10 minutes, the time of expiry) + self.reactor.advance(599.0 - 61.0) + + # It's still the case that only the non-refreshable token is still valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 601 s (just past 10 minutes, the time of expiry) + self.reactor.advance(2.0) + + # Now neither token is valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} ) -- cgit 1.5.1 From 494ebd7347ba52d702802fba4c3bb13e7bfbc2cf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Dec 2021 10:51:15 -0500 Subject: Include bundled aggregations in /sync and related fixes (#11478) Due to updates to MSC2675 this includes a few fixes: * Include bundled aggregations for /sync. * Do not include bundled aggregations for /initialSync and /events. * Do not bundle aggregations for state events. * Clarifies comments and variable names. --- changelog.d/11478.bugfix | 1 + synapse/events/utils.py | 58 ++++++++++------ synapse/handlers/events.py | 5 +- synapse/handlers/initial_sync.py | 30 ++++++-- synapse/handlers/message.py | 8 +-- synapse/rest/admin/rooms.py | 13 +--- synapse/rest/client/relations.py | 9 ++- synapse/rest/client/room.py | 5 +- synapse/rest/client/sync.py | 6 +- tests/rest/client/test_relations.py | 135 +++++++++++++++++++++++++----------- 10 files changed, 169 insertions(+), 101 deletions(-) create mode 100644 changelog.d/11478.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11478.bugfix b/changelog.d/11478.bugfix new file mode 100644 index 0000000000..5f02636f50 --- /dev/null +++ b/changelog.d/11478.bugfix @@ -0,0 +1 @@ +Include bundled relation aggregations during a limited `/sync` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 05219a9dd0..84ef69df67 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -306,6 +306,7 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, + *, as_client_event: bool = True, event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, token_id: Optional[str] = None, @@ -393,7 +394,8 @@ class EventClientSerializer: self, event: Union[JsonDict, EventBase], time_now: int, - bundle_relations: bool = True, + *, + bundle_aggregations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -401,8 +403,9 @@ class EventClientSerializer: Args: event: The event being serialized. time_now: The current time in milliseconds - bundle_relations: Whether to include the bundled relations for this - event. + bundle_aggregations: Whether to include the bundled aggregations for this + event. Only applies to non-state events. (State events never include + bundled aggregations.) **kwargs: Arguments to pass to `serialize_event` Returns: @@ -414,20 +417,27 @@ class EventClientSerializer: serialized_event = serialize_event(event, time_now, **kwargs) - # If MSC1849 is enabled then we need to look if there are any relations - # we need to bundle in with the event. - # Do not bundle relations if the event has been redacted - if not event.internal_metadata.is_redacted() and ( - self._msc1849_enabled and bundle_relations + # Check if there are any bundled aggregations to include with the event. + # + # Do not bundle aggregations if any of the following at true: + # + # * Support is disabled via the configuration or the caller. + # * The event is a state event. + # * The event has been redacted. + if ( + self._msc1849_enabled + and bundle_aggregations + and not event.is_state() + and not event.internal_metadata.is_redacted() ): - await self._injected_bundled_relations(event, time_now, serialized_event) + await self._injected_bundled_aggregations(event, time_now, serialized_event) return serialized_event - async def _injected_bundled_relations( + async def _injected_bundled_aggregations( self, event: EventBase, time_now: int, serialized_event: JsonDict ) -> None: - """Potentially injects bundled relations into the unsigned portion of the serialized event. + """Potentially injects bundled aggregations into the unsigned portion of the serialized event. Args: event: The event being serialized. @@ -435,7 +445,7 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ - # Do not bundle relations for an event which represents an edit or an + # Do not bundle aggregations for an event which represents an edit or an # annotation. It does not make sense for them to have related events. relates_to = event.content.get("m.relates_to") if isinstance(relates_to, (dict, frozendict)): @@ -445,18 +455,18 @@ class EventClientSerializer: event_id = event.event_id - # The bundled relations to include. - relations = {} + # The bundled aggregations to include. + aggregations = {} annotations = await self.store.get_aggregation_groups_for_event(event_id) if annotations.chunk: - relations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - relations[RelationTypes.REFERENCE] = references.to_dict() + aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: @@ -482,7 +492,7 @@ class EventClientSerializer: else: serialized_event["content"].pop("m.relates_to", None) - relations[RelationTypes.REPLACE] = { + aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, @@ -495,17 +505,19 @@ class EventClientSerializer: latest_thread_event, ) = await self.store.get_thread_summary(event_id) if latest_thread_event: - relations[RelationTypes.THREAD] = { - # Don't bundle relations as this could recurse forever. + aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_relations=False + latest_thread_event, time_now, bundle_aggregations=False ), "count": thread_count, } - # If any bundled relations were found, include them. - if relations: - serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) + # If any bundled aggregations were found, include them. + if aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + aggregations + ) async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index b4ff935546..32b0254c5f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -122,9 +122,8 @@ class EventStreamHandler: events, time_now, as_client_event=as_client_event, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_relations=False, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, ) chunk = { diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d4e4556155..9cd21e7f2b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -165,7 +165,11 @@ class InitialSyncHandler: invite_event = await self.store.get_event(event.event_id) d["invite"] = await self._event_serializer.serialize_event( - invite_event, time_now, as_client_event + invite_event, + time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) rooms_ret.append(d) @@ -216,7 +220,11 @@ class InitialSyncHandler: d["messages"] = { "chunk": ( await self._event_serializer.serialize_events( - messages, time_now=time_now, as_client_event=as_client_event + messages, + time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) ), "start": await start_token.to_string(self.store), @@ -226,6 +234,8 @@ class InitialSyncHandler: d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, as_client_event=as_client_event, ) @@ -366,14 +376,18 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( + # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - room_state.values(), time_now + room_state.values(), time_now, bundle_aggregations=False ) ), "presence": [], @@ -392,8 +406,9 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() + # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( - current_state.values(), time_now + current_state.values(), time_now, bundle_aggregations=False ) now_token = self.hs.get_event_sources().get_current_token() @@ -467,7 +482,10 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 95b4fad3c6..87f671708c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -247,13 +247,7 @@ class MessageHandler: room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events( - room_state.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_relations=False, - ) + events = await self._event_serializer.serialize_events(room_state.values(), now) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6bbc5510f0..669ab44a45 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -449,13 +449,7 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events( - events.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_relations=False, - ) + room_state = await self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} return HTTPStatus.OK, ret @@ -789,10 +783,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_relations=False, + results["state"], time_now ) return HTTPStatus.OK, results diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b1a3304849..fc4e6921c5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,14 +224,13 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_relations to False when retrieving the original - # event because we want the content before relations were applied to - # it. + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. original_event = await self._event_serializer.serialize_event( - event, now, bundle_relations=False + event, now, bundle_aggregations=False ) # The relations returned for the requested event do include their - # bundled relations. + # bundled aggregations. serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 3598967be0..f48e2e6ca2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_relations=False, + results["state"], time_now ) return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index b6a2485732..88e4f5e063 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet): return self._event_serializer.serialize_events( events, time_now=time_now, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_relations=False, + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + bundle_aggregations=room.timeline.limited, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index b494da5138..397c12c2a6 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client import login, register, relations, room +from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel @@ -29,6 +29,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, + sync.register_servlets, login.register_servlets, register.register_servlets, admin.register_servlets_for_client_rest_resource, @@ -454,11 +455,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_aggregation_get_event(self): - """Test that annotations, references, and threads get correctly bundled when - getting the parent event. - """ - + def test_bundled_aggregations(self): + """Test that annotations, references, and threads get correctly bundled.""" + # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -485,49 +484,107 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) + def assert_bundle(actual): + """Assert the expected values of the bundled aggregations.""" - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { + # Ensure the fields are as expected. + self.assertCountEqual( + actual.keys(), + ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.THREAD, + ), + ) + + # Check the values of each field. + self.assertEquals( + { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, {"type": "m.reaction", "key": "b", "count": 1}, ] }, - RelationTypes.REFERENCE: { - "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] - }, - RelationTypes.THREAD: { - "count": 2, - "latest_event": { - "age": 100, - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "origin_server_ts": 1600, - "room_id": self.room, - "sender": self.user_id, - "type": "m.room.test", - "unsigned": {"age": 100}, - "user_id": self.user_id, + actual[RelationTypes.ANNOTATION], + ) + + self.assertEquals( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + actual[RelationTypes.REFERENCE], + ) + + self.assertEquals( + 2, + actual[RelationTypes.THREAD].get("count"), + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } }, + "event_id": thread_2, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "user_id": self.user_id, }, - }, + actual[RelationTypes.THREAD].get("latest_event"), + ) + + def _find_and_assert_event(events): + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + break + else: + raise AssertionError(f"Event {self.parent_id} not found in chunk") + assert_bundle(event["unsigned"].get("m.relations")) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["unsigned"].get("m.relations")) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, ) + self.assertEquals(200, channel.code, channel.json_body) + _find_and_assert_event(channel.json_body["chunk"]) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + + # Request sync. + channel = self.make_request("GET", "/sync", access_token=self.user_token) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + _find_and_assert_event(room_timeline["events"]) + + # Note that /relations is tested separately in test_aggregation_get_event_for_thread + # since it needs different data configured. def test_aggregation_get_event_for_annotation(self): - """Test that annotations do not get bundled relations included + """Test that annotations do not get bundled aggregations included when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -549,7 +606,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) def test_aggregation_get_event_for_thread(self): - """Test that threads get bundled relations included when directly requested.""" + """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] -- cgit 1.5.1 From a15a893df8428395df7cb95b729431575001c38a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 6 Dec 2021 18:43:06 +0100 Subject: Save the OIDC session ID (sid) with the device on login (#11482) As a step towards allowing back-channel logout for OIDC. --- changelog.d/11482.misc | 1 + synapse/handlers/auth.py | 34 +++++- synapse/handlers/device.py | 8 ++ synapse/handlers/oidc.py | 58 +++++---- synapse/handlers/register.py | 15 ++- synapse/handlers/sso.py | 4 + synapse/module_api/__init__.py | 2 + synapse/replication/http/login.py | 8 ++ synapse/rest/client/login.py | 7 +- synapse/storage/databases/main/devices.py | 50 +++++++- .../delta/65/11_devices_auth_provider_session.sql | 27 +++++ tests/handlers/test_auth.py | 6 +- tests/handlers/test_cas.py | 40 +++++- tests/handlers/test_oidc.py | 135 ++++++++++++++++++--- tests/handlers/test_saml.py | 40 +++++- 15 files changed, 370 insertions(+), 65 deletions(-) create mode 100644 changelog.d/11482.misc create mode 100644 synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql (limited to 'synapse/handlers') diff --git a/changelog.d/11482.misc b/changelog.d/11482.misc new file mode 100644 index 0000000000..e78662988f --- /dev/null +++ b/changelog.d/11482.misc @@ -0,0 +1 @@ +Save the OpenID Connect session ID on login. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4d9c4e5834..61607cf2ba 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -39,6 +39,7 @@ import attr import bcrypt import pymacaroons import unpaddedbase64 +from pymacaroons.exceptions import MacaroonVerificationFailedException from twisted.web.server import Request @@ -182,8 +183,11 @@ class LoginTokenAttributes: user_id = attr.ib(type=str) - # the SSO Identity Provider that the user authenticated with, to get this token auth_provider_id = attr.ib(type=str) + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id = attr.ib(type=Optional[str]) + """The session ID advertised by the SSO Identity Provider.""" class AuthHandler: @@ -1650,6 +1654,7 @@ class AuthHandler: client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, new_user: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> None: """Having figured out a mxid for this user, complete the HTTP request @@ -1665,6 +1670,7 @@ class AuthHandler: during successful login. Must be JSON serializable. new_user: True if we should use wording appropriate to a user who has just registered. + auth_provider_session_id: The session ID from the SSO IdP received during login. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1685,6 +1691,7 @@ class AuthHandler: extra_attributes, new_user=new_user, user_profile_data=profile, + auth_provider_session_id=auth_provider_session_id, ) def _complete_sso_login( @@ -1696,6 +1703,7 @@ class AuthHandler: extra_attributes: Optional[JsonDict] = None, new_user: bool = False, user_profile_data: Optional[ProfileInfo] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ The synchronous portion of complete_sso_login. @@ -1717,7 +1725,9 @@ class AuthHandler: # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id, auth_provider_id=auth_provider_id + registered_user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) # Append the login token to the original redirect URL (i.e. with its query @@ -1822,6 +1832,7 @@ class MacaroonGenerator: self, user_id: str, auth_provider_id: str, + auth_provider_session_id: Optional[str] = None, duration_in_ms: int = (2 * 60 * 1000), ) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1830,6 +1841,10 @@ class MacaroonGenerator: expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) + if auth_provider_session_id is not None: + macaroon.add_first_party_caveat( + "auth_provider_session_id = %s" % (auth_provider_session_id,) + ) return macaroon.serialize() def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: @@ -1851,15 +1866,28 @@ class MacaroonGenerator: user_id = get_value_from_macaroon(macaroon, "user_id") auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") + auth_provider_session_id: Optional[str] = None + try: + auth_provider_session_id = get_value_from_macaroon( + macaroon, "auth_provider_session_id" + ) + except MacaroonVerificationFailedException: + pass + v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") v.satisfy_exact("type = login") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) + v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) satisfy_expiry(v, self.hs.get_clock().time_msec) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) + return LoginTokenAttributes( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 68b446eb66..82ee11e921 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: str, device_id: Optional[str], initial_device_display_name: Optional[str] = None, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> str: """ If the given device has not been registered, register it with the @@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: @user:id device_id: device id supplied by client initial_device_display_name: device display name from client + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from the SSO IdP. Returns: device id (generated if none was supplied) """ @@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [device_id]) @@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=new_device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [new_device_id]) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 3665d91513..deb3539751 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -23,7 +23,7 @@ from authlib.common.security import generate_token from authlib.jose import JsonWebToken, jwt from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri -from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template from pymacaroons.exceptions import ( @@ -117,7 +117,8 @@ class OidcHandler: for idp_id, p in self._providers.items(): try: await p.load_metadata() - await p.load_jwks() + if not p._uses_userinfo: + await p.load_jwks() except Exception as e: raise Exception( "Error while initialising OIDC provider %r" % (idp_id,) @@ -498,10 +499,6 @@ class OidcProvider: return await self._jwks.get() async def _load_jwks(self) -> JWKS: - if self._uses_userinfo: - # We're not using jwt signing, return an empty jwk set - return {"keys": []} - metadata = await self.load_metadata() # Load the JWKS using the `jwks_uri` metadata. @@ -663,7 +660,7 @@ class OidcProvider: return UserInfo(resp) - async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. Args: @@ -673,7 +670,7 @@ class OidcProvider: request. This value should match the one inside the token. Returns: - An object representing the user. + The decoded claims in the ID token. """ metadata = await self.load_metadata() claims_params = { @@ -684,9 +681,6 @@ class OidcProvider: # If we got an `access_token`, there should be an `at_hash` claim # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - claims_cls = CodeIDToken - else: - claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) jwt = JsonWebToken(alg_values) @@ -703,7 +697,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -713,7 +707,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -721,7 +715,8 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) claims.validate(leeway=120) # allows 2 min of clock skew - return UserInfo(claims) + + return claims async def handle_redirect_request( self, @@ -837,8 +832,22 @@ class OidcProvider: logger.debug("Successfully obtained OAuth2 token data: %r", token) - # Now that we have a token, get the userinfo, either by decoding the - # `id_token` or by fetching the `userinfo_endpoint`. + # If there is an id_token, it should be validated, regardless of the + # userinfo endpoint is used or not. + if token.get("id_token") is not None: + try: + id_token = await self._parse_id_token(token, nonce=session_data.nonce) + sid = id_token.get("sid") + except Exception as e: + logger.exception("Invalid id_token") + self._sso_handler.render_error(request, "invalid_token", str(e)) + return + else: + id_token = None + sid = None + + # Now that we have a token, get the userinfo either from the `id_token` + # claims or by fetching the `userinfo_endpoint`. if self._uses_userinfo: try: userinfo = await self._fetch_userinfo(token) @@ -846,13 +855,14 @@ class OidcProvider: logger.exception("Could not fetch userinfo") self._sso_handler.render_error(request, "fetch_error", str(e)) return + elif id_token is not None: + userinfo = UserInfo(id_token) else: - try: - userinfo = await self._parse_id_token(token, nonce=session_data.nonce) - except Exception as e: - logger.exception("Invalid id_token") - self._sso_handler.render_error(request, "invalid_token", str(e)) - return + logger.error("Missing id_token in token response") + self._sso_handler.render_error( + request, "invalid_token", "Missing id_token in token response" + ) + return # first check if we're doing a UIA if session_data.ui_auth_session_id: @@ -884,7 +894,7 @@ class OidcProvider: # Call the mapper to register/login the user try: await self._complete_oidc_login( - userinfo, token, request, session_data.client_redirect_url + userinfo, token, request, session_data.client_redirect_url, sid ) except MappingException as e: logger.exception("Could not map user") @@ -896,6 +906,7 @@ class OidcProvider: token: Token, request: SynapseRequest, client_redirect_url: str, + sid: Optional[str], ) -> None: """Given a UserInfo response, complete the login flow @@ -1008,6 +1019,7 @@ class OidcProvider: oidc_response_to_user_attributes, grandfather_existing_users, extra_attributes, + auth_provider_session_id=sid, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b14ddd8267..f08a516a75 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -746,6 +746,7 @@ class RegistrationHandler: is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. @@ -756,9 +757,9 @@ class RegistrationHandler: device_id: The device ID to check, or None to generate a new one. initial_display_name: An optional display name for the device. is_guest: Whether this is a guest account - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: Whether it should also issue a refresh token + auth_provider_session_id: The session ID received during login from the SSO IdP. Returns: Tuple of device ID, access token, access token expiration time and refresh token """ @@ -769,6 +770,8 @@ class RegistrationHandler: is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) login_counter.labels( @@ -791,6 +794,8 @@ class RegistrationHandler: is_guest: bool = False, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> LoginDict: """Helper for register_device @@ -822,7 +827,11 @@ class RegistrationHandler: refresh_token_id = None registered_device_id = await self.device_handler.check_device_registered( - user_id, device_id, initial_display_name + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if is_guest: assert access_token_expiry is None diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 49fde01cf0..65c27bc64a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -365,6 +365,7 @@ class SsoHandler: sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -415,6 +416,8 @@ class SsoHandler: extra_login_attributes: An optional dictionary of extra attributes to be provided to the client in the login response. + auth_provider_session_id: An optional session ID from the IdP. + Raises: MappingException if there was a problem mapping the response to a user. RedirectException: if the mapping provider needs to redirect the user @@ -490,6 +493,7 @@ class SsoHandler: client_redirect_url, extra_login_attributes, new_user=new_user, + auth_provider_session_id=auth_provider_session_id, ) async def _call_attribute_mapper( diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a8154168be..6bfb4b8d1b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -626,6 +626,7 @@ class ModuleApi: user_id: str, duration_in_ms: int = (2 * 60 * 1000), auth_provider_id: str = "", + auth_provider_session_id: Optional[str] = None, ) -> str: """Generate a login token suitable for m.login.token authentication @@ -643,6 +644,7 @@ class ModuleApi: return self._hs.get_macaroon_generator().generate_short_term_login_token( user_id, auth_provider_id, + auth_provider_session_id, duration_in_ms, ) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 0db419ea57..daacc34cea 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost, should_issue_refresh_token, + auth_provider_id, + auth_provider_session_id, ): """ Args: @@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, "should_issue_refresh_token": should_issue_refresh_token, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, } async def _handle_request(self, request, user_id): @@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] should_issue_refresh_token = content["should_issue_refresh_token"] + auth_provider_id = content["auth_provider_id"] + auth_provider_session_id = content["auth_provider_session_id"] res = await self.registration_handler.register_device_inner( user_id, @@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) return 200, res diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index a66ee4fb3d..1b23fa18cf 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -303,6 +303,7 @@ class LoginRestServlet(RestServlet): ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -318,10 +319,10 @@ class LoginRestServlet(RestServlet): create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: True if this login should issue a refresh token alongside the access token. + auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: result: Dictionary of account information after successful login. @@ -354,6 +355,7 @@ class LoginRestServlet(RestServlet): initial_display_name, auth_provider_id=auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=auth_provider_session_id, ) result = LoginResponse( @@ -399,6 +401,7 @@ class LoginRestServlet(RestServlet): self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=res.auth_provider_session_id, ) async def _do_jwt_login( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9ccc66e589..d5a4a661cd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} + async def get_devices_by_auth_provider_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> List[Dict[str, Any]]: + """Retrieve the list of devices associated with a SSO IdP session ID. + + Args: + auth_provider_id: The SSO IdP ID as defined in the server config + auth_provider_session_id: The session ID within the IdP + Returns: + A list of dicts containing the device_id and the user_id of each device + """ + return await self.db_pool.simple_select_list( + table="device_auth_providers", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + retcols=("user_id", "device_id"), + desc="get_devices_by_auth_provider_session_id", + ) + @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: Optional[str] + self, + user_id: str, + device_id: str, + initial_device_display_name: Optional[str], + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + if auth_provider_id and auth_provider_session_id: + await self.db_pool.simple_insert( + "device_auth_providers", + values={ + "user_id": user_id, + "device_id": device_id, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="store_device_auth_provider", + ) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) + self.db_pool.simple_delete_many_txn( + txn, + table="device_auth_providers", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id}, + ) + await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql new file mode 100644 index 0000000000..a65bfb520d --- /dev/null +++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql @@ -0,0 +1,27 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track the auth provider used by each login as well as the session ID +CREATE TABLE device_auth_providers ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + auth_provider_id TEXT NOT NULL, + auth_provider_session_id TEXT NOT NULL +); + +CREATE INDEX device_auth_providers_devices + ON device_auth_providers (user_id, device_id); +CREATE INDEX device_auth_providers_sessions + ON device_auth_providers (auth_provider_id, auth_provider_session_id); diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 72e176da75..03b8b8615c 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) self.assertEqual(self.user1, res.user_id) @@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) @@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase): def _get_macaroon(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index b625995d12..8705ff8943 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -66,7 +66,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) def test_map_cas_user_to_existing_user(self): @@ -89,7 +95,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -98,7 +110,13 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_cas_user_to_invalid_localpart(self): @@ -116,7 +134,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True + "@f=c3=b6=c3=b6:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -160,7 +184,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a25c89bd5b..cfe3de5266 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -252,13 +252,6 @@ class OidcHandlerTestCase(HomeserverTestCase): with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) - # Return empty key set if JWKS are not used - self.provider._scopes = [] # not asking the openid scope - self.http_client.get_json.reset_mock() - jwks = self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_not_called() - self.assertEqual(jwks, {"keys": []}) - @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" @@ -455,7 +448,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=True + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=True, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -482,17 +481,58 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.provider._scopes = [] # do not ask the "openid" scope + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + } + self.provider._exchange_code = simple_async_mock(return_value=token) self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=False + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() + # With an ID token, userinfo fetching and sid in the ID token + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + "id_token": "id_token", + } + id_token = { + "sid": "abcdefgh", + } + self.provider._parse_id_token = simple_async_mock(return_value=id_token) + self.provider._exchange_code = simple_async_mock(return_value=token) + auth_handler.complete_sso_login.reset_mock() + self.provider._fetch_userinfo.reset_mock() + self.get_success(self.handler.handle_oidc_callback(request)) + + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=id_token["sid"], + ) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_called_once_with(token) + self.render_error.assert_not_called() + # Handle userinfo fetching error self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) @@ -776,6 +816,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url, {"phone": "1234567"}, new_user=True, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -790,7 +831,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "oidc", ANY, ANY, None, new_user=True + "@test_user:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -801,7 +848,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True + "@test_user_2:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -838,14 +891,26 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -860,7 +925,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -896,7 +967,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False + "@TEST_USER_2:test", + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -934,7 +1011,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "oidc", ANY, ANY, None, new_user=True + "@test_user1:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -1018,7 +1101,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1043,7 +1132,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1156,7 +1251,7 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={}) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 8cfc184fef..50551aa6e3 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -130,7 +130,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -156,7 +162,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -165,7 +177,13 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_saml_response_to_invalid_localpart(self): @@ -213,7 +231,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "saml", request, "", None, new_user=True + "@test_user1:test", + "saml", + request, + "", + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -309,7 +333,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) -- cgit 1.5.1 From 9c55dedc8c4484e6269451a8c3c10b3e314aeb4a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Dec 2021 11:24:31 +0000 Subject: Correctly ignore invites from ignored users (#11511) --- changelog.d/11511.bugfix | 1 + synapse/handlers/sync.py | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 changelog.d/11511.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11511.bugfix b/changelog.d/11511.bugfix new file mode 100644 index 0000000000..9c287357b8 --- /dev/null +++ b/changelog.d/11511.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where invites from ignored users were included in incremental syncs. \ No newline at end of file diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 53d4627147..97d5a26e20 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1771,6 +1771,7 @@ class SyncHandler: if not non_joins: continue + last_non_join = non_joins[-1] # Check if we have left the room. This can either be because we were # joined before *or* that we since joined and then left. @@ -1792,18 +1793,18 @@ class SyncHandler: newly_left_rooms.append(room_id) # Only bother if we're still currently invited - should_invite = non_joins[-1].membership == Membership.INVITE + should_invite = last_non_join.membership == Membership.INVITE if should_invite: - if event.sender not in ignored_users: - invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if last_non_join.sender not in ignored_users: + invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join) if invite_room_sync: invited.append(invite_room_sync) # Only bother if our latest membership in the room is knock (and we haven't # been accepted/rejected in the meantime). - should_knock = non_joins[-1].membership == Membership.KNOCK + should_knock = last_non_join.membership == Membership.KNOCK if should_knock: - knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1]) + knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join) if knock_room_sync: knocked.append(knock_room_sync) -- cgit 1.5.1 From b1ecd19c5d19815b69e425d80f442bf2877cab76 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Dec 2021 11:37:54 +0000 Subject: Fix 'delete room' admin api to work on incomplete rooms (#11523) If, for some reason, we don't have the create event, we should still be able to purge a room. --- changelog.d/11523.feature | 1 + synapse/handlers/pagination.py | 3 --- synapse/handlers/room.py | 21 +++++++-------------- synapse/rest/admin/rooms.py | 3 --- tests/rest/admin/test_room.py | 42 +++++++++++++++++++++++++----------------- 5 files changed, 33 insertions(+), 37 deletions(-) create mode 100644 changelog.d/11523.feature (limited to 'synapse/handlers') diff --git a/changelog.d/11523.feature b/changelog.d/11523.feature new file mode 100644 index 0000000000..ecac7f9db9 --- /dev/null +++ b/changelog.d/11523.feature @@ -0,0 +1 @@ +Extend the "delete room" admin api to work correctly on rooms which have previously been partially deleted. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index cd64142735..4f42438053 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -406,9 +406,6 @@ class PaginationHandler: force: set true to skip checking for joined users. """ with await self.pagination_lock.write(room_id): - # check we know about the room - await self.store.get_room_version_id(room_id) - # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2bcdf32dcc..ead2198e14 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1535,20 +1535,13 @@ class RoomShutdownHandler: await self.store.block_room(room_id, requester_user_id) if not await self.store.get_room(room_id): - if block: - # We allow you to block an unknown room. - return { - "kicked_users": [], - "failed_to_kick_users": [], - "local_aliases": [], - "new_room_id": None, - } - else: - # But if you don't want to preventatively block another room, - # this function can't do anything useful. - raise NotFoundError( - "Cannot shut down room: unknown room id %s" % (room_id,) - ) + # if we don't know about the room, there is nothing left to do. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 669ab44a45..829e86675a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -106,9 +106,6 @@ class RoomRestV2Servlet(RestServlet): HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) ) - if not await self._store.get_room(room_id): - raise NotFoundError("Unknown room id %s" % (room_id,)) - delete_id = self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d3858e460d..22f9aa6234 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -83,7 +83,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + Check that unknown rooms/server return 200 """ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test" @@ -94,8 +94,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_room_is_not_valid(self): """ @@ -508,27 +507,36 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - @parameterized.expand( - [ - ("DELETE", "/_synapse/admin/v2/rooms/%s"), - ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), - ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), - ] - ) - def test_room_does_not_exist(self, method: str, url: str): - """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + def test_room_does_not_exist(self): """ + Check that unknown rooms/server return 200 + This is important, as it allows incomplete vestiges of rooms to be cleared up + even if the create event/etc is missing. + """ + room_id = "!unknown:test" channel = self.make_request( - method, - url % "!unknown:test", + "DELETE", + f"/_synapse/admin/v2/rooms/{room_id}", content={}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/rooms/{room_id}/delete_status", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) @parameterized.expand( [ -- cgit 1.5.1 From 2a3ec6facf79f6aae011d9fb6f9ed5e43c7b6bec Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Dec 2021 12:34:38 +0000 Subject: Correctly register shutdown handler for presence workers (#11518) Fixes #11517 --- changelog.d/11518.bugfix | 1 + synapse/handlers/presence.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11518.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11518.bugfix b/changelog.d/11518.bugfix new file mode 100644 index 0000000000..3c27382c7a --- /dev/null +++ b/changelog.d/11518.bugfix @@ -0,0 +1 @@ +Fix a regression in Synapse 1.48.0 where presence workers would not clear their presence updates over replication on shutdown. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3df872c578..454d06c973 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -421,7 +421,7 @@ class WorkerPresenceHandler(BasePresenceHandler): self._on_shutdown, ) - def _on_shutdown(self) -> None: + async def _on_shutdown(self) -> None: if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) -- cgit 1.5.1 From 14d593f72d10b4d8cb67e3288bb3131ee30ccf59 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Dec 2021 12:42:05 +0000 Subject: Refactors in `_generate_sync_entry_for_rooms` (#11515) * Move sync_token up to the top * Pull out _get_ignored_users * Try to signpost the body of `_generate_sync_entry_for_rooms` * Pull out _calculate_user_changes Co-authored-by: Patrick Cloke --- changelog.d/11494.misc | 2 +- changelog.d/11515.misc | 1 + synapse/handlers/sync.py | 122 ++++++++++++++++++++++++++++++----------------- 3 files changed, 79 insertions(+), 46 deletions(-) create mode 100644 changelog.d/11515.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11494.misc b/changelog.d/11494.misc index 7afd7d3a07..52cf26a4b6 100644 --- a/changelog.d/11494.misc +++ b/changelog.d/11494.misc @@ -1 +1 @@ -Add comments to various parts of the `/sync` handler. \ No newline at end of file +Refactor various parts of the `/sync` handler. \ No newline at end of file diff --git a/changelog.d/11515.misc b/changelog.d/11515.misc new file mode 100644 index 0000000000..7f9d8cd57f --- /dev/null +++ b/changelog.d/11515.misc @@ -0,0 +1 @@ +Refactor various parts of the `/sync` handler. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 97d5a26e20..f3039c3c3f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1506,16 +1506,22 @@ class SyncHandler: account_data_by_room: Dictionary of per room account data Returns: - Returns a 4-tuple whose entries are: + Returns a 4-tuple describing rooms the user has joined or left, and users who've + joined or left rooms any rooms the user is in. This gets used later in + `_generate_sync_entry_for_device_list`. + + Its entries are: - newly_joined_rooms - newly_joined_or_invited_or_knocked_users - newly_left_rooms - newly_left_users """ - # Start by fetching all ephemeral events in rooms we've joined (if required). + since_token = sync_result_builder.since_token + + # 1. Start by fetching all ephemeral events in rooms we've joined (if required). user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - sync_result_builder.since_token is None + since_token is None and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) @@ -1529,9 +1535,8 @@ class SyncHandler: ) sync_result_builder.now_token = now_token - # We check up front if anything has changed, if it hasn't then there is + # 2. We check up front if anything has changed, if it hasn't then there is # no point in going further. - since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: have_changed = await self._have_rooms_changed(sync_result_builder) @@ -1544,20 +1549,8 @@ class SyncHandler: logger.debug("no-oping sync") return set(), set(), set(), set() - ignored_account_data = ( - await self.store.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id=user_id - ) - ) - - # If there is ignored users account data and it matches the proper type, - # then use it. - ignored_users: FrozenSet[str] = frozenset() - if ignored_account_data: - ignored_users_data = ignored_account_data.get("ignored_users", {}) - if isinstance(ignored_users_data, dict): - ignored_users = frozenset(ignored_users_data.keys()) - + # 3. Work out which rooms need reporting in the sync response. + ignored_users = await self._get_ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( sync_result_builder, ignored_users @@ -1567,7 +1560,6 @@ class SyncHandler: ) else: room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) - tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1578,6 +1570,8 @@ class SyncHandler: newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms + # 4. We need to apply further processing to `room_entries` (rooms considered + # joined or archived). async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: logger.debug("Generating room entry for %s", room_entry.room_id) await self._generate_room_entry( @@ -1596,31 +1590,13 @@ class SyncHandler: sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) - # Now we want to get any newly joined, invited or knocking users - newly_joined_or_invited_or_knocked_users = set() - newly_left_users = set() - if since_token: - for joined_sync in sync_result_builder.joined: - it = itertools.chain( - joined_sync.timeline.events, joined_sync.state.values() - ) - for event in it: - if event.type == EventTypes.Member: - if ( - event.membership == Membership.JOIN - or event.membership == Membership.INVITE - or event.membership == Membership.KNOCK - ): - newly_joined_or_invited_or_knocked_users.add( - event.state_key - ) - else: - prev_content = event.unsigned.get("prev_content", {}) - prev_membership = prev_content.get("membership", None) - if prev_membership == Membership.JOIN: - newly_left_users.add(event.state_key) - - newly_left_users -= newly_joined_or_invited_or_knocked_users + # 5. Work out which users have joined or left rooms we're in. We use this + # to build the device_list part of the sync response in + # `_generate_sync_entry_for_device_list`. + ( + newly_joined_or_invited_or_knocked_users, + newly_left_users, + ) = sync_result_builder.calculate_user_changes() return ( set(newly_joined_rooms), @@ -1629,6 +1605,29 @@ class SyncHandler: newly_left_users, ) + async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]: + """Retrieve the users ignored by the given user from their global account_data. + + Returns an empty set if + - there is no global account_data entry for ignored_users + - there is such an entry, but it's not a JSON object. + """ + # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead? + ignored_account_data = ( + await self.store.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ) + ) + + # If there is ignored users account data and it matches the proper type, + # then use it. + ignored_users: FrozenSet[str] = frozenset() + if ignored_account_data: + ignored_users_data = ignored_account_data.get("ignored_users", {}) + if isinstance(ignored_users_data, dict): + ignored_users = frozenset(ignored_users_data.keys()) + return ignored_users + async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" ) -> bool: @@ -2341,6 +2340,39 @@ class SyncResultBuilder: groups: Optional[GroupsSyncResult] = None to_device: List[JsonDict] = attr.Factory(list) + def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]: + """Work out which other users have joined or left rooms we are joined to. + + This data only is only useful for an incremental sync. + + The SyncResultBuilder is not modified by this function. + """ + newly_joined_or_invited_or_knocked_users = set() + newly_left_users = set() + if self.since_token: + for joined_sync in self.joined: + it = itertools.chain( + joined_sync.timeline.events, joined_sync.state.values() + ) + for event in it: + if event.type == EventTypes.Member: + if ( + event.membership == Membership.JOIN + or event.membership == Membership.INVITE + or event.membership == Membership.KNOCK + ): + newly_joined_or_invited_or_knocked_users.add( + event.state_key + ) + else: + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership == Membership.JOIN: + newly_left_users.add(event.state_key) + + newly_left_users -= newly_joined_or_invited_or_knocked_users + return newly_joined_or_invited_or_knocked_users, newly_left_users + @attr.s(slots=True, auto_attribs=True) class RoomSyncResultBuilder: -- cgit 1.5.1