From a86b2f6837f0a067b0a014fbf5140e8773b8da2e Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 11 Oct 2022 11:18:45 -0700 Subject: Fix a bug where redactions were not being sent over federation if we did not have the original event. (#13813) --- tests/handlers/test_appservice.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index af24c4984d..7e4570f990 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_all_new_events_stream.side_effect = [ - make_awaitable((0, [], {})), - make_awaitable((1, [event], {event.event_id: 0})), + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, {})), + make_awaitable((1, {event.event_id: 0})), + ] + self.mock_store.get_events_as_list.side_effect = [ + make_awaitable([]), + make_awaitable([event]), ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) @@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_events_stream.side_effect = [ - make_awaitable((0, [event], {event.event_id: 0})), + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, {event.event_id: 0})), ] - + self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])] self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_events_stream.side_effect = [ + self.mock_store.get_all_new_event_ids_stream.side_effect = [ make_awaitable((0, [event], {event.event_id: 0})), ] -- cgit 1.5.1 From 40bb37eb27e1841754a297ac1277748de7f6c1cb Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Sat, 15 Oct 2022 00:36:49 -0500 Subject: Stop getting missing `prev_events` after we already know their signature is invalid (#13816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While https://github.com/matrix-org/synapse/pull/13635 stops us from doing the slow thing after we've already done it once, this PR stops us from doing one of the slow things in the first place. Related to - https://github.com/matrix-org/synapse/issues/13622 - https://github.com/matrix-org/synapse/pull/13635 - https://github.com/matrix-org/synapse/issues/13676 Part of https://github.com/matrix-org/synapse/issues/13356 Follow-up to https://github.com/matrix-org/synapse/pull/13815 which tracks event signature failures. With this PR, we avoid the call to the costly `_get_state_ids_after_missing_prev_event` because the signature failure will count as an attempt before and we filter events based on the backoff before calling `_get_state_ids_after_missing_prev_event` now. For example, this will save us 156s out of the 185s total that this `matrix.org` `/messages` request. If you want to see the full Jaeger trace of this, you can drag and drop this `trace.json` into your own Jaeger, https://gist.github.com/MadLittleMods/4b12d0d0afe88c2f65ffcc907306b761 To explain this exact scenario around `/messages` -> backfill, we call `/backfill` and first check the signatures of the 100 events. We see bad signature for `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` and `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` (both member events). Then we process the 98 events remaining that have valid signatures but one of the events references `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` as a `prev_event`. So we have to do the whole `_get_state_ids_after_missing_prev_event` rigmarole which pulls in those same events which fail again because the signatures are still invalid. - `backfill` - `outgoing-federation-request` `/backfill` - `_check_sigs_and_hash_and_fetch` - `_check_sigs_and_hash_and_fetch_one` for each event received over backfill - ❗ `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` fails with `Signature on retrieved event was invalid.`: `unable to verify signature for sender domain xxx: 401: Failed to find any key to satisfy: _FetchKeyRequest(...)` - ❗ `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` fails with `Signature on retrieved event was invalid.`: `unable to verify signature for sender domain xxx: 401: Failed to find any key to satisfy: _FetchKeyRequest(...)` - `_process_pulled_events` - `_process_pulled_event` for each validated event - ❗ Event `$Q0iMdqtz3IJYfZQU2Xk2WjB5NDF8Gg8cFSYYyKQgKJ0` references `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` as a `prev_event` which is missing so we try to get it - `_get_state_ids_after_missing_prev_event` - `outgoing-federation-request` `/state_ids` - ❗ `get_pdu` for `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` which fails the signature check again - ❗ `get_pdu` for `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` which fails the signature check --- changelog.d/13816.feature | 1 + synapse/api/errors.py | 21 +++ synapse/handlers/federation.py | 16 ++ synapse/handlers/federation_event.py | 31 ++++ synapse/storage/databases/main/event_federation.py | 54 ++++++ tests/handlers/test_federation_event.py | 201 ++++++++++++++++++++- tests/storage/test_event_federation.py | 64 +++++++ 7 files changed, 386 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13816.feature (limited to 'tests/handlers') diff --git a/changelog.d/13816.feature b/changelog.d/13816.feature new file mode 100644 index 0000000000..5eaa936b08 --- /dev/null +++ b/changelog.d/13816.feature @@ -0,0 +1 @@ +Stop fetching missing `prev_events` after we already know their signature is invalid. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c606207569..e0873b1913 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -640,6 +640,27 @@ class FederationError(RuntimeError): } +class FederationPullAttemptBackoffError(RuntimeError): + """ + Raised to indicate that we are are deliberately not attempting to pull the given + event over federation because we've already done so recently and are backing off. + + Attributes: + event_id: The event_id which we are refusing to pull + message: A custom error message that gives more context + """ + + def __init__(self, event_ids: List[str], message: Optional[str]): + self.event_ids = event_ids + + if message: + error_message = message + else: + error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)." + + super().__init__(error_message) + + class HttpResponseException(CodeMessageException): """ Represents an HTTP-level failure of an outbound request diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 44e70c6c3c..5f7e0a1f79 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, LimitExceededError, NotFoundError, @@ -1720,7 +1721,22 @@ class FederationHandler: destination, event ) break + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_sync_partial_state_room: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: + # TODO: We should `record_event_failed_pull_attempt` here, + # see https://github.com/matrix-org/synapse/issues/13700 + if attempt == len(destinations) - 1: # We have tried every remote server for this event. Give up. # TODO(faster_joins) giving up isn't the right thing to do diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f382961099..4300e8dd40 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -44,6 +44,7 @@ from synapse.api.errors import ( AuthError, Codes, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, RequestSendFailed, SynapseError, @@ -567,6 +568,9 @@ class FederationEventHandler: event: partial-state event to be de-partial-stated Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to request state from the remote server. """ logger.info("Updating state for %s", event.event_id) @@ -901,6 +905,18 @@ class FederationEventHandler: context, backfilled=backfilled, ) + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_process_pulled_event: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: await self._store.record_event_failed_pull_attempt( event.room_id, event_id, str(e) @@ -947,6 +963,9 @@ class FederationEventHandler: The event context. Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to get the state from the remote server after any missing `prev_event`s. """ @@ -957,6 +976,18 @@ class FederationEventHandler: seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen + # If we've already recently attempted to pull this missing event, don't + # try it again so soon. Since we have to fetch all of the prev_events, we can + # bail early here if we find any to ignore. + prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff( + room_id, missing_prevs + ) + if len(prevs_to_ignore) > 0: + raise FederationPullAttemptBackoffError( + event_ids=prevs_to_ignore, + message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).", + ) + if not missing_prevs: return await self._state_handler.compute_event_context(event) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6b9a629edd..309a4ba664 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id: The event that failed to be fetched or processed cause: The error message or reason that we failed to pull the event """ + logger.debug( + "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s", + room_id, + event_id, + cause, + ) await self.db_pool.runInteraction( "record_event_failed_pull_attempt", self._record_event_failed_pull_attempt_upsert_txn, @@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + @trace + async def get_event_ids_to_not_pull_from_backoff( + self, + room_id: str, + event_ids: Collection[str], + ) -> List[str]: + """ + Filter down the events to ones that we've failed to pull before recently. Uses + exponential backoff. + + Args: + room_id: The room that the events belong to + event_ids: A list of events to filter down + + Returns: + List of event_ids that should not be attempted to be pulled + """ + event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", + ) + + current_time = self._clock.time_msec() + return [ + event_failed_pull_attempt["event_id"] + for event_failed_pull_attempt in event_failed_pull_attempts + # Exponential back-off (up to the upper bound) so we don't try to + # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + if current_time + < event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) + ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS + ] + async def get_missing_events( self, room_id: str, diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 918010cddb..e448cb1901 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -14,7 +14,7 @@ from typing import Optional from unittest import mock -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, StoreError from synapse.api.room_versions import RoomVersion from synapse.event_auth import ( check_state_dependent_auth_rules, @@ -43,7 +43,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): def make_homeserver(self, reactor, clock): # mock out the federation transport client self.mock_federation_transport_client = mock.Mock( - spec=["get_room_state_ids", "get_room_state", "get_event"] + spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] ) return super().setup_test_homeserver( federation_transport_client=self.mock_federation_transport_client @@ -459,6 +459,203 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) self.assertIsNotNone(persisted, "pulled event was not persisted at all") + def test_backfill_signature_failure_does_not_fetch_same_prev_event_later( + self, + ) -> None: + """ + Test to make sure we backoff and don't try to fetch a missing prev_event when we + already know it has a invalid signature from checking the signatures of all of + the events in the backfill response. + """ + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # Allow the remote user to send state events + self.helper.send_state( + room_id, + "m.room.power_levels", + {"events_default": 0, "state_default": 0}, + tok=tok, + ) + + # Add the remote user to the room + member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + + auth_event_ids = [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + member_event.event_id, + ] + + # We purposely don't run `add_hashes_and_signatures_from_other_server` + # over this because we want the signature check to fail. + pulled_event_without_signatures = make_event_from_dict( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [member_event.event_id], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event_without_signatures"}, + }, + room_version, + ) + + # Create a regular event that should pass except for the + # `pulled_event_without_signatures` in the `prev_event`. + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [ + member_event.event_id, + pulled_event_without_signatures.event_id, + ], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event"}, + } + ), + room_version, + ) + + # We expect an outbound request to /backfill, so stub that out + self.mock_federation_transport_client.backfill.return_value = make_awaitable( + { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + # This is one of the important aspects of this test: we include + # `pulled_event_without_signatures` so it fails the signature check + # when we filter down the backfill response down to events which + # have valid signatures in + # `_check_sigs_and_hash_for_pulled_events_and_fetch` + pulled_event_without_signatures.get_pdu_json(), + # Then later when we process this valid signature event, when we + # fetch the missing `prev_event`s, we want to make sure that we + # backoff and don't try and fetch `pulled_event_without_signatures` + # again since we know it just had an invalid signature. + pulled_event.get_pdu_json(), + ], + } + ) + + # Keep track of the count and make sure we don't make any of these requests + event_endpoint_requested_count = 0 + room_state_ids_endpoint_requested_count = 0 + room_state_endpoint_requested_count = 0 + + async def get_event( + destination: str, event_id: str, timeout: Optional[int] = None + ) -> None: + nonlocal event_endpoint_requested_count + event_endpoint_requested_count += 1 + + async def get_room_state_ids( + destination: str, room_id: str, event_id: str + ) -> None: + nonlocal room_state_ids_endpoint_requested_count + room_state_ids_endpoint_requested_count += 1 + + async def get_room_state( + room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> None: + nonlocal room_state_endpoint_requested_count + room_state_endpoint_requested_count += 1 + + # We don't expect an outbound request to `/event`, `/state_ids`, or `/state` in + # the happy path but if the logic is sneaking around what we expect, stub that + # out so we can detect that failure + self.mock_federation_transport_client.get_event.side_effect = get_event + self.mock_federation_transport_client.get_room_state_ids.side_effect = ( + get_room_state_ids + ) + self.mock_federation_transport_client.get_room_state.side_effect = ( + get_room_state + ) + + # The function under test: try to backfill and process the pulled event + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler().backfill( + self.OTHER_SERVER_NAME, + room_id, + limit=1, + extremities=["$some_extremity"], + ) + ) + + if event_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /event in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + if room_state_ids_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /state_ids in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + if room_state_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /state in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + # Make sure we only recorded a single failure which corresponds to the signature + # failure initially in `_check_sigs_and_hash_for_pulled_events_and_fetch` before + # we process all of the pulled events. + backfill_num_attempts_for_event_without_signatures = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event_without_signatures.event_id}, + retcol="num_attempts", + ) + ) + self.assertEqual(backfill_num_attempts_for_event_without_signatures, 1) + + # And make sure we didn't record a failure for the event that has the missing + # prev_event because we don't want to cause a cascade of failures. Not being + # able to fetch the `prev_events` just means we won't be able to de-outlier the + # pulled event. But we can still use an `outlier` in the state/auth chain for + # another event. So we shouldn't stop a downstream event from trying to pull it. + self.get_failure( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + ), + # StoreError: 404: No row found + StoreError, + ) + def test_process_pulled_event_with_rejected_missing_state(self) -> None: """Ensure that we correctly handle pulled events with missing state containing a rejected state event diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 59b8910907..853db930d6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -27,6 +27,8 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.events import _EventInternalMetadata +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict @@ -43,6 +45,12 @@ class _BackfillSetupInfo: class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main @@ -1122,6 +1130,62 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertEqual(backfill_event_ids, ["insertion_eventA"]) + def test_get_event_ids_to_not_pull_from_backoff( + self, + ): + """ + Test to make sure only event IDs we should backoff from are returned. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "$failed_event_id", "fake cause" + ) + ) + + event_ids_to_backoff = self.get_success( + self.store.get_event_ids_to_not_pull_from_backoff( + room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] + ) + ) + + self.assertEqual(event_ids_to_backoff, ["$failed_event_id"]) + + def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( + self, + ): + """ + Test to make sure no event IDs are returned after the backoff duration has + elapsed. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "$failed_event_id", "fake cause" + ) + ) + + # Now advance time by 2 hours so we wait long enough for the single failed + # attempt (2^1 hours). + self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) + + event_ids_to_backoff = self.get_success( + self.store.get_event_ids_to_not_pull_from_backoff( + room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] + ) + ) + # Since this function only returns events we should backoff from, time has + # elapsed past the backoff range so there is no events to backoff from. + self.assertEqual(event_ids_to_backoff, []) + @attr.s class FakeEvent: -- cgit 1.5.1 From dc02d9f8c54576d4b41ce51a2704fdd43b582d66 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 Oct 2022 10:33:35 +0100 Subject: Avoid checking the event cache when backfilling events (#14164) --- changelog.d/14164.bugfix | 1 + synapse/handlers/federation_event.py | 47 ++++++++--- synapse/storage/databases/main/events_worker.py | 2 +- tests/handlers/test_federation.py | 105 +++++++++++++++++++++++- 4 files changed, 140 insertions(+), 15 deletions(-) create mode 100644 changelog.d/14164.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/14164.bugfix b/changelog.d/14164.bugfix new file mode 100644 index 0000000000..aed4d9e386 --- /dev/null +++ b/changelog.d/14164.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.30.0 where purging and rejoining a room without restarting in-between would result in a broken room. \ No newline at end of file diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 4300e8dd40..06e41b5cc0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -798,9 +798,42 @@ class FederationEventHandler: ], ) + # Check if we already any of these have these events. + # Note: we currently make a lookup in the database directly here rather than + # checking the event cache, due to: + # https://github.com/matrix-org/synapse/issues/13476 + existing_events_map = await self._store._get_events_from_db( + [event.event_id for event in events] + ) + + new_events = [] + for event in events: + event_id = event.event_id + + # If we've already seen this event ID... + if event_id in existing_events_map: + existing_event = existing_events_map[event_id] + + # ...and the event itself was not previously stored as an outlier... + if not existing_event.event.internal_metadata.is_outlier(): + # ...then there's no need to persist it. We have it already. + logger.info( + "_process_pulled_event: Ignoring received event %s which we " + "have already seen", + event.event_id, + ) + continue + + # While we have seen this event before, it was stored as an outlier. + # We'll now persist it as a non-outlier. + logger.info("De-outliering event %s", event_id) + + # Continue on with the events that are new to us. + new_events.append(event) + # We want to sort these by depth so we process them and # tell clients about them in order. - sorted_events = sorted(events, key=lambda x: x.depth) + sorted_events = sorted(new_events, key=lambda x: x.depth) for ev in sorted_events: with nested_logging_context(ev.event_id): await self._process_pulled_event(origin, ev, backfilled=backfilled) @@ -852,18 +885,6 @@ class FederationEventHandler: event_id = event.event_id - existing = await self._store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "_process_pulled_event: Ignoring received event %s which we have already seen", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - try: self._sanity_check_event(event) except SynapseError as err: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cfd4780add..7bc7f2f33e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -374,7 +374,7 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - The event, or None if the event was not found. + The event, or None if the event was not found and allow_none is `True`. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 745750b1d7..d00c69c229 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -19,7 +19,13 @@ from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + LimitExceededError, + NotFoundError, + SynapseError, +) from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.federation.federation_base import event_from_pdu_json @@ -28,6 +34,7 @@ from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer +from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.util import Clock from synapse.util.stringutils import random_string @@ -322,6 +329,102 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) self.get_success(d) + def test_backfill_ignores_known_events(self) -> None: + """ + Tests that events that we already know about are ignored when backfilling. + """ + # Set up users + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # Create a room to backfill events into + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + # Build an event to backfill + event = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"body": "hello world", "msgtype": "m.text"}, + "room_id": room_id, + "sender": other_user, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + # Ensure the event is not already in the DB + self.get_failure( + self.store.get_event(event.event_id), + NotFoundError, + ) + + # Backfill the event and check that it has entered the DB. + + # We mock out the FederationClient.backfill method, to pretend that a remote + # server has returned our fake event. + federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) + self.hs.get_federation_client().backfill = federation_client_backfill_mock + + # We also mock the persist method with a side effect of itself. This allows us + # to track when it has been called while preserving its function. + persist_events_and_notify_mock = Mock( + side_effect=self.hs.get_federation_event_handler().persist_events_and_notify + ) + self.hs.get_federation_event_handler().persist_events_and_notify = ( + persist_events_and_notify_mock + ) + + # Small side-tangent. We populate the event cache with the event, even though + # it is not yet in the DB. This is an invalid scenario that can currently occur + # due to not properly invalidating the event cache. + # See https://github.com/matrix-org/synapse/issues/13476. + # + # As a result, backfill should not rely on the event cache to check whether + # we already have an event in the DB. + # TODO: Remove this bit when the event cache is properly invalidated. + cache_entry = EventCacheEntry( + event=event, + redacted_event=None, + ) + self.store._get_event_cache.set_local((event.event_id,), cache_entry) + + # We now call FederationEventHandler.backfill (a separate method) to trigger + # a backfill request. It should receive the fake event. + self.get_success( + self.hs.get_federation_event_handler().backfill( + other_user, + room_id, + limit=10, + extremities=[], + ) + ) + + # Check that our fake event was persisted. + persist_events_and_notify_mock.assert_called_once() + persist_events_and_notify_mock.reset_mock() + + # Now we repeat the backfill, having the homeserver receive the fake event + # again. + self.get_success( + self.hs.get_federation_event_handler().backfill( + other_user, + room_id, + limit=10, + extremities=[], + ), + ) + + # This time, we expect no event persistence to have occurred, as we already + # have this event. + persist_events_and_notify_mock.assert_not_called() + @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) -- cgit 1.5.1 From 9192d74b0bf2f87b00d3e106a18baa9ce27acda1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Oct 2022 16:25:02 +0200 Subject: Refactor OIDC tests to better mimic an actual OIDC provider. (#13910) This implements a fake OIDC server, which intercepts calls to the HTTP client. Improves accuracy of tests by covering more internal methods. One particular example was the ID token validation, which previously mocked. This uncovered an incorrect dependency: Synapse actually requires at least authlib 0.15.1, not 0.14.0. --- changelog.d/13910.misc | 1 + pyproject.toml | 2 +- synapse/handlers/oidc.py | 15 +- tests/federation/test_federation_client.py | 36 +- tests/handlers/test_oidc.py | 580 +++++++++++++---------------- tests/rest/client/test_auth.py | 32 +- tests/rest/client/test_login.py | 40 +- tests/rest/client/utils.py | 136 +++---- tests/test_utils/__init__.py | 40 +- tests/test_utils/oidc.py | 325 ++++++++++++++++ 10 files changed, 747 insertions(+), 460 deletions(-) create mode 100644 changelog.d/13910.misc create mode 100644 tests/test_utils/oidc.py (limited to 'tests/handlers') diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc new file mode 100644 index 0000000000..e906952aab --- /dev/null +++ b/changelog.d/13910.misc @@ -0,0 +1 @@ +Refactor OIDC tests to better mimic an actual OIDC provider. diff --git a/pyproject.toml b/pyproject.toml index 6ebac41ed1..7e0feb75aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true } psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true } pysaml2 = { version = ">=4.5.0", optional = true } -authlib = { version = ">=0.14.0", optional = true } +authlib = { version = ">=0.15.1", optional = true } # systemd-python is necessary for logging to the systemd journal via # `systemd.journal.JournalHandler`, as is documented in # `contrib/systemd/log_config.yaml`. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d7a8226900..9759daf043 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -275,6 +275,7 @@ class OidcProvider: provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -673,6 +674,13 @@ class OidcProvider: Returns: The decoded claims in the ID token. """ + id_token = token.get("id_token") + logger.debug("Attempting to decode JWT id_token %r", id_token) + + # That has been theoritically been checked by the caller, so even though + # assertion are not enabled in production, it is mainly here to appease mypy + assert id_token is not None + metadata = await self.load_metadata() claims_params = { "nonce": nonce, @@ -688,9 +696,6 @@ class OidcProvider: claim_options = {"iss": {"values": [metadata["issuer"]]}} - id_token = token["id_token"] - logger.debug("Attempting to decode JWT id_token %r", id_token) - # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() @@ -715,7 +720,9 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew return claims diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index a538215931..51d3bb8fff 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest import mock import twisted.web.client from twisted.internet import defer -from twisted.internet.protocol import Protocol -from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions @@ -26,10 +23,9 @@ from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import event_injection +from tests.test_utils import FakeResponse, event_injection from tests.unittest import FederatingHomeserverTestCase @@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "pdus": [ create_event_dict, member_event_dict, @@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, "pdus": [ @@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # We expect an outbound request to /backfill, so stub that out self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, # Mimic the other server returning our new `pulled_event` @@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase): # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the # other from "yet.another.server" self.assertEqual(backfill_num_attempts, 2) - - -def _mock_response(resp: JsonDict): - body = json.dumps(resp).encode("utf-8") - - def deliver_body(p: Protocol): - p.dataReceived(body) - p.connectionLost(Failure(twisted.web.client.ResponseDone())) - - response = mock.Mock( - code=200, - phrase=b"OK", - headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), - length=len(body), - deliverBody=deliver_body, - ) - mock.seal(response) - return response diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e6cd3af7b7..5955410524 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os -from typing import Any, Dict +from typing import Any, Dict, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -22,12 +21,15 @@ import pymacaroons from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException +from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util import Clock -from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon +from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.stringutils import random_string from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config try: @@ -46,12 +48,6 @@ BASE_URL = "https://synapse/" CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] -AUTHORIZATION_ENDPOINT = ISSUER + "authorize" -TOKEN_ENDPOINT = ISSUER + "token" -USERINFO_ENDPOINT = ISSUER + "userinfo" -WELL_KNOWN = ISSUER + ".well-known/openid-configuration" -JWKS_URI = ISSUER + ".well-known/jwks.json" - # config for common cases DEFAULT_CONFIG = { "enabled": True, @@ -66,9 +62,9 @@ DEFAULT_CONFIG = { EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, + "authorization_endpoint": ISSUER + "authorize", + "token_endpoint": ISSUER + "token", + "jwks_uri": ISSUER + "jwks", } @@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url: str) -> JsonDict: - # Mock get_json calls to handle jwks & oidc discovery endpoints - if url == WELL_KNOWN: - # Minimal discovery document, as defined in OpenID.Discovery - # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata - return { - "issuer": ISSUER, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, - "userinfo_endpoint": USERINFO_ENDPOINT, - "response_types_supported": ["code"], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"], - } - elif url == JWKS_URI: - return {"keys": []} - - return {} - - def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = b"Synapse Test" + self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER) - hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + hs = self.setup_test_homeserver() + self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) + self.hs_patcher.start() self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase): # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 + auth_handler = hs.get_auth_handler() + # Mock the complete SSO login method. + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + return hs + def tearDown(self) -> None: + self.hs_patcher.stop() + return super().tearDown() + + def reset_mocks(self): + """Reset all the Mocks.""" + self.fake_server.reset_mocks() + self.render_error.reset_mock() + self.complete_sso_login.reset_mock() + def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - async def patched_get_json(uri): - res = await get_json(uri) - if uri == WELL_KNOWN: - res.update(values) - return res + metadata = self.fake_server.get_metadata() + metadata.update(values) + return patch.object(self.fake_server, "get_metadata", return_value=metadata) - return patch.object(self.http_client, "get_json", patched_get_json) + def start_authorization( + self, + userinfo: dict, + client_redirect_url: str = "http://client/redirect", + scope: str = "openid", + with_sid: bool = False, + ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]: + """Start an authorization request, and get the callback request back.""" + nonce = random_string(10) + state = random_string(10) + + code, grant = self.fake_server.start_authorization( + userinfo=userinfo, + scope=scope, + client_id=self.provider._client_auth.client_id, + redirect_uri=self.provider._callback_url, + nonce=nonce, + with_sid=with_sid, + ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + return _build_callback_request(code, state, session), grant def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.fake_server.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, ISSUER) - self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) - self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) - self.assertEqual(metadata.jwks_uri, JWKS_URI) - # FIXME: it seems like authlib does not have that defined in its metadata models - # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + self.assertEqual(metadata.issuer, self.fake_server.issuer) + self.assertEqual( + metadata.authorization_endpoint, + self.fake_server.authorization_endpoint, + ) + self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_server.userinfo_endpoint, + ) # subsequent calls should be cached - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() - @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - self.assertEqual(jwks, {"keys": []}) + self.fake_server.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_server.get_jwks()) # subsequent calls should be cached… - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_jwks_handler.assert_not_called() # …unless forced - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.fake_server.get_jwks_handler.assert_called_once() - # Throw if the JWKS uri is missing - original = self.provider.load_metadata - - async def patched_load_metadata(): - m = (await original()).copy() - m.update({"jwks_uri": None}) - return m - - with patch.object(self.provider, "load_metadata", patched_load_metadata): + with self.metadata_edit({"jwks_uri": None}): + # If we don't do this, the load_metadata call will throw because of the + # missing jwks_uri + self.provider._user_profile_method = "userinfo_endpoint" + self.get_success(self.provider.load_metadata(force=True)) self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + auth_endpoint = urlparse(self.fake_server.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase): with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } username = "bar" userinfo = { "sub": "foo", "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - code = "code" - state = "state" - nonce = "nonce" client_redirect_url = "http://client/redirect" - ip_address = "10.0.0.1" - session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request(code, state, session, ip_address=ip_address) - + request, _ = self.start_authorization( + userinfo, client_redirect_url=client_redirect_url + ) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, client_redirect_url, None, new_user=True, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_not_called() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors + request, _ = self.start_authorization(userinfo) with patch.object( self.provider, "_remote_id_from_userinfo", @@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - auth_handler.complete_sso_login.reset_mock() - self.provider._exchange_code.reset_mock() - self.provider._parse_id_token.reset_mock() - self.provider._fetch_userinfo.reset_mock() + self.reset_mocks() # With userinfo fetching self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + # Without the "openid" scope, the FakeProvider does not generate an id_token + request, _ = self.start_authorization(userinfo, scope="") self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_not_called() - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() + self.reset_mocks() + # With an ID token, userinfo fetching and sid in the ID token self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() + request, grant = self.start_authorization(userinfo, with_sid=True) + self.assertIsNotNone(grant.sid) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, - auth_provider_session_id=id_token["sid"], + auth_provider_session_id=grant.sid, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(userinfo=True): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") - # Handle code exchange failure - from synapse.handlers.oidc import OidcError - - self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] - raises=OidcError("invalid_request") - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("invalid_request") + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(token=True): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("server_error") @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self) -> None: @@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" - token = {"type": "bearer"} - token_json = json.dumps(token).encode("utf-8") - self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad Request", - body=b'{"error": "foo", "error_description": "bar"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b"Not JSON", - ) + self.fake_server.post_token_handler.return_value = FakeResponse( + code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b'{"error": "internal_server_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=500, payload={"error": "internal_server_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad request", - body=b"{}", - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, - phrase=b"OK", - body=b'{"error": "some_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase): """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" @@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) @@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) @@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase): """ Login while using a mapping provider that implements get_extra_attributes. """ - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } userinfo = { "sub": "foo", "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - - state = "state" - client_redirect_url = "http://client/redirect" - session = self._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ) - request = _build_callback_request("code", state, session) - + request, _ = self.start_authorization(userinfo) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@foo:test", - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, {"phone": "1234567"}, new_user=True, auth_provider_session_id=None, @@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - userinfo: dict = { "sub": "test_user", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Test if the mxid is already taken store = self.hs.get_datastores().main @@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Mapping provider does not support de-duplicating Matrix IDs", @@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Subsequent calls should map to the same mxid. - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( args[2].startswith( @@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, @@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" - self.get_success( - _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) - ) + userinfo = {"sub": "test2", "username": "föö"} + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@test_user1:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) @@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # userinfo lacking "test": "foobar" attribute should fail. userinfo = { "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": "foobar" attribute should succeed. userinfo = { @@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": "foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. userinfo = { "sub": "tester", "username": "tester", "test": ["foobar", "foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase): Test that auth fails if attributes exist but don't match, or are non-string values. """ - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": ["foo", "bar"] attribute should fail userinfo = { @@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": ["foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": False attribute should fail # this is largely just to ensure we don't crash here @@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": False, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": None attribute should fail # a value of None breaks the OIDC spec, but it's important to not crash here @@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 1 attribute should fail # this is largely just to ensure we don't crash here @@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 1, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 3.14 attribute should fail # this is largely just to ensure we don't crash here @@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 3.14, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() def _generate_oidc_session_token( self, @@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return self.handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", + idp_id=self.provider.idp_id, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -async def _make_callback_with_userinfo( - hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" -) -> None: - """Mock up an OIDC callback with the given userinfo dict - - We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, - and poke in the userinfo dict as if it were the response to an OIDC userinfo call. - - Args: - hs: the HomeServer impl to send the callback to. - userinfo: the OIDC userinfo dict - client_redirect_url: the URL to redirect to on success. - """ - - handler = hs.get_oidc_handler() - provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] - provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - - state = "state" - session = handler._macaroon_generator.generate_oidc_session_token( - state=state, - session_data=OidcSessionData( - idp_id="oidc", - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id="", - ), - ) - request = _build_callback_request("code", state, session) - - await handler.handle_oidc_callback(request) - - def _build_callback_request( code: str, state: str, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 090cef5216..ebf653d018 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase): * checking that the original operation succeeds """ + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase): # run the UIA-via-SSO flow session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - login_resp = self.helper.login_via_oidc("username") + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) channel = self.delete_device( @@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" + + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device @@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message @@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase): """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - login_resp = self.helper.login_via_oidc("username", expected_status=403) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, "username", expected_status=403 + ) self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) self.assertEqual( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e801ba8c8b..ff5baa9f0a 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -36,7 +36,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 -from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): TEST_CLIENT_REDIRECT_URL, ) - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + channel, _ = self.helper.complete_oidc_auth( + fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} + ) # that should serve a confirmation page self.assertEqual(channel.code, 200, channel.result) @@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request("oidc") + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" + fake_oidc_server = self.helper.fake_oidc_server() + # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, + {"sub": "tester", "displayname": "Jonny"}, + TEST_CLIENT_REDIRECT_URL, ) # that should redirect to the username picker diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index c249a42bb6..967d229223 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,7 +31,6 @@ from typing import ( Tuple, overload, ) -from unittest.mock import patch from urllib.parse import urlencode import attr @@ -46,8 +45,19 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_ISSUER = "https://issuer.test/" +TEST_OIDC_CONFIG = { + "enabled": True, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} @attr.s(auto_attribs=True) @@ -543,12 +553,28 @@ class RestHelper: return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: + """Create a ``FakeOidcServer``. + + This can be used in conjuction with ``login_via_oidc``:: + + fake_oidc_server = self.helper.fake_oidc_server() + login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user") + """ + + return FakeOidcServer( + clock=self.hs.get_clock(), + issuer=issuer, + ) + def login_via_oidc( self, + fake_server: FakeOidcServer, remote_user_id: str, + with_sid: bool = False, expected_status: int = 200, - ) -> JsonDict: - """Log in via OIDC + ) -> Tuple[JsonDict, FakeAuthorizationGrant]: + """Log in (as a new user) via OIDC Returns the result of the final token login. @@ -560,7 +586,10 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + userinfo = {"sub": remote_user_id} + channel, grant = self.auth_via_oidc( + fake_server, userinfo, client_redirect_url, with_sid=with_sid + ) # expect a confirmation page assert channel.code == HTTPStatus.OK, channel.result @@ -585,14 +614,16 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body + return channel.json_body, grant def auth_via_oidc( self, + fake_server: FakeOidcServer, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. This can be used for either login or user-interactive auth. @@ -616,6 +647,7 @@ class RestHelper: the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -625,14 +657,15 @@ class RestHelper: cookies: Dict[str, str] = {} - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + with fake_server.patch_homeserver(hs=self.hs): + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -640,17 +673,21 @@ class RestHelper: # that synapse passes to the client. oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + assert oauth_uri_path == fake_server.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + return self.complete_oidc_auth( + fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid + ) def complete_oidc_auth( self, + fake_serer: FakeOidcServer, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Mock out an OIDC authentication flow Assumes that an OIDC auth has been initiated by one of initiate_sso_login or @@ -661,50 +698,37 @@ class RestHelper: Requires the OIDC callback resource to be mounted at the normal place. Args: + fake_server: the fake OIDC server with which the auth should be done oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, from initiate_sso_login or initiate_sso_ui_auth). cookies: the cookies set by synapse's redirect endpoint, which will be sent back to the callback endpoint. user_info_dict: the remote userinfo that the OIDC provider should present. Typically this should be '{"sub": ""}'. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. """ _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) + + code, grant = fake_serer.start_authorization( + scope=params["scope"][0], + userinfo=user_info_dict, + client_id=params["client_id"][0], + redirect_uri=params["redirect_uri"][0], + nonce=params["nonce"][0], + with_sid=with_sid, + ) + state = params["state"][0] + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + urllib.parse.urlencode({"state": state, "code": code}), ) - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req( - method: str, - uri: str, - data: Optional[dict] = None, - headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=HTTPStatus.OK, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( self.hs.get_reactor(), @@ -715,7 +739,7 @@ class RestHelper: ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) - return channel + return channel, grant def initiate_sso_login( self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] @@ -806,21 +830,3 @@ class RestHelper: assert len(p.links) == 1, "not exactly one link in confirmation page" oauth_uri = p.links[0] return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 0d0d6faf0d..e62ebcc6a5 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -15,17 +15,24 @@ """ Utilities for running the unit tests """ +import json import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, Tuple, TypeVar from unittest.mock import Mock import attr +import zope.interface from twisted.python.failure import Failure from twisted.web.client import ResponseDone +from twisted.web.http import RESPONSES +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.types import JsonDict TV = TypeVar("TV") @@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock: return Mock(side_effect=cb) -@attr.s -class FakeResponse: +# Type ignore: it does not fully implement IResponse, but is good enough for tests +@zope.interface.implementer(IResponse) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeResponse: # type: ignore[misc] """A fake twisted.web.IResponse object there is a similar class at treq.test.test_response, but it lacks a `phrase` attribute, and didn't support deliverBody until recently. """ - # HTTP response code - code = attr.ib(type=int) + version: Tuple[bytes, int, int] = (b"HTTP", 1, 1) - # HTTP response phrase (eg b'OK' for a 200) - phrase = attr.ib(type=bytes) + # HTTP response code + code: int = 200 # body of the response - body = attr.ib(type=bytes) + body: bytes = b"" + + headers: Headers = attr.Factory(Headers) + + @property + def phrase(self): + return RESPONSES.get(self.code, b"Unknown Status") + + @property + def length(self): + return len(self.body) def deliverBody(self, protocol): protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) + @classmethod + def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + headers = Headers({"Content-Type": ["application/json"]}) + body = json.dumps(payload).encode("utf-8") + return cls(code=code, body=body, headers=headers) + # A small image used in some tests. # diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py new file mode 100644 index 0000000000..de134bbc89 --- /dev/null +++ b/tests/test_utils/oidc.py @@ -0,0 +1,325 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import Mock, patch +from urllib.parse import parse_qs + +import attr + +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.server import HomeServer +from synapse.util import Clock +from synapse.util.stringutils import random_string + +from tests.test_utils import FakeResponse + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeAuthorizationGrant: + userinfo: dict + client_id: str + redirect_uri: str + scope: str + nonce: Optional[str] + sid: Optional[str] + + +class FakeOidcServer: + """A fake OpenID Connect Provider.""" + + # All methods here are mocks, so we can track when they are called, and override + # their values + request: Mock + get_jwks_handler: Mock + get_metadata_handler: Mock + get_userinfo_handler: Mock + post_token_handler: Mock + + def __init__(self, clock: Clock, issuer: str): + from authlib.jose import ECKey, KeySet + + self._clock = clock + self.issuer = issuer + + self.request = Mock(side_effect=self._request) + self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler) + self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler) + self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler) + self.post_token_handler = Mock(side_effect=self._post_token_handler) + + # A code -> grant mapping + self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + # An access token -> grant mapping + self._sessions: Dict[str, FakeAuthorizationGrant] = {} + + # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for + # signing JWTs. ECDSA keys are really quick to generate compared to RSA. + self._key = ECKey.generate_key(crv="P-256", is_private=True) + self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))]) + + self._id_token_overrides: Dict[str, Any] = {} + + def reset_mocks(self): + self.request.reset_mock() + self.get_jwks_handler.reset_mock() + self.get_metadata_handler.reset_mock() + self.get_userinfo_handler.reset_mock() + self.post_token_handler.reset_mock() + + def patch_homeserver(self, hs: HomeServer): + """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. + + This patch should be used whenever the HS is expected to perform request to the + OIDC provider, e.g.:: + + fake_oidc_server = self.helper.fake_oidc_server() + with fake_oidc_server.patch_homeserver(hs): + self.make_request("GET", "/_matrix/client/r0/login/sso/redirect") + """ + return patch.object(hs.get_proxied_http_client(), "request", self.request) + + @property + def authorization_endpoint(self) -> str: + return self.issuer + "authorize" + + @property + def token_endpoint(self) -> str: + return self.issuer + "token" + + @property + def userinfo_endpoint(self) -> str: + return self.issuer + "userinfo" + + @property + def metadata_endpoint(self) -> str: + return self.issuer + ".well-known/openid-configuration" + + @property + def jwks_uri(self) -> str: + return self.issuer + "jwks" + + def get_metadata(self) -> dict: + return { + "issuer": self.issuer, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "jwks_uri": self.jwks_uri, + "userinfo_endpoint": self.userinfo_endpoint, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["ES256"], + } + + def get_jwks(self) -> dict: + return self._jwks.as_dict() + + def get_userinfo(self, access_token: str) -> Optional[dict]: + """Given an access token, get the userinfo of the associated session.""" + session = self._sessions.get(access_token, None) + if session is None: + return None + return session.userinfo + + def _sign(self, payload: dict) -> str: + from authlib.jose import JsonWebSignature + + jws = JsonWebSignature() + kid = self.get_jwks()["keys"][0]["kid"] + protected = {"alg": "ES256", "kid": kid} + json_payload = json.dumps(payload) + return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") + + def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: + now = self._clock.time() + id_token = { + **grant.userinfo, + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "nbf": now, + "exp": now + 600, + } + + if grant.nonce is not None: + id_token["nonce"] = grant.nonce + + if grant.sid is not None: + id_token["sid"] = grant.sid + + id_token.update(self._id_token_overrides) + + return self._sign(id_token) + + def id_token_override(self, overrides: dict): + """Temporarily patch the ID token generated by the token endpoint.""" + return patch.object(self, "_id_token_overrides", overrides) + + def start_authorization( + self, + client_id: str, + scope: str, + redirect_uri: str, + userinfo: dict, + nonce: Optional[str] = None, + with_sid: bool = False, + ) -> Tuple[str, FakeAuthorizationGrant]: + """Start an authorization request, and get back the code to use on the authorization endpoint.""" + code = random_string(10) + sid = None + if with_sid: + sid = random_string(10) + + grant = FakeAuthorizationGrant( + userinfo=userinfo, + scope=scope, + redirect_uri=redirect_uri, + nonce=nonce, + client_id=client_id, + sid=sid, + ) + self._authorization_grants[code] = grant + + return code, grant + + def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: + grant = self._authorization_grants.pop(code, None) + if grant is None: + return None + + access_token = random_string(10) + self._sessions[access_token] = grant + + token = { + "token_type": "Bearer", + "access_token": access_token, + "expires_in": 3600, + "scope": grant.scope, + } + + if "openid" in grant.scope: + token["id_token"] = self.generate_id_token(grant) + + return dict(token) + + def buggy_endpoint( + self, + *, + jwks: bool = False, + metadata: bool = False, + token: bool = False, + userinfo: bool = False, + ): + """A context which makes a set of endpoints return a 500 error. + + Args: + jwks: If True, makes the JWKS endpoint return a 500 error. + metadata: If True, makes the OIDC Discovery endpoint return a 500 error. + token: If True, makes the token endpoint return a 500 error. + userinfo: If True, makes the userinfo endpoint return a 500 error. + """ + buggy = FakeResponse(code=500, body=b"Internal server error") + + patches = {} + if jwks: + patches["get_jwks_handler"] = Mock(return_value=buggy) + if metadata: + patches["get_metadata_handler"] = Mock(return_value=buggy) + if token: + patches["post_token_handler"] = Mock(return_value=buggy) + if userinfo: + patches["get_userinfo_handler"] = Mock(return_value=buggy) + + return patch.multiple(self, **patches) + + async def _request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """The override of the SimpleHttpClient#request() method""" + access_token: Optional[str] = None + + if headers is None: + headers = Headers() + + # Try to find the access token in the headers if any + auth_headers = headers.getRawHeaders(b"Authorization") + if auth_headers: + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + access_token = parts[1].decode("ascii") + + if method == "POST": + # If the method is POST, assume it has an url-encoded body + if data is None or headers.getRawHeaders(b"Content-Type") != [ + b"application/x-www-form-urlencoded" + ]: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + params = parse_qs(data.decode("utf-8")) + + if uri == self.token_endpoint: + # Even though this endpoint should be protected, this does not check + # for client authentication. We're not checking it for simplicity, + # and because client authentication is tested in other standalone tests. + return self.post_token_handler(params) + + elif method == "GET": + if uri == self.jwks_uri: + return self.get_jwks_handler() + elif uri == self.metadata_endpoint: + return self.get_metadata_handler() + elif uri == self.userinfo_endpoint: + return self.get_userinfo_handler(access_token=access_token) + + return FakeResponse(code=404, body=b"404 not found") + + # Request handlers + def _get_jwks_handler(self) -> IResponse: + """Handles requests to the JWKS URI.""" + return FakeResponse.json(payload=self.get_jwks()) + + def _get_metadata_handler(self) -> IResponse: + """Handles requests to the OIDC well-known document.""" + return FakeResponse.json(payload=self.get_metadata()) + + def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse: + """Handles requests to the userinfo endpoint.""" + if access_token is None: + return FakeResponse(code=401) + user_info = self.get_userinfo(access_token) + if user_info is None: + return FakeResponse(code=401) + + return FakeResponse.json(payload=user_info) + + def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse: + """Handles requests to the token endpoint.""" + code = params.get("code", []) + + if len(code) != 1: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + grant = self.exchange_code(code=code[0]) + if grant is None: + return FakeResponse.json(code=400, payload={"error": "invalid_grant"}) + + return FakeResponse.json(payload=grant) -- cgit 1.5.1 From 8756d5c87efc5637da55c9e21d2a4eb2369ba693 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 26 Oct 2022 12:45:41 +0200 Subject: Save login tokens in database (#13844) * Save login tokens in database Signed-off-by: Quentin Gliech * Add upgrade notes * Track login token reuse in a Prometheus metric Signed-off-by: Quentin Gliech --- changelog.d/13844.misc | 1 + docs/upgrade.md | 9 ++ synapse/handlers/auth.py | 64 +++++++-- synapse/module_api/__init__.py | 41 +----- synapse/rest/client/login.py | 3 +- synapse/rest/client/login_token_request.py | 5 +- synapse/storage/databases/main/registration.py | 156 ++++++++++++++++++++- .../schema/main/delta/73/10login_tokens.sql | 35 +++++ synapse/util/macaroons.py | 87 +----------- tests/handlers/test_auth.py | 135 ++++++++++-------- tests/util/test_macaroons.py | 28 ---- 11 files changed, 337 insertions(+), 227 deletions(-) create mode 100644 changelog.d/13844.misc create mode 100644 synapse/storage/schema/main/delta/73/10login_tokens.sql (limited to 'tests/handlers') diff --git a/changelog.d/13844.misc b/changelog.d/13844.misc new file mode 100644 index 0000000000..66f4414df7 --- /dev/null +++ b/changelog.d/13844.misc @@ -0,0 +1 @@ +Save login tokens in database and prevent login token reuse. diff --git a/docs/upgrade.md b/docs/upgrade.md index b81385b191..78c34d0c15 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.71.0 + +## Removal of the `generate_short_term_login_token` module API method + +As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_short_term_login_token-module-api-method), the deprecated `generate_short_term_login_token` module method has been removed. + +Modules relying on it can instead use the `create_login_token` method. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f5f0e0e7a7..8b9ef25d29 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -38,6 +38,7 @@ from typing import ( import attr import bcrypt import unpaddedbase64 +from prometheus_client import Counter from twisted.internet.defer import CancelledError from twisted.web.server import Request @@ -48,6 +49,7 @@ from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, LoginError, + NotFoundError, StoreError, SynapseError, UserDeactivatedError, @@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.registration import ( + LoginTokenExpired, + LoginTokenLookupResult, + LoginTokenReused, +) from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import delay_cancellation, maybe_awaitable -from synapse.util.macaroons import LoginTokenAttributes from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email @@ -80,6 +86,12 @@ logger = logging.getLogger(__name__) INVALID_USERNAME_OR_PASSWORD = "Invalid username or password" +invalid_login_token_counter = Counter( + "synapse_user_login_invalid_login_tokens", + "Counts the number of rejected m.login.token on /login", + ["reason"], +) + def convert_client_dict_legacy_fields_to_identifier( submission: JsonDict, @@ -883,6 +895,25 @@ class AuthHandler: return True + async def create_login_token_for_user_id( + self, + user_id: str, + duration_ms: int = (2 * 60 * 1000), + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, + ) -> str: + login_token = self.generate_login_token() + now = self._clock.time_msec() + expiry_ts = now + duration_ms + await self.store.add_login_token_to_user( + user_id=user_id, + token=login_token, + expiry_ts=expiry_ts, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + return login_token + async def create_refresh_token_for_user_id( self, user_id: str, @@ -1401,6 +1432,18 @@ class AuthHandler: return None return user_id + def generate_login_token(self) -> str: + """Generates an opaque string, for use as an short-term login token""" + + # we use the following format for access tokens: + # syl__ + + random_string = stringutils.random_string(20) + base = f"syl_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + def generate_access_token(self, for_user: UserID) -> str: """Generates an opaque string, for use as an access token""" @@ -1427,16 +1470,17 @@ class AuthHandler: crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) return f"{base}_{crc}" - async def validate_short_term_login_token( - self, login_token: str - ) -> LoginTokenAttributes: + async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult: try: - res = self.macaroon_gen.verify_short_term_login_token(login_token) - except Exception: - raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) + return await self.store.consume_login_token(login_token) + except LoginTokenExpired: + invalid_login_token_counter.labels("expired").inc() + except LoginTokenReused: + invalid_login_token_counter.labels("reused").inc() + except NotFoundError: + invalid_login_token_counter.labels("not found").inc() - await self.auth_blocking.check_auth_blocking(res.user_id) - return res + raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) async def delete_access_token(self, access_token: str) -> None: """Invalidate a single access token @@ -1711,7 +1755,7 @@ class AuthHandler: ) # Create a login token - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.create_login_token_for_user_id( registered_user_id, auth_provider_id=auth_provider_id, auth_provider_session_id=auth_provider_session_id, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6a6ae208d1..30e689d00d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -771,50 +771,11 @@ class ModuleApi: auth_provider_session_id: The session ID got during login from the SSO IdP, if any. """ - # The deprecated `generate_short_term_login_token` method defaulted to an empty - # string for the `auth_provider_id` because of how the underlying macaroon was - # generated. This will change to a proper NULL-able field when the tokens get - # moved to the database. - return self._hs.get_macaroon_generator().generate_short_term_login_token( + return await self._hs.get_auth_handler().create_login_token_for_user_id( user_id, - auth_provider_id or "", - auth_provider_session_id, duration_in_ms, - ) - - def generate_short_term_login_token( - self, - user_id: str, - duration_in_ms: int = (2 * 60 * 1000), - auth_provider_id: str = "", - auth_provider_session_id: Optional[str] = None, - ) -> str: - """Generate a login token suitable for m.login.token authentication - - Added in Synapse v1.9.0. - - This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will - be removed in Synapse 1.71.0. - - Args: - user_id: gives the ID of the user that the token is for - - duration_in_ms: the time that the token will be valid for - - auth_provider_id: the ID of the SSO IdP that the user used to authenticate - to get this token, if any. This is encoded in the token so that - /login can report stats on number of successful logins by IdP. - """ - logger.warn( - "A module configured on this server uses ModuleApi.generate_short_term_login_token(), " - "which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in " - "Synapse 1.71.0", - ) - return self._hs.get_macaroon_generator().generate_short_term_login_token( - user_id, auth_provider_id, auth_provider_session_id, - duration_in_ms, ) @defer.inlineCallbacks diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f554586ac3..7774f1967d 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -436,8 +436,7 @@ class LoginRestServlet(RestServlet): The body of the JSON response. """ token = login_submission["token"] - auth_handler = self.auth_handler - res = await auth_handler.validate_short_term_login_token(token) + res = await self.auth_handler.consume_login_token(token) return await self._complete_login( res.user_id, diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index 277b20fb63..43ea21d5e6 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -57,7 +57,6 @@ class LoginTokenRequestServlet(RestServlet): self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name - self.macaroon_gen = hs.get_macaroon_generator() self.auth_handler = hs.get_auth_handler() self.token_timeout = hs.config.experimental.msc3882_token_timeout self.ui_auth = hs.config.experimental.msc3882_ui_auth @@ -76,10 +75,10 @@ class LoginTokenRequestServlet(RestServlet): can_skip_ui_auth=False, # Don't allow skipping of UI auth ) - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.auth_handler.create_login_token_for_user_id( user_id=requester.user.to_string(), auth_provider_id="org.matrix.msc3882.login_token_request", - duration_in_ms=self.token_timeout, + duration_ms=self.token_timeout, ) return ( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 2996d6bb4d..0255295317 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + NotFoundError, + StoreError, + SynapseError, + ThreepidValidationError, +) from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception): because this external id is given to an other user.""" +class LoginTokenExpired(Exception): + """Exception if the login token sent expired""" + + +class LoginTokenReused(Exception): + """Exception if the login token sent was already used""" + + @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: """Result of looking up an access token. @@ -115,6 +129,20 @@ class RefreshTokenLookupResult: If None, the session can be refreshed indefinitely.""" +@attr.s(auto_attribs=True, frozen=True, slots=True) +class LoginTokenLookupResult: + """Result of looking up a login token.""" + + user_id: str + """The user this token belongs to.""" + + auth_provider_id: Optional[str] + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id: Optional[str] + """The session ID advertised by the SSO Identity Provider.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "replace_refresh_token", _replace_refresh_token_txn ) + async def add_login_token_to_user( + self, + user_id: str, + token: str, + expiry_ts: int, + auth_provider_id: Optional[str], + auth_provider_session_id: Optional[str], + ) -> None: + """Adds a short-term login token for the given user. + + Args: + user_id: The user ID. + token: The new login token to add. + expiry_ts (milliseconds since the epoch): Time after which the login token + cannot be used. + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token, if any + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider, if any. + """ + await self.db_pool.simple_insert( + "login_tokens", + { + "token": token, + "user_id": user_id, + "expiry_ts": expiry_ts, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="add_login_token_to_user", + ) + + def _consume_login_token( + self, + txn: LoggingTransaction, + token: str, + ts: int, + ) -> LoginTokenLookupResult: + values = self.db_pool.simple_select_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + retcols=( + "user_id", + "expiry_ts", + "used_ts", + "auth_provider_id", + "auth_provider_session_id", + ), + allow_none=True, + ) + + if values is None: + raise NotFoundError() + + self.db_pool.simple_update_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + updatevalues={"used_ts": ts}, + ) + user_id = values["user_id"] + expiry_ts = values["expiry_ts"] + used_ts = values["used_ts"] + auth_provider_id = values["auth_provider_id"] + auth_provider_session_id = values["auth_provider_session_id"] + + # Token was already used + if used_ts is not None: + raise LoginTokenReused() + + # Token expired + if ts > int(expiry_ts): + raise LoginTokenExpired() + + return LoginTokenLookupResult( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + async def consume_login_token(self, token: str) -> LoginTokenLookupResult: + """Lookup a login token and consume it. + + Args: + token: The login token. + + Returns: + The data stored with that token, including the `user_id`. Returns `None` if + the token does not exist or if it expired. + + Raises: + NotFound if the login token was not found in database + LoginTokenExpired if the login token expired + LoginTokenReused if the login token was already used + """ + return await self.db_pool.runInteraction( + "consume_login_token", + self._consume_login_token, + token, + self._clock.time_msec(), + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( @@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): and hs.config.experimental.msc3866.require_approval_for_new_accounts ) + # Create a background job for removing expired login tokens + if hs.config.worker.run_background_tasks: + self._clock.looping_call( + self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS + ) + async def add_access_token_to_user( self, user_id: str, @@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): approved, ) + @wrap_as_background_process("delete_expired_login_tokens") + async def _delete_expired_login_tokens(self) -> None: + """Remove login tokens with expiry dates that have passed.""" + + def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?" + txn.execute(sql, (ts,)) + + # We keep the expired tokens for an extra 5 minutes so we can measure how many + # times a token is being used after its expiry + now = self._clock.time_msec() + await self.db_pool.runInteraction( + "delete_expired_login_tokens", + _delete_expired_login_tokens_txn, + now - (5 * 60 * 1000), + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/schema/main/delta/73/10login_tokens.sql b/synapse/storage/schema/main/delta/73/10login_tokens.sql new file mode 100644 index 0000000000..a39b7bcece --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10login_tokens.sql @@ -0,0 +1,35 @@ +/* + * Copyright 2022 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Login tokens are short-lived tokens that are used for the m.login.token +-- login method, mainly during SSO logins +CREATE TABLE login_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + expiry_ts BIGINT NOT NULL, + used_ts BIGINT, + auth_provider_id TEXT, + auth_provider_session_id TEXT +); + +-- We're sometimes querying them by their session ID we got from their IDP +CREATE INDEX login_tokens_auth_provider_idx + ON login_tokens (auth_provider_id, auth_provider_session_id); + +-- We're deleting them by their expiration time +CREATE INDEX login_tokens_expiry_time_idx + ON login_tokens (expiry_ts); + diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index df77edcce2..5df03d3ddc 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -24,7 +24,7 @@ from typing_extensions import Literal from synapse.util import Clock, stringutils -MacaroonType = Literal["access", "delete_pusher", "session", "login"] +MacaroonType = Literal["access", "delete_pusher", "session"] def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: @@ -111,19 +111,6 @@ class OidcSessionData: """The session ID of the ongoing UI Auth ("" if this is a login)""" -@attr.s(slots=True, frozen=True, auto_attribs=True) -class LoginTokenAttributes: - """Data we store in a short-term login token""" - - user_id: str - - auth_provider_id: str - """The SSO Identity Provider that the user authenticated with, to get this token.""" - - auth_provider_session_id: Optional[str] - """The session ID advertised by the SSO Identity Provider.""" - - class MacaroonGenerator: def __init__(self, clock: Clock, location: str, secret_key: bytes): self._clock = clock @@ -165,35 +152,6 @@ class MacaroonGenerator: macaroon.add_first_party_caveat(f"pushkey = {pushkey}") return macaroon.serialize() - def generate_short_term_login_token( - self, - user_id: str, - auth_provider_id: str, - auth_provider_session_id: Optional[str] = None, - duration_in_ms: int = (2 * 60 * 1000), - ) -> str: - """Generate a short-term login token used during SSO logins - - Args: - user_id: The user for which the token is valid. - auth_provider_id: The SSO IdP the user used. - auth_provider_session_id: The session ID got during login from the SSO IdP. - - Returns: - A signed token valid for using as a ``m.login.token`` token. - """ - now = self._clock.time_msec() - expiry = now + duration_in_ms - macaroon = self._generate_base_macaroon("login") - macaroon.add_first_party_caveat(f"user_id = {user_id}") - macaroon.add_first_party_caveat(f"time < {expiry}") - macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}") - if auth_provider_session_id is not None: - macaroon.add_first_party_caveat( - f"auth_provider_session_id = {auth_provider_session_id}" - ) - return macaroon.serialize() - def generate_oidc_session_token( self, state: str, @@ -233,49 +191,6 @@ class MacaroonGenerator: return macaroon.serialize() - def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: - """Verify a short-term-login macaroon - - Checks that the given token is a valid, unexpired short-term-login token - minted by this server. - - Args: - token: The login token to verify. - - Returns: - A set of attributes carried by this token, including the - ``user_id`` and informations about the SSO IDP used during that - login. - - Raises: - MacaroonVerificationFailedException if the verification failed - """ - macaroon = pymacaroons.Macaroon.deserialize(token) - - v = self._base_verifier("login") - v.satisfy_general(lambda c: c.startswith("user_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) - satisfy_expiry(v, self._clock.time_msec) - v.verify(macaroon, self._secret_key) - - user_id = get_value_from_macaroon(macaroon, "user_id") - auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") - - auth_provider_session_id: Optional[str] = None - try: - auth_provider_session_id = get_value_from_macaroon( - macaroon, "auth_provider_session_id" - ) - except MacaroonVerificationFailedException: - pass - - return LoginTokenAttributes( - user_id=user_id, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, - ) - def verify_guest_token(self, token: str) -> str: """Verify a guest access token macaroon diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 7106799d44..036dbbc45b 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from unittest.mock import Mock import pymacaroons @@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import AuthError, ResourceLimitError from synapse.rest import admin +from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable class AuthTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, + login.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1 = self.register_user("a_user", "pass") + def token_login(self, token: str) -> Optional[str]: + body = { + "type": "m.login.token", + "token": token, + } + + channel = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + + if channel.code == 200: + return channel.json_body["user_id"] + + return None + def test_macaroon_caveats(self) -> None: token = self.macaroon_generator.generate_guest_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase): v.satisfy_general(verify_guest) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - def test_short_term_login_token_gives_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + def test_login_token_gives_user_id(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) + + res = self.get_success(self.auth_handler.consume_login_token(token)) self.assertEqual(self.user1, res.user_id) - self.assertEqual("", res.auth_provider_id) + self.assertEqual(None, res.auth_provider_id) - # when we advance the clock, the token should be rejected - self.reactor.advance(6) - self.get_failure( - self.auth_handler.validate_short_term_login_token(token), - AuthError, + def test_login_token_reuse_fails(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - def test_short_term_login_token_gives_auth_provider(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, auth_provider_id="my_idp" - ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) - self.assertEqual(self.user1, res.user_id) - self.assertEqual("my_idp", res.auth_provider_id) + self.get_success(self.auth_handler.consume_login_token(token)) - def test_short_term_login_token_cannot_replace_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.get_failure( + self.auth_handler.consume_login_token(token), + AuthError, ) - macaroon = pymacaroons.Macaroon.deserialize(token) - res = self.get_success( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()) + def test_login_token_expires(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - self.assertEqual(self.user1, res.user_id) - - # add another "user_id" caveat, which might allow us to override the - # user_id. - macaroon.add_first_party_caveat("user_id = b_user") + # when we advance the clock, the token should be rejected + self.reactor.advance(6) self.get_failure( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()), + self.auth_handler.consume_login_token(token), AuthError, ) + def test_login_token_gives_auth_provider(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + auth_provider_id="my_idp", + auth_provider_session_id="11-22-33-44", + duration_ms=(5 * 1000), + ) + ) + res = self.get_success(self.auth_handler.consume_login_token(token)) + self.assertEqual(self.user1, res.user_id) + self.assertEqual("my_idp", res.auth_provider_id) + self.assertEqual("11-22-33-44", res.auth_provider_session_id) + def test_mau_limits_disabled(self) -> None: self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception @@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) + def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( @@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) def test_mau_limits_parity(self) -> None: # Ensure we're not at the unix epoch. @@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase): ), ResourceLimitError, ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) # If in monthly active cohort self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( @@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1, device_id=None, valid_until_ms=None ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True @@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) - ) - - def _get_macaroon(self) -> pymacaroons.Macaroon: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) - return pymacaroons.Macaroon.deserialize(token) + self.assertIsNotNone(self.token_login(token)) diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index 32125f7bb7..40754a4711 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase): ) self.assertEqual(user_id, "@user:tesths") - def test_short_term_login_token(self): - """Test the generation and verification of short-term login tokens""" - token = self.macaroon_generator.generate_short_term_login_token( - user_id="@user:tesths", - auth_provider_id="oidc", - auth_provider_session_id="sid", - duration_in_ms=2 * 60 * 1000, - ) - - info = self.macaroon_generator.verify_short_term_login_token(token) - self.assertEqual(info.user_id, "@user:tesths") - self.assertEqual(info.auth_provider_id, "oidc") - self.assertEqual(info.auth_provider_session_id, "sid") - - # Raises with another secret key - with self.assertRaises(MacaroonVerificationFailedException): - self.other_macaroon_generator.verify_short_term_login_token(token) - - # Wait a minute - self.reactor.pump([60]) - # Shouldn't raise - self.macaroon_generator.verify_short_term_login_token(token) - # Wait another minute - self.reactor.pump([60]) - # Should raise since it expired - with self.assertRaises(MacaroonVerificationFailedException): - self.macaroon_generator.verify_short_term_login_token(token) - def test_oidc_session_token(self): """Test the generation and verification of OIDC session cookies""" state = "arandomstate" -- cgit 1.5.1 From 0cfbb3513152b8360155c2d75df50e06ea861fa4 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Wed, 26 Oct 2022 18:51:23 +0400 Subject: fix broken avatar checks when server_name contains a port (#13927) Fixes check_avatar_size_and_mime_type() to successfully update avatars on homeservers running on non-default ports which it would mistakenly treat as remote homeserver while validating the avatar's size and mime type. Signed-off-by: Ashish Kumar ashfame@users.noreply.github.com --- changelog.d/13927.bugfix | 1 + synapse/handlers/profile.py | 6 +++++- tests/handlers/test_profile.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13927.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/13927.bugfix b/changelog.d/13927.bugfix new file mode 100644 index 0000000000..119cd128e7 --- /dev/null +++ b/changelog.d/13927.bugfix @@ -0,0 +1 @@ +Fix a bug which prevented setting an avatar on homeservers which have an explicit port in their `server_name` and have `max_avatar_size` and/or `allowed_avatar_mimetypes` configuration. Contributed by @ashfame. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index d8ff5289b5..4bf9a047a3 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -307,7 +307,11 @@ class ProfileHandler: if not self.max_avatar_size and not self.allowed_avatar_mimetypes: return True - server_name, _, media_id = parse_and_validate_mxc_uri(mxc) + host, port, media_id = parse_and_validate_mxc_uri(mxc) + if port is not None: + server_name = host + ":" + str(port) + else: + server_name = host if server_name == self.server_name: media_info = await self.store.get_local_media(media_id) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f88c725a42..675aa023ac 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,6 +14,8 @@ from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor import synapse.types @@ -327,6 +329,53 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertFalse(res) + @unittest.override_config( + {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]} + ) + def test_avatar_constraint_on_local_server_with_port(self): + """Test that avatar metadata is correctly fetched when the media is on a local + server and the server has an explicit port. + + (This was previously a bug) + """ + local_server_name = self.hs.config.server.server_name + media_id = "local" + local_mxc = f"mxc://{local_server_name}/{media_id}" + + # mock up the existence of the avatar file + self._setup_local_files({media_id: {"mimetype": "image/png"}}) + + # and now check that check_avatar_size_and_mime_type is happy + self.assertTrue( + self.get_success(self.handler.check_avatar_size_and_mime_type(local_mxc)) + ) + + @parameterized.expand([("remote",), ("remote:1234",)]) + @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) + def test_check_avatar_on_remote_server(self, remote_server_name: str) -> None: + """Test that avatar metadata is correctly fetched from a remote server""" + media_id = "remote" + remote_mxc = f"mxc://{remote_server_name}/{media_id}" + + # if the media is remote, check_avatar_size_and_mime_type just checks the + # media cache, so we don't need to instantiate a real remote server. It is + # sufficient to poke an entry into the db. + self.get_success( + self.hs.get_datastores().main.store_cached_remote_media( + media_id=media_id, + media_type="image/png", + media_length=50, + origin=remote_server_name, + time_now_ms=self.clock.time_msec(), + upload_name=None, + filesystem_id="xyz", + ) + ) + + self.assertTrue( + self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc)) + ) + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): """Stores metadata about files in the database. -- cgit 1.5.1 From aa70556699e649f46f51a198fb104eecdc0d311b Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 27 Oct 2022 13:29:23 -0500 Subject: Check appservice user interest against the local users instead of all users (`get_users_in_room` mis-use) (#13958) --- changelog.d/13958.bugfix | 1 + docs/upgrade.md | 19 ++++ synapse/appservice/__init__.py | 16 ++- synapse/storage/databases/main/appservice.py | 17 ++- synapse/storage/databases/main/roommember.py | 3 + tests/appservice/test_appservice.py | 10 +- tests/handlers/test_appservice.py | 162 ++++++++++++++++++++++++++- 7 files changed, 214 insertions(+), 14 deletions(-) create mode 100644 changelog.d/13958.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/13958.bugfix b/changelog.d/13958.bugfix new file mode 100644 index 0000000000..f9f651bfdc --- /dev/null +++ b/changelog.d/13958.bugfix @@ -0,0 +1 @@ +Check appservice user interest against the local users instead of all users in the room to align with [MSC3905](https://github.com/matrix-org/matrix-spec-proposals/pull/3905). diff --git a/docs/upgrade.md b/docs/upgrade.md index 78c34d0c15..f095bbc3a6 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -97,6 +97,25 @@ As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_s Modules relying on it can instead use the `create_login_token` method. +## Changes to the events received by application services (interest) + +To align with spec (changed in +[MSC3905](https://github.com/matrix-org/matrix-spec-proposals/pull/3905)), Synapse now +only considers local users to be interesting. In other words, the `users` namespace +regex is only be applied against local users of the homeserver. + +Please note, this probably doesn't affect the expected behavior of your application +service, since an interesting local user in a room still means all messages in the room +(from local or remote users) will still be considered interesting. And matching a room +with the `rooms` or `aliases` namespace regex will still consider all events sent in the +room to be interesting to the application service. + +If one of your application service's `users` regex was intending to match a remote user, +this will no longer match as you expect. The behavioral mismatch between matching all +local users and some remote users is why the spec was changed/clarified and this +caveat is no longer supported. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 0dfa00df44..500bdde3a9 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -172,12 +172,24 @@ class ApplicationService: Returns: True if this service would like to know about this room. """ - member_list = await store.get_users_in_room( + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + # + # In the future, we can consider re-using + # `store.get_app_service_users_in_room` which is very similar to this + # function but has a slightly worse performance than this because we + # have an early escape-hatch if we find a single user that the + # appservice is interested in. The juice would be worth the squeeze if + # `store.get_app_service_users_in_room` was used in more places besides + # an experimental MSC. But for now we can avoid doing more work and + # barely using it later. + local_user_ids = await store.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) # check joined member events - for user_id in member_list: + for user_id in local_user_ids: if self.is_interested_in_user(user_id): return True return False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 64b70a7b28..63046c0527 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -157,10 +157,23 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): app_service: "ApplicationService", cache_context: _CacheContext, ) -> List[str]: - users_in_room = await self.get_users_in_room( + """ + Get all users in a room that the appservice controls. + + Args: + room_id: The room to check in. + app_service: The application service to check interest/control against + + Returns: + List of user IDs that the appservice controls. + """ + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + local_users_in_room = await self.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) - return list(filter(app_service.is_interested_in_user, users_in_room)) + return list(filter(app_service.is_interested_in_user, local_users_in_room)) class ApplicationServiceStore(ApplicationServiceWorkerStore): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ab708b0ba5..e56a13f21e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -152,6 +152,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): the forward extremities of those rooms will exclude most members. We may also calculate room state incorrectly for such rooms and believe that a member is or is not in the room when the opposite is true. + + Note: If you only care about users in the room local to the homeserver, use + `get_local_users_in_room(...)` instead which will be more performant. """ return await self.db_pool.simple_select_onecol( table="current_state_events", diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 3018d3fc6f..d4dccfc2f0 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store = Mock() self.store.get_aliases_for_room = simple_async_mock([]) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): @@ -129,7 +129,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room = simple_async_mock( ["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -184,7 +184,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room = simple_async_mock( ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertFalse( ( yield defer.ensureDeferred( @@ -203,7 +203,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -236,7 +236,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_users_in_room = simple_async_mock( + self.store.get_local_users_in_room = simple_async_mock( ["@alice:here", "@irc_fo:here", "@bob:here"] ) self.store.get_aliases_for_room = simple_async_mock([]) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 7e4570f990..144e49d0fd 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -22,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.api.constants import EduTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ( ApplicationService, TransactionOneTimeKeyCounts, @@ -36,7 +36,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import make_awaitable, simple_async_mock +from tests.test_utils import event_injection, make_awaitable, simple_async_mock from tests.unittest import override_config from tests.utils import MockClock @@ -390,15 +390,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events self.send_mock = simple_async_mock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] return_value=self._services ) @@ -416,6 +417,157 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) + def _notify_interested_services(self): + # This is normally set in `notify_interested_services` but we need to call the + # internal async version so the reactor gets pushed to completion. + self.hs.get_application_service_handler().current_max += 1 + self.get_success( + self.hs.get_application_service_handler()._notify_interested_services( + RoomStreamToken( + None, self.hs.get_application_service_handler().current_max + ) + ) + ) + + @parameterized.expand( + [ + ("@local_as_user:test", True), + # Defining remote users in an application service user namespace regex is a + # footgun since the appservice might assume that it'll receive all events + # sent by that remote user, but it will only receive events in rooms that + # are shared with a local user. So we just remove this footgun possibility + # entirely and we won't notify the application service based on remote + # users. + ("@remote_as_user:remote", False), + ] + ) + def test_match_interesting_room_members( + self, interesting_user: str, should_notify: bool + ): + """ + Test to make sure that a interesting user (local or remote) in the room is + notified as expected when someone else in the room sends a message. + """ + # Register an application service that's interested in the `interesting_user` + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": interesting_user, + "exclusive": False, + }, + ], + }, + ) + + # Create a room + alice = self.register_user("alice", "pass") + alice_access_token = self.login("alice", "pass") + room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token) + + # Join the interesting user to the room + self.get_success( + event_injection.inject_member_event( + self.hs, room_id, interesting_user, "join" + ) + ) + # Kick the appservice into checking this membership event to get the event out + # of the way + self._notify_interested_services() + # We don't care about the interesting user join event (this test is making sure + # the next thing works) + self.send_mock.reset_mock() + + # Send a message from an uninteresting user + self.helper.send_event( + room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message from uninteresting user", + }, + tok=alice_access_token, + ) + # Kick the appservice into checking this new event + self._notify_interested_services() + + if should_notify: + self.send_mock.assert_called_once() + ( + service, + events, + _ephemeral, + _to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Even though the message came from an uninteresting user, it should still + # notify us because the interesting user is joined to the room where the + # message was sent. + self.assertEqual(service, interested_appservice) + self.assertEqual(events[0]["type"], "m.room.message") + self.assertEqual(events[0]["sender"], alice) + else: + self.send_mock.assert_not_called() + + def test_application_services_receive_events_sent_by_interesting_local_user(self): + """ + Test to make sure that a messages sent from a local user can be interesting and + picked up by the appservice. + """ + # Register an application service that's interested in all local users + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": ".*", + "exclusive": False, + }, + ], + }, + ) + + # Create a room + alice = self.register_user("alice", "pass") + alice_access_token = self.login("alice", "pass") + room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token) + + # We don't care about interesting events before this (this test is making sure + # the next thing works) + self.send_mock.reset_mock() + + # Send a message from the interesting local user + self.helper.send_event( + room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message from interesting local user", + }, + tok=alice_access_token, + ) + # Kick the appservice into checking this new event + self._notify_interested_services() + + self.send_mock.assert_called_once() + ( + service, + events, + _ephemeral, + _to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Events sent from an interesting local user should also be picked up as + # interesting to the appservice. + self.assertEqual(service, interested_appservice) + self.assertEqual(events[0]["type"], "m.room.message") + self.assertEqual(events[0]["sender"], alice) + def test_sending_read_receipt_batches_to_application_services(self): """Tests that a large batch of read receipts are sent correctly to interested application services. -- cgit 1.5.1 From 618e4ab81b70e37bdb8e9224bd84fcfe4b15bdea Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:25:35 +0000 Subject: Fix an invalid comparison of `UserPresenceState` to `str` (#14393) --- changelog.d/14393.bugfix | 1 + synapse/handlers/presence.py | 2 +- tests/handlers/test_presence.py | 41 +++++++++++++++++++++++++++++++++++------ tests/module_api/test_api.py | 3 +++ tests/replication/_base.py | 7 ++++++- 5 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14393.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/14393.bugfix b/changelog.d/14393.bugfix new file mode 100644 index 0000000000..97177bc62f --- /dev/null +++ b/changelog.d/14393.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.58.0 where a user with presence state 'org.matrix.msc3026.busy' would mistakenly be set to 'online' when calling `/sync` or `/events` on a worker process. \ No newline at end of file diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b7bc787636..cf08737d11 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -478,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler): return _NullContextManager() prev_state = await self.current_state_for_user(user_id) - if prev_state != PresenceState.BUSY: + if prev_state.state != PresenceState.BUSY: # We set state here but pass ignore_status_msg = True as we don't want to # cause the status message to be cleared. # Note that this causes last_active_ts to be incremented which is not diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c96dc6caf2..c5981ff965 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -15,6 +15,7 @@ from typing import Optional from unittest.mock import Mock, call +from parameterized import parameterized from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState @@ -37,6 +38,7 @@ from synapse.rest.client import room from synapse.types import UserID, get_domain_from_id from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase class PresenceUpdateTestCase(unittest.HomeserverTestCase): @@ -505,7 +507,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(state, new_state) -class PresenceHandlerTestCase(unittest.HomeserverTestCase): +class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): def prepare(self, reactor, clock, hs): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() @@ -716,20 +718,47 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): # our status message should be the same as it was before self.assertEqual(state.status_msg, status_msg) - def test_set_presence_from_syncing_keeps_busy(self): - """Test that presence set by syncing doesn't affect busy status""" - # while this isn't the default - self.presence_handler._busy_presence_enabled = True + @parameterized.expand([(False,), (True,)]) + @unittest.override_config( + { + "experimental_features": { + "msc3026_enabled": True, + }, + } + ) + def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): + """Test that presence set by syncing doesn't affect busy status + Args: + test_with_workers: If True, check the presence state of the user by calling + /sync against a worker, rather than the main process. + """ user_id = "@test:server" status_msg = "I'm busy!" + # By default, we call /sync against the main process. + worker_to_sync_against = self.hs + if test_with_workers: + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "presence_writer"} + ) + + # Set presence to BUSY self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg) + # Perform a sync with a presence state other than busy. This should NOT change + # our presence status; we only change from busy if we explicitly set it via + # /presence/*. self.get_success( - self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) + worker_to_sync_against.get_presence_handler().user_syncing( + user_id, True, PresenceState.ONLINE + ) ) + # Check against the main process that the user's presence did not change. state = self.get_success( self.presence_handler.get_state(UserID.from_string(user_id)) ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 02cef6f876..058ca57e55 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -778,8 +778,11 @@ def _test_sending_local_online_presence_to_local_user( worker process. The test users will still sync with the main process. The purpose of testing with a worker is to check whether a Synapse module running on a worker can inform other workers/ the main process that they should include additional presence when a user next syncs. + If this argument is True, `test_case` MUST be an instance of BaseMultiWorkerStreamTestCase. """ if test_with_workers: + assert isinstance(test_case, BaseMultiWorkerStreamTestCase) + # Create a worker process to make module_api calls against worker_hs = test_case.make_worker_hs( "synapse.app.generic_worker", {"worker_name": "presence_writer"} diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 121f3d8d65..3029a16dda 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -542,8 +542,13 @@ class FakeRedisPubSubProtocol(Protocol): self.send("OK") elif command == b"GET": self.send(None) + + # Connection keep-alives. + elif command == b"PING": + self.send("PONG") + else: - raise Exception("Unknown command") + raise Exception(f"Unknown command: {command}") def send(self, msg): """Send a message back to the client.""" -- cgit 1.5.1 From 6d47b7e32589e816eb766446cc1ff19ea73fc7c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 14:08:04 -0500 Subject: Add a type hint for `get_device_handler()` and fix incorrect types. (#14055) This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places. --- changelog.d/14055.misc | 1 + synapse/handlers/deactivate_account.py | 4 +++ synapse/handlers/device.py | 65 ++++++++++++++++++++++++++-------- synapse/handlers/e2e_keys.py | 61 ++++++++++++++++--------------- synapse/handlers/register.py | 4 +++ synapse/handlers/set_password.py | 6 +++- synapse/handlers/sso.py | 9 +++++ synapse/module_api/__init__.py | 10 +++++- synapse/replication/http/devices.py | 11 ++++-- synapse/rest/admin/__init__.py | 26 ++++++++------ synapse/rest/admin/devices.py | 13 +++++-- synapse/rest/client/devices.py | 17 ++++++--- synapse/rest/client/logout.py | 9 +++-- synapse/server.py | 2 +- tests/handlers/test_device.py | 19 ++++++---- tests/rest/admin/test_device.py | 5 ++- 16 files changed, 185 insertions(+), 77 deletions(-) create mode 100644 changelog.d/14055.misc (limited to 'tests/handlers') diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc new file mode 100644 index 0000000000..02980bc528 --- /dev/null +++ b/changelog.d/14055.misc @@ -0,0 +1 @@ +Add missing type hints to `HomeServer`. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 816e1a6d79..d74d135c0c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError +from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -76,6 +77,9 @@ class DeactivateAccountHandler: True if identity server supports removing threepids, otherwise False. """ + # This can only be called on the main process. + assert isinstance(self._device_handler, DeviceHandler) + # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index da3ddafeae..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: + device_list_updater: "DeviceListWorkerUpdater" + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs @@ -76,6 +78,8 @@ class DeviceWorkerHandler: self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self.device_list_updater = DeviceListWorkerUpdater(hs) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -99,6 +103,19 @@ class DeviceWorkerHandler: log_kv(device_map) return devices + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: """Retrieve the given device @@ -127,7 +144,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken - ) -> Collection[str]: + ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ @@ -320,6 +337,8 @@ class DeviceWorkerHandler: class DeviceHandler(DeviceWorkerHandler): + device_list_updater: "DeviceListUpdater" + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, [old_device_id]) return device_id - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[Tuple[str, JsonDict]]: - """Retrieve the information for a dehydrated device. - - Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information - """ - return await self.store.get_dehydrated_device(user_id) - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: @@ -882,7 +888,36 @@ def _update_device_from_client_ips( ) -class DeviceListUpdater: +class DeviceListWorkerUpdater: + "Handles incoming device list updates from federation and contacts the main process over replication" + + def __init__(self, hs: "HomeServer"): + from synapse.replication.http.devices import ( + ReplicationUserDevicesResyncRestServlet, + ) + + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[JsonDict]: + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale + if the attempt to resync failed. + Returns: + A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + return await self._user_device_resync_client(user_id=user_id) + + +class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index bf1221f523..5fe102e2f2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -27,9 +27,9 @@ from twisted.internet import defer from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( JsonDict, UserID, @@ -56,27 +56,23 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() - self._edu_updater = SigningKeyEduUpdater(hs, self) - federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker.worker_app is None - if not self._is_master: - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) - else: + is_master = hs.config.worker.worker_app is None + if is_master: + edu_updater = SigningKeyEduUpdater(hs) + # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( EduTypes.SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the @@ -319,14 +315,13 @@ class E2eKeysHandler: # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - if self._is_master: - resync_results = await self.device_handler.device_list_updater.user_device_resync( + resync_results = ( + await self.device_handler.device_list_updater.user_device_resync( user_id ) - else: - resync_results = await self._user_device_resync_client( - user_id=user_id - ) + ) + if resync_results is None: + raise ValueError("Device resync failed") # Add the device keys to the results. user_devices = resync_results["devices"] @@ -605,6 +600,8 @@ class E2eKeysHandler: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) time_now = self.clock.time_msec() @@ -732,6 +729,8 @@ class E2eKeysHandler: user_id: the user uploading the keys keys: the signing keys """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys @@ -823,6 +822,9 @@ class E2eKeysHandler: Raises: SynapseError: if the signatures dict is not valid. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -1200,6 +1202,9 @@ class E2eKeysHandler: A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() @@ -1396,11 +1401,14 @@ class SignatureListItem: class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() - self.e2e_keys_handler = e2e_keys_handler + + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_signing_key") @@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater: user_id: the user whose updates we are processing """ - device_handler = self.e2e_keys_handler.device_handler - device_list_updater = device_handler.device_list_updater - async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: @@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = ( - await device_list_updater.process_cross_signing_key_update( - user_id, - master_key, - self_signing_key, - ) + new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + new_device_ids - await device_handler.notify_device_update(user_id, device_ids) + await self._device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1c7a1866..6307fa9c5d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -841,6 +842,9 @@ class RegistrationHandler: refresh_token = None refresh_token_id = None + # This can only run on the main process. + assert isinstance(self.device_handler, DeviceHandler) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 73861bbd40..bd9d0bb34b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -29,7 +30,10 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler async def set_password( self, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 749d7e93b0..e1c0bff1b2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -1035,6 +1036,8 @@ class SsoHandler: ) -> None: """Revoke any devices and in-flight logins tied to a provider session. + Can only be called from the main process. + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1042,6 +1045,12 @@ class SsoHandler: expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ + + # It is expected that this is the main process. + assert isinstance( + self._device_handler, DeviceHandler + ), "revoking SSO sessions can only be called on the main process" + # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1adc1fd64f..96a661177a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,6 +86,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -207,6 +208,7 @@ class ModuleApi: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -784,6 +786,8 @@ class ModuleApi: ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user + Can only be called from the main process. + Added in Synapse v0.25.0. Args: @@ -796,6 +800,10 @@ class ModuleApi: Raises: synapse.api.errors.AuthError: the access token is invalid """ + assert isinstance( + self._device_handler, DeviceHandler + ), "invalidate_access_token can only be called on the main process" + # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) @@ -805,7 +813,7 @@ class ModuleApi: if device_id: # delete the device, which will also delete its access tokens yield defer.ensureDeferred( - self._hs.get_device_handler().delete_devices(user_id, [device_id]) + self._device_handler.delete_devices(user_id, [device_id]) ) else: # no associated device. Just delete the access token. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index c21629def8..7c4941c3d3 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.device_list_updater = hs.get_device_handler().device_list_updater + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c62ea22116..fb73886df0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ + # Admin servlets aren't registered on workers. + if hs.config.worker.worker_app is not None: + return + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeleteDevicesRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) @@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) - # Some servlets only get registered for the main process. - if hs.config.worker.worker_app is None: - SendServerNoticeServlet(hs).register(http_server) - BackgroundUpdateEnabledRestServlet(hs).register(http_server) - BackgroundUpdateRestServlet(hs).register(http_server) - BackgroundUpdateStartJobRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) + SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + # The following resources can only be run on the main process. + if hs.config.worker.worker_app is None: + DeactivateAccountRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d934880102..3b2f2d9abb 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8f3cbd4ea2..69b803f9f8 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): @@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler class PostBody(RequestBodyModel): device_id: StrictStr diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 23dfa4518f..6d34625ad5 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) diff --git a/synapse/server.py b/synapse/server.py index f0a60d0056..5baae2325e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta): ) @cache_in_self - def get_device_handler(self): + def get_device_handler(self) -> DeviceWorkerHandler: if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,7 @@ from typing import Optional from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.util import Clock @@ -32,7 +32,9 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.store = hs.get_datastores().main return hs @@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_is_preserved_if_exists(self) -> None: @@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res2, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_id_is_made_up_if_unspecified(self) -> None: @@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) + assert dev is not None self.assertEqual(dev["display_name"], "display") def test_get_devices_by_user(self) -> None: @@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ) ) - retrieved_device_id, device_data = self.get_success( - self.handler.get_dehydrated_device(user_id=user_id) - ) + result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + assert result is not None + retrieved_device_id, device_data = result self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index d52aee8f92..03f2112b07 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.handlers.device import DeviceHandler from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.5.1 From 09de2aecb05cb46e0513396e2675b24c8beedb68 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Fri, 25 Nov 2022 19:16:50 +0400 Subject: Add support for handling avatar with SSO login (#13917) This commit adds support for handling a provided avatar picture URL when logging in via SSO. Signed-off-by: Ashish Kumar Fixes #9357. --- changelog.d/13917.feature | 1 + docs/usage/configuration/config_documentation.md | 9 +- mypy.ini | 4 +- synapse/handlers/oidc.py | 7 ++ synapse/handlers/sso.py | 111 +++++++++++++++++ tests/handlers/test_sso.py | 145 +++++++++++++++++++++++ 6 files changed, 275 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13917.feature create mode 100644 tests/handlers/test_sso.py (limited to 'tests/handlers') diff --git a/changelog.d/13917.feature b/changelog.d/13917.feature new file mode 100644 index 0000000000..4eb942ab38 --- /dev/null +++ b/changelog.d/13917.feature @@ -0,0 +1 @@ +Adds support for handling avatar in SSO login. Contributed by @ashfame. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index fae2771fad..749af12aac 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2968,10 +2968,17 @@ Options for each entry include: For the default provider, the following settings are available: - * subject_claim: name of the claim containing a unique identifier + * `subject_claim`: name of the claim containing a unique identifier for the user. Defaults to 'sub', which OpenID Connect compliant providers should provide. + * `picture_claim`: name of the claim containing an url for the user's profile picture. + Defaults to 'picture', which OpenID Connect compliant providers should provide + and has to refer to a direct image file such as PNG, JPEG, or GIF image file. + + Currently only supported in monolithic (single-process) server configurations + where the media repository runs within the Synapse process. + * `localpart_template`: Jinja2 template for the localpart of the MXID. If this is not set, the user will be prompted to choose their own username (see the documentation for the `sso_auth_account_details.html` diff --git a/mypy.ini b/mypy.ini index 25b3c93748..0b6e7df267 100644 --- a/mypy.ini +++ b/mypy.ini @@ -119,6 +119,9 @@ disallow_untyped_defs = True [mypy-tests.storage.test_profile] disallow_untyped_defs = True +[mypy-tests.handlers.test_sso] +disallow_untyped_defs = True + [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True @@ -137,7 +140,6 @@ disallow_untyped_defs = False [mypy-tests.utils] disallow_untyped_defs = True - ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. ;; The `typeshed` project maintains stubs here: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 41c675f408..03de6a4ba6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1435,6 +1435,7 @@ class UserAttributeDict(TypedDict): localpart: Optional[str] confirm_localpart: bool display_name: Optional[str] + picture: Optional[str] # may be omitted by older `OidcMappingProviders` emails: List[str] @@ -1520,6 +1521,7 @@ env.filters.update( @attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: subject_claim: str + picture_claim: str localpart_template: Optional[Template] display_name_template: Optional[Template] email_template: Optional[Template] @@ -1539,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @staticmethod def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") + picture_claim = config.get("picture_claim", "picture") def parse_template_config(option_name: str) -> Optional[Template]: if option_name not in config: @@ -1572,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): return JinjaOidcMappingConfig( subject_claim=subject_claim, + picture_claim=picture_claim, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, @@ -1611,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) + picture = userinfo.get("picture") + return UserAttributeDict( localpart=localpart, display_name=display_name, emails=emails, + picture=picture, confirm_localpart=self._config.confirm_localpart, ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e1c0bff1b2..44e70fc4b8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import hashlib +import io import logging from typing import ( TYPE_CHECKING, @@ -138,6 +140,7 @@ class UserAttributes: localpart: Optional[str] confirm_localpart: bool = False display_name: Optional[str] = None + picture: Optional[str] = None emails: Collection[str] = attr.Factory(list) @@ -196,6 +199,10 @@ class SsoHandler: self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() + self._media_repo = ( + hs.get_media_repository() if hs.config.media.can_load_media_repo else None + ) + self._http_client = hs.get_proxied_blacklisted_http_client() # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. @@ -495,6 +502,8 @@ class SsoHandler: await self._profile_handler.set_displayname( user_id_obj, requester, attributes.display_name, True ) + if attributes.picture: + await self.set_avatar(user_id, attributes.picture) await self._auth_handler.complete_sso_login( user_id, @@ -703,8 +712,110 @@ class SsoHandler: await self._store.record_user_external_id( auth_provider_id, remote_user_id, registered_user_id ) + + # Set avatar, if available + if attributes.picture: + await self.set_avatar(registered_user_id, attributes.picture) + return registered_user_id + async def set_avatar(self, user_id: str, picture_https_url: str) -> bool: + """Set avatar of the user. + + This downloads the image file from the URL provided, stores that in + the media repository and then sets the avatar on the user's profile. + + It can detect if the same image is being saved again and bails early by storing + the hash of the file in the `upload_name` of the avatar image. + + Currently, it only supports server configurations which run the media repository + within the same process. + + It silently fails and logs a warning by raising an exception and catching it + internally if: + * it is unable to fetch the image itself (non 200 status code) or + * the image supplied is bigger than max allowed size or + * the image type is not one of the allowed image types. + + Args: + user_id: matrix user ID in the form @localpart:domain as a string. + + picture_https_url: HTTPS url for the picture image file. + + Returns: `True` if the user's avatar has been successfully set to the image at + `picture_https_url`. + """ + if self._media_repo is None: + logger.info( + "failed to set user avatar because out-of-process media repositories " + "are not supported yet " + ) + return False + + try: + uid = UserID.from_string(user_id) + + def is_allowed_mime_type(content_type: str) -> bool: + if ( + self._profile_handler.allowed_avatar_mimetypes + and content_type + not in self._profile_handler.allowed_avatar_mimetypes + ): + return False + return True + + # download picture, enforcing size limit & mime type check + picture = io.BytesIO() + + content_length, headers, uri, code = await self._http_client.get_file( + url=picture_https_url, + output_stream=picture, + max_size=self._profile_handler.max_avatar_size, + is_allowed_content_type=is_allowed_mime_type, + ) + + if code != 200: + raise Exception( + "GET request to download sso avatar image returned {}".format(code) + ) + + # upload name includes hash of the image file's content so that we can + # easily check if it requires an update or not, the next time user logs in + upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest() + + # bail if user already has the same avatar + profile = await self._profile_handler.get_profile(user_id) + if profile["avatar_url"] is not None: + server_name = profile["avatar_url"].split("/")[-2] + media_id = profile["avatar_url"].split("/")[-1] + if server_name == self._server_name: + media = await self._media_repo.store.get_local_media(media_id) + if media is not None and upload_name == media["upload_name"]: + logger.info("skipping saving the user avatar") + return True + + # store it in media repository + avatar_mxc_url = await self._media_repo.create_content( + media_type=headers[b"Content-Type"][0].decode("utf-8"), + upload_name=upload_name, + content=picture, + content_length=content_length, + auth_user=uid, + ) + + # save it as user avatar + await self._profile_handler.set_avatar_url( + uid, + create_requester(uid), + str(avatar_mxc_url), + ) + + logger.info("successfully saved the user avatar") + return True + except Exception: + logger.warning("failed to save the user avatar") + return False + async def complete_sso_ui_auth_request( self, auth_provider_id: str, diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py new file mode 100644 index 0000000000..137deab138 --- /dev/null +++ b/tests/handlers/test_sso.py @@ -0,0 +1,145 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus +from typing import BinaryIO, Callable, Dict, List, Optional, Tuple +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.http_headers import Headers + +from synapse.api.errors import Codes, SynapseError +from synapse.http.client import RawHeaders +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.test_utils import SMALL_PNG, FakeResponse + + +class TestSSOHandler(unittest.HomeserverTestCase): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock(spec=["get_file"]) + self.http_client.get_file.side_effect = mock_get_file + self.http_client.user_agent = b"Synapse Test" + hs = self.setup_test_homeserver( + proxied_blacklisted_http_client=self.http_client + ) + return hs + + async def test_set_avatar(self) -> None: + """Tests successfully setting the avatar of a newly created user""" + handler = self.hs.get_sso_handler() + + # Create a new user to set avatar for + reg_handler = self.hs.get_registration_handler() + user_id = self.get_success(reg_handler.register_user(approved=True)) + + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # Ensure avatar is set on this newly created user, + # so no need to compare for the exact image + profile_handler = self.hs.get_profile_handler() + profile = self.get_success(profile_handler.get_profile(user_id)) + self.assertIsNot(profile["avatar_url"], None) + + @unittest.override_config({"max_avatar_size": 1}) + async def test_set_avatar_too_big_image(self) -> None: + """Tests that saving an avatar fails when it is too big""" + handler = self.hs.get_sso_handler() + + # any random user works since image check is supposed to fail + user_id = "@sso-user:test" + + self.assertFalse( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + @unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]}) + async def test_set_avatar_incorrect_mime_type(self) -> None: + """Tests that saving an avatar fails when its mime type is not allowed""" + handler = self.hs.get_sso_handler() + + # any random user works since image check is supposed to fail + user_id = "@sso-user:test" + + self.assertFalse( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + async def test_skip_saving_avatar_when_not_changed(self) -> None: + """Tests whether saving of avatar correctly skips if the avatar hasn't + changed""" + handler = self.hs.get_sso_handler() + + # Create a new user to set avatar for + reg_handler = self.hs.get_registration_handler() + user_id = self.get_success(reg_handler.register_user(approved=True)) + + # set avatar for the first time, should be a success + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # get avatar picture for comparison after another attempt + profile_handler = self.hs.get_profile_handler() + profile = self.get_success(profile_handler.get_profile(user_id)) + url_to_match = profile["avatar_url"] + + # set same avatar for the second time, should be a success + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # compare avatar picture's url from previous step + profile = self.get_success(profile_handler.get_profile(user_id)) + self.assertEqual(profile["avatar_url"], url_to_match) + + +async def mock_get_file( + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + is_allowed_content_type: Optional[Callable[[str], bool]] = None, +) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: + + fake_response = FakeResponse(code=404) + if url == "http://my.server/me.png": + fake_response = FakeResponse( + code=200, + headers=Headers( + {"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]} + ), + body=SMALL_PNG, + ) + + if max_size is not None and max_size < len(SMALL_PNG): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + "Requested file is too large > %r bytes" % (max_size,), + Codes.TOO_LARGE, + ) + + if is_allowed_content_type and not is_allowed_content_type("image/png"): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + ( + "Requested file's content type not allowed for this operation: %s" + % "image/png" + ), + ) + + output_stream.write(fake_response.body) + + return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200 -- cgit 1.5.1 From 1183c372fa9da01b2667f1b83dab958dad432c68 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 28 Nov 2022 11:17:29 -0500 Subject: Use `device_one_time_keys_count` to match MSC3202 (#14565) * Use `device_one_time_keys_count` to match MSC3202 Rename the `device_one_time_key_counts` key in responses to `device_one_time_keys_count` to match the name specified by MSC3202. Also change related variable/class names for consistency. Signed-off-by: Andrew Ferrazzutti * Update changelog.d/14565.misc * Revert name change for `one_time_key_counts` key as this is a different key altogether from `device_one_time_keys_count`, which is used for `/sync` instead of appservice transactions. Signed-off-by: Andrew Ferrazzutti --- changelog.d/14565.misc | 1 + synapse/appservice/__init__.py | 10 +++++----- synapse/appservice/api.py | 11 +++++++---- synapse/appservice/scheduler.py | 16 ++++++++-------- synapse/handlers/sync.py | 6 +++--- synapse/storage/databases/main/appservice.py | 10 +++++----- synapse/storage/databases/main/end_to_end_keys.py | 8 ++++---- tests/appservice/test_scheduler.py | 6 +++--- tests/handlers/test_appservice.py | 4 ++-- 9 files changed, 38 insertions(+), 34 deletions(-) create mode 100644 changelog.d/14565.misc (limited to 'tests/handlers') diff --git a/changelog.d/14565.misc b/changelog.d/14565.misc new file mode 100644 index 0000000000..19a62b036c --- /dev/null +++ b/changelog.d/14565.misc @@ -0,0 +1 @@ +In application service transactions that include the experimental `org.matrix.msc3202.device_one_time_key_counts` key, include a duplicate key of `org.matrix.msc3202.device_one_time_keys_count` to match the name proposed by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/blob/travis/msc/otk-dl-appservice/proposals/3202-encrypted-appservices.md). diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 500bdde3a9..bf4e6c629b 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -32,9 +32,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Type for the `device_one_time_key_counts` field in an appservice transaction +# Type for the `device_one_time_keys_count` field in an appservice transaction # user ID -> {device ID -> {algorithm -> count}} -TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] +TransactionOneTimeKeysCount = Dict[str, Dict[str, Dict[str, int]]] # Type for the `device_unused_fallback_key_types` field in an appservice transaction # user ID -> {device ID -> [algorithm]} @@ -376,7 +376,7 @@ class AppServiceTransaction: events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ): @@ -385,7 +385,7 @@ class AppServiceTransaction: self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages - self.one_time_key_counts = one_time_key_counts + self.one_time_keys_count = one_time_keys_count self.unused_fallback_keys = unused_fallback_keys self.device_list_summary = device_list_summary @@ -402,7 +402,7 @@ class AppServiceTransaction: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, - one_time_key_counts=self.one_time_key_counts, + one_time_keys_count=self.one_time_keys_count, unused_fallback_keys=self.unused_fallback_keys, device_list_summary=self.device_list_summary, txn_id=self.id, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 60774b240d..edafd433cd 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.appservice import ( ApplicationService, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.events import EventBase @@ -262,7 +262,7 @@ class ApplicationServiceApi(SimpleHttpClient): events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, txn_id: Optional[int] = None, @@ -310,10 +310,13 @@ class ApplicationServiceApi(SimpleHttpClient): # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: - if one_time_key_counts: + if one_time_keys_count: body[ "org.matrix.msc3202.device_one_time_key_counts" - ] = one_time_key_counts + ] = one_time_keys_count + body[ + "org.matrix.msc3202.device_one_time_keys_count" + ] = one_time_keys_count if unused_fallback_keys: body[ "org.matrix.msc3202.device_unused_fallback_key_types" diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 430ffbcd1f..7b562795a3 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -64,7 +64,7 @@ from typing import ( from synapse.appservice import ( ApplicationService, ApplicationServiceState, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.appservice.api import ApplicationServiceApi @@ -258,7 +258,7 @@ class _ServiceQueuer: ): return - one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None if ( @@ -269,7 +269,7 @@ class _ServiceQueuer: # for the users which are mentioned in this transaction, # as well as the appservice's sender. ( - one_time_key_counts, + one_time_keys_count, unused_fallback_keys, ) = await self._compute_msc3202_otk_counts_and_fallback_keys( service, events, ephemeral, to_device_messages_to_send @@ -281,7 +281,7 @@ class _ServiceQueuer: events, ephemeral, to_device_messages_to_send, - one_time_key_counts, + one_time_keys_count, unused_fallback_keys, device_list_summary, ) @@ -296,7 +296,7 @@ class _ServiceQueuer: events: Iterable[EventBase], ephemerals: Iterable[JsonDict], to_device_messages: Iterable[JsonDict], - ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: """ Given a list of the events, ephemeral messages and to-device messages, - first computes a list of application services users that may have @@ -367,7 +367,7 @@ class _TransactionController: events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, - one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: @@ -380,7 +380,7 @@ class _TransactionController: events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -397,7 +397,7 @@ class _TransactionController: events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], - one_time_key_counts=one_time_key_counts or {}, + one_time_keys_count=one_time_keys_count or {}, unused_fallback_keys=unused_fallback_keys or {}, device_list_summary=device_list_summary or DeviceListUpdates(), ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 259456b55d..c8858b22dd 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1426,14 +1426,14 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_key_counts: JsonDict = {} + one_time_keys_count: JsonDict = {} unused_fallback_key_types: List[str] = [] if device_id: # TODO: We should have a way to let clients differentiate between the states of: # * no change in OTK count since the provided since token # * the server has zero OTKs left for this device # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 - one_time_key_counts = await self.store.count_e2e_one_time_keys( + one_time_keys_count = await self.store.count_e2e_one_time_keys( user_id, device_id ) unused_fallback_key_types = ( @@ -1463,7 +1463,7 @@ class SyncHandler: archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, - device_one_time_keys_count=one_time_key_counts, + device_one_time_keys_count=one_time_keys_count, device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 25da0c56c5..c2c8018ee2 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,7 +20,7 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices @@ -260,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: @@ -273,7 +273,7 @@ class ApplicationServiceTransactionWorkerStore( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -299,7 +299,7 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, - one_time_key_counts=one_time_key_counts, + one_time_keys_count=one_time_keys_count, unused_fallback_keys=unused_fallback_keys, device_list_summary=device_list_summary, ) @@ -379,7 +379,7 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index cf33e73e2b..643c47d608 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -33,7 +33,7 @@ from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.logging.opentracing import log_kv, set_tag, trace @@ -514,7 +514,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def count_bulk_e2e_one_time_keys_for_as( self, user_ids: Collection[str] - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: """ Counts, in bulk, the one-time keys for all the users specified. Intended to be used by application services for populating OTK counts in @@ -528,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker def _count_bulk_e2e_one_time_keys_txn( txn: LoggingTransaction, - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: user_in_where_clause, user_parameters = make_in_list_sql_clause( self.database_engine, "user_id", user_ids ) @@ -541,7 +541,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ txn.execute(sql, user_parameters) - result: TransactionOneTimeKeyCounts = {} + result: TransactionOneTimeKeysCount = {} for user_id, device_id, algorithm, count in txn: # We deliberately construct empty dictionaries for diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 0b22afdc75..0a1ae83a2b 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -69,7 +69,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -96,7 +96,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -125,7 +125,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 144e49d0fd..9ed26d87a7 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -25,7 +25,7 @@ import synapse.storage from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ( ApplicationService, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.handlers.appservice import ApplicationServicesHandler @@ -1123,7 +1123,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Capture what was sent as an AS transaction. self.send_mock.assert_called() last_args, _last_kwargs = self.send_mock.call_args - otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS] unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ self.ARG_FALLBACK_KEYS ] -- cgit 1.5.1 From c7e29ca277cf60bfdc488b93f4321b046fa6b46f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Nov 2022 10:36:41 +0000 Subject: POC delete stale non-e2e devices for users (#14038) This should help reduce the number of devices e.g. simple bots the repeatedly login rack up. We only delete non-e2e devices as they should be safe to delete, whereas if we delete e2e devices for a user we may accidentally break their ability to receive e2e keys for a message. Co-authored-by: Patrick Cloke Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14038.misc | 1 + synapse/handlers/device.py | 13 +++++- synapse/storage/databases/main/devices.py | 67 ++++++++++++++++++++++++++++++- tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14038.misc (limited to 'tests/handlers') diff --git a/changelog.d/14038.misc b/changelog.d/14038.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14038.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b1e55e1b9e..7c4dd8cf5a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -421,6 +421,9 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) + # Prune the user's device list if they already have a lot of devices. + await self._prune_too_many_devices(user_id) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -452,6 +455,14 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") + async def _prune_too_many_devices(self, user_id: str) -> None: + """Delete any excess old devices this user may have.""" + device_ids = await self.store.check_too_many_devices_for_user(user_id) + if not device_ids: + return + + await self.delete_devices(user_id, device_ids) + async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -481,7 +492,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 534f7fc04a..1e83c62753 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1533,6 +1533,70 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def check_too_many_devices_for_user(self, user_id: str) -> Collection[str]: + """Check if the user has a lot of devices, and if so return the set of + devices we can prune. + + This does *not* return hidden devices or devices with E2E keys. + """ + + num_devices = await self.db_pool.simple_select_one_onecol( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcol="COALESCE(COUNT(*), 0)", + desc="count_devices", + ) + + # We let users have up to ten devices without pruning. + if num_devices <= 10: + return () + + # We prune everything older than N days. + max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 + + if num_devices > 50: + # If the user has more than 50 devices, then we chose a last seen + # that ensures we keep at most 50 devices. + sql = """ + SELECT last_seen FROM devices + WHERE + user_id = ? + AND NOT hidden + AND last_seen IS NOT NULL + AND key_json IS NULL + ORDER BY last_seen DESC + LIMIT 1 + OFFSET 50 + """ + + rows = await self.db_pool.execute( + "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) + ) + if rows: + max_last_seen = max(rows[0][0], max_last_seen) + + # Now fetch the devices to delete. + sql = """ + SELECT DISTINCT device_id FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen < ? + AND key_json IS NULL + """ + + def check_too_many_devices_for_user_txn( + txn: LoggingTransaction, + ) -> Collection[str]: + txn.execute(sql, (user_id, max_last_seen)) + return {device_id for device_id, in txn} + + return await self.db_pool.runInteraction( + "check_too_many_devices_for_user", + check_too_many_devices_for_user_txn, + ) + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1591,6 +1655,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, + "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1636,7 +1701,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Deletes several devices. Args: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ce7525e29c..a456bffd63 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": None, + "last_seen_ts": 1000000, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 49ad3c1324..a9af1babed 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,6 +169,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) + last_seen = self.clock.time_msec() + if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -189,7 +191,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": None, + "last_seen": last_seen, }, ], ) -- cgit 1.5.1 From c29e2c630624beb0b5557aa0f7ccdcedbe62def1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 29 Nov 2022 17:48:48 +0000 Subject: Revert "POC delete stale non-e2e devices for users (#14038)" (#14582) --- changelog.d/14582.bugfix | 1 + synapse/handlers/device.py | 13 +----- synapse/storage/databases/main/devices.py | 68 +------------------------------ tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 5 insertions(+), 83 deletions(-) create mode 100644 changelog.d/14582.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/14582.bugfix b/changelog.d/14582.bugfix new file mode 100644 index 0000000000..caad468e70 --- /dev/null +++ b/changelog.d/14582.bugfix @@ -0,0 +1 @@ +Fix a regression in Synapse 1.73.0rc1 where Synapse's main process would stop responding to HTTP requests when a user with a large number of devices logs in. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7c4dd8cf5a..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -421,9 +421,6 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) - # Prune the user's device list if they already have a lot of devices. - await self._prune_too_many_devices(user_id) - if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -455,14 +452,6 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") - async def _prune_too_many_devices(self, user_id: str) -> None: - """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id) - if not device_ids: - return - - await self.delete_devices(user_id, device_ids) - async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -492,7 +481,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 0378035cff..534f7fc04a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1533,71 +1533,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user(self, user_id: str) -> Collection[str]: - """Check if the user has a lot of devices, and if so return the set of - devices we can prune. - - This does *not* return hidden devices or devices with E2E keys. - """ - - num_devices = await self.db_pool.simple_select_one_onecol( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcol="COALESCE(COUNT(*), 0)", - desc="count_devices", - ) - - # We let users have up to ten devices without pruning. - if num_devices <= 10: - return () - - # We prune everything older than N days. - max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 - - if num_devices > 50: - # If the user has more than 50 devices, then we chose a last seen - # that ensures we keep at most 50 devices. - sql = """ - SELECT last_seen FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen IS NOT NULL - AND key_json IS NULL - ORDER BY last_seen DESC - LIMIT 1 - OFFSET 50 - """ - - rows = await self.db_pool.execute( - "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) - ) - if rows: - max_last_seen = max(rows[0][0], max_last_seen) - - # Now fetch the devices to delete. - sql = """ - SELECT DISTINCT device_id FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen < ? - AND key_json IS NULL - """ - - def check_too_many_devices_for_user_txn( - txn: LoggingTransaction, - ) -> Collection[str]: - txn.execute(sql, (user_id, max_last_seen)) - return {device_id for device_id, in txn} - - return await self.db_pool.runInteraction( - "check_too_many_devices_for_user", - check_too_many_devices_for_user_txn, - ) - class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1656,7 +1591,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, - "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1702,7 +1636,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a456bffd63..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": 1000000, + "last_seen_ts": None, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index a9af1babed..49ad3c1324 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,8 +169,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) - last_seen = self.clock.time_msec() - if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -191,7 +189,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": last_seen, + "last_seen": None, }, ], ) -- cgit 1.5.1 From 854a6884d81c95297bf93badcddc00a4cab93418 Mon Sep 17 00:00:00 2001 From: realtyem Date: Thu, 1 Dec 2022 06:38:27 -0600 Subject: Modernize unit tests configuration settings for workers. (#14568) Use the newer foo_instances configuration instead of the deprecated flags to enable specific features (e.g. start_pushers). --- changelog.d/14568.misc | 1 + tests/events/test_presence_router.py | 16 ++++-- tests/federation/test_federation_catch_up.py | 21 +++++--- tests/federation/test_federation_sender.py | 27 +++++++++-- tests/handlers/test_presence.py | 3 +- tests/handlers/test_typing.py | 6 ++- tests/handlers/test_user_directory.py | 7 ++- tests/module_api/test_api.py | 3 +- tests/push/test_email.py | 1 - tests/push/test_http.py | 7 +-- tests/replication/_base.py | 2 +- tests/replication/tcp/streams/test_federation.py | 5 +- tests/replication/test_auth.py | 4 +- tests/replication/test_client_reader_shard.py | 14 +++--- tests/replication/test_federation_ack.py | 5 +- tests/replication/test_federation_sender_shard.py | 59 ++++++++++++++--------- tests/replication/test_pusher_shard.py | 15 ++---- tests/utils.py | 8 +-- 18 files changed, 123 insertions(+), 81 deletions(-) create mode 100644 changelog.d/14568.misc (limited to 'tests/handlers') diff --git a/changelog.d/14568.misc b/changelog.d/14568.misc new file mode 100644 index 0000000000..99973de1c1 --- /dev/null +++ b/changelog.d/14568.misc @@ -0,0 +1 @@ +Modernize unit tests configuration related to workers. diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 685a9a6d52..b703e4472e 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -126,6 +126,13 @@ class PresenceRouterTestModule: class PresenceRouterTestCase(FederatingHomeserverTestCase): + """ + Test cases using a custom PresenceRouter + + By default in test cases, federation sending is disabled. This class re-enables it + for the main process by setting `federation_sender_instances` to None. + """ + servlets = [ admin.register_servlets, login.register_servlets, @@ -150,6 +157,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): self.sync_handler = self.hs.get_sync_handler() self.module_api = homeserver.get_module_api() + def default_config(self) -> JsonDict: + config = super().default_config() + config["federation_sender_instances"] = None + return config + @override_config( { "presence": { @@ -162,7 +174,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } }, - "send_federation": True, } ) def test_receiving_all_presence_legacy(self): @@ -180,7 +191,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, }, ], - "send_federation": True, } ) def test_receiving_all_presence(self): @@ -290,7 +300,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } }, - "send_federation": True, } ) def test_send_local_online_presence_to_with_module_legacy(self): @@ -310,7 +319,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, }, ], - "send_federation": True, } ) def test_send_local_online_presence_to_with_module(self): diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 2873b4d430..b8fee72898 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -7,13 +7,21 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin from synapse.rest.client import login, room +from synapse.types import JsonDict from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable -from tests.unittest import FederatingHomeserverTestCase, override_config +from tests.unittest import FederatingHomeserverTestCase class FederationCatchUpTestCases(FederatingHomeserverTestCase): + """ + Tests cases of catching up over federation. + + By default for test cases federation sending is disabled. This Test class has it + re-enabled for the main process. + """ + servlets = [ admin.register_servlets, room.register_servlets, @@ -42,6 +50,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.record_transaction ) + def default_config(self) -> JsonDict: + config = super().default_config() + config["federation_sender_instances"] = None + return config + async def record_transaction(self, txn, json_cb): if self.is_online: data = json_cb() @@ -79,7 +92,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): )[0] return {"event_id": event_id, "stream_ordering": stream_ordering} - @override_config({"send_federation": True}) def test_catch_up_destination_rooms_tracking(self): """ Tests that we populate the `destination_rooms` table as needed. @@ -105,7 +117,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.assertEqual(row_2["event_id"], event_id_2) self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1) - @override_config({"send_federation": True}) def test_catch_up_last_successful_stream_ordering_tracking(self): """ Tests that we populate the `destination_rooms` table as needed. @@ -163,7 +174,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): "Send succeeded but not marked as last_successful_stream_ordering", ) - @override_config({"send_federation": True}) # critical to federate def test_catch_up_from_blank_state(self): """ Runs an overall test of federation catch-up from scratch. @@ -260,7 +270,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): return per_dest_queue, results_list - @override_config({"send_federation": True}) def test_catch_up_loop(self): """ Tests the behaviour of _catch_up_transmission_loop. @@ -325,7 +334,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): event_5.internal_metadata.stream_ordering, ) - @override_config({"send_federation": True}) def test_catch_up_on_synapse_startup(self): """ Tests the behaviour of get_catch_up_outstanding_destinations and @@ -424,7 +432,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - all destinations are woken exactly once; they appear once in woken. self.assertCountEqual(woken, server_names[:-1]) - @override_config({"send_federation": True}) def test_not_latest_event(self): """Test that we send the latest event in the room even if its not ours.""" diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index cbc99d30b9..8692d8190f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -25,10 +25,17 @@ from synapse.rest.client import login from synapse.types import JsonDict, ReadReceipt from tests.test_utils import make_awaitable -from tests.unittest import HomeserverTestCase, override_config +from tests.unittest import HomeserverTestCase class FederationSenderReceiptsTestCases(HomeserverTestCase): + """ + Test federation sending to update receipts. + + By default for test cases federation sending is disabled. This Test class has it + re-enabled for the main process. + """ + def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), @@ -44,7 +51,11 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): return hs - @override_config({"send_federation": True}) + def default_config(self) -> JsonDict: + config = super().default_config() + config["federation_sender_instances"] = None + return config + def test_send_receipts(self): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction @@ -87,7 +98,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ], ) - @override_config({"send_federation": True}) def test_send_receipts_thread(self): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction @@ -164,7 +174,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ], ) - @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" @@ -251,6 +260,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): class FederationSenderDevicesTestCases(HomeserverTestCase): + """ + Test federation sending to update devices. + + By default for test cases federation sending is disabled. This Test class has it + re-enabled for the main process. + """ + servlets = [ admin.register_servlets, login.register_servlets, @@ -265,7 +281,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def default_config(self): c = super().default_config() - c["send_federation"] = True + # Enable federation sending on the main process. + c["federation_sender_instances"] = None return c def prepare(self, reactor, clock, hs): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c5981ff965..584e7b8971 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -992,7 +992,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def default_config(self): config = super().default_config() - config["send_federation"] = True + # Enable federation sending on the main process. + config["federation_sender_instances"] = None return config def prepare(self, reactor, clock, hs): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 9c821b3042..efbb5a8dbb 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -200,7 +200,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) - @override_config({"send_federation": True}) + # Enable federation sending on the main process. + @override_config({"federation_sender_instances": None}) def test_started_typing_remote_send(self) -> None: self.room_members = [U_APPLE, U_ONION] @@ -305,7 +306,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0], []) self.assertEqual(events[1], 0) - @override_config({"send_federation": True}) + # Enable federation sending on the main process. + @override_config({"federation_sender_instances": None}) def test_stopped_typing(self) -> None: self.room_members = [U_APPLE, U_BANANA, U_ONION] diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 9e39cd97e5..75fc5a17a4 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -56,7 +56,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["update_user_directory"] = True + # Re-enables updating the user directory, as that function is needed below. + config["update_user_directory_from_worker"] = None self.appservice = ApplicationService( token="i_am_an_app_service", @@ -1045,7 +1046,9 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["update_user_directory"] = True + # Re-enables updating the user directory, as that function is needed below. It + # will be force disabled later + config["update_user_directory_from_worker"] = None hs = self.setup_test_homeserver(config=config) self.config = hs.config diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 058ca57e55..b0f3f4374d 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -336,7 +336,8 @@ class ModuleApiTestCase(HomeserverTestCase): # Test sending local online presence to users from the main process _test_sending_local_online_presence_to_local_user(self, test_with_workers=False) - @override_config({"send_federation": True}) + # Enable federation sending on the main process. + @override_config({"federation_sender_instances": None}) def test_send_local_online_presence_to_federation(self): """Tests that send_local_presence_to_users sends local online presence to remote users.""" # Create a user who will send presence updates diff --git a/tests/push/test_email.py b/tests/push/test_email.py index fd14568f55..57b2f0536e 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -66,7 +66,6 @@ class EmailPusherTests(HomeserverTestCase): "riot_base_url": None, } config["public_baseurl"] = "http://aaa" - config["start_pushers"] = True hs = self.setup_test_homeserver(config=config) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index b383b8401f..afaafe79aa 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -41,11 +41,6 @@ class HTTPPusherTests(HomeserverTestCase): user_id = True hijack_auth = False - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["start_pushers"] = True - return config - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.push_attempts: List[Tuple[Deferred, str, dict]] = [] diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 3029a16dda..6a7174b333 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -307,7 +307,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): stream to the master HS. Args: - worker_app: Type of worker, e.g. `synapse.app.federation_sender`. + worker_app: Type of worker, e.g. `synapse.app.generic_worker`. extra_config: Any extra config to use for this instances. **kwargs: Options that get passed to `self.setup_test_homeserver`, useful to e.g. pass some mocks for things like `federation_http_client` diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py index ffec06a0d6..bcb82c9c80 100644 --- a/tests/replication/tcp/streams/test_federation.py +++ b/tests/replication/tcp/streams/test_federation.py @@ -22,9 +22,8 @@ class FederationStreamTestCase(BaseStreamTestCase): def _get_worker_hs_config(self) -> dict: # enable federation sending on the worker config = super()._get_worker_hs_config() - # TODO: make it so we don't need both of these - config["send_federation"] = False - config["worker_app"] = "synapse.app.federation_sender" + config["worker_name"] = "federation_sender1" + config["federation_sender_instances"] = ["federation_sender1"] return config def test_catchup(self): diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index 43a16bb141..5d7a89e0c7 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -38,7 +38,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): def _get_worker_hs_config(self) -> dict: config = self.default_config() - config["worker_app"] = "synapse.app.client_reader" + config["worker_app"] = "synapse.app.generic_worker" config["worker_replication_host"] = "testserv" config["worker_replication_http_port"] = "8765" @@ -53,7 +53,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): 4. Return the final request. """ - worker_hs = self.make_worker_hs("synapse.app.client_reader") + worker_hs = self.make_worker_hs("synapse.app.generic_worker") site = self._hs_to_site[worker_hs] channel_1 = make_request( diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 995097d72c..eb5b376534 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -22,20 +22,20 @@ logger = logging.getLogger(__name__) class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): - """Test using one or more client readers for registration.""" + """Test using one or more generic workers for registration.""" servlets = [register.register_servlets] def _get_worker_hs_config(self) -> dict: config = self.default_config() - config["worker_app"] = "synapse.app.client_reader" + config["worker_app"] = "synapse.app.generic_worker" config["worker_replication_host"] = "testserv" config["worker_replication_http_port"] = "8765" return config def test_register_single_worker(self): - """Test that registration works when using a single client reader worker.""" - worker_hs = self.make_worker_hs("synapse.app.client_reader") + """Test that registration works when using a single generic worker.""" + worker_hs = self.make_worker_hs("synapse.app.generic_worker") site = self._hs_to_site[worker_hs] channel_1 = make_request( @@ -64,9 +64,9 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(channel_2.json_body["user_id"], "@user:test") def test_register_multi_worker(self): - """Test that registration works when using multiple client reader workers.""" - worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") - worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") + """Test that registration works when using multiple generic workers.""" + worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker") + worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker") site_1 = self._hs_to_site[worker_hs_1] channel_1 = make_request( diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 26b8bd512a..63b1dd40b5 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -25,8 +25,9 @@ from tests.unittest import HomeserverTestCase class FederationAckTestCase(HomeserverTestCase): def default_config(self) -> dict: config = super().default_config() - config["worker_app"] = "synapse.app.federation_sender" - config["send_federation"] = False + config["worker_app"] = "synapse.app.generic_worker" + config["worker_name"] = "federation_sender1" + config["federation_sender_instances"] = ["federation_sender1"] return config def make_homeserver(self, reactor, clock): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 6104a55aa1..c28073b8f7 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -27,17 +27,19 @@ logger = logging.getLogger(__name__) class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): + """ + Various tests for federation sending on workers. + + Federation sending is disabled by default, it will be enabled in each test by + updating 'federation_sender_instances'. + """ + servlets = [ login.register_servlets, register_servlets_for_client_rest_resource, room.register_servlets, ] - def default_config(self): - conf = super().default_config() - conf["send_federation"] = False - return conf - def test_send_event_single_sender(self): """Test that using a single federation sender worker correctly sends a new event. @@ -46,8 +48,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", - {"send_federation": False}, + "synapse.app.generic_worker", + { + "worker_name": "federation_sender1", + "federation_sender_instances": ["federation_sender1"], + }, federation_http_client=mock_client, ) @@ -73,11 +78,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client1 = Mock(spec=["put_json"]) mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", + "synapse.app.generic_worker", { - "send_federation": True, - "worker_name": "sender1", - "federation_sender_instances": ["sender1", "sender2"], + "worker_name": "federation_sender1", + "federation_sender_instances": [ + "federation_sender1", + "federation_sender2", + ], }, federation_http_client=mock_client1, ) @@ -85,11 +92,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client2 = Mock(spec=["put_json"]) mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", + "synapse.app.generic_worker", { - "send_federation": True, - "worker_name": "sender2", - "federation_sender_instances": ["sender1", "sender2"], + "worker_name": "federation_sender2", + "federation_sender_instances": [ + "federation_sender1", + "federation_sender2", + ], }, federation_http_client=mock_client2, ) @@ -136,11 +145,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client1 = Mock(spec=["put_json"]) mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", + "synapse.app.generic_worker", { - "send_federation": True, - "worker_name": "sender1", - "federation_sender_instances": ["sender1", "sender2"], + "worker_name": "federation_sender1", + "federation_sender_instances": [ + "federation_sender1", + "federation_sender2", + ], }, federation_http_client=mock_client1, ) @@ -148,11 +159,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client2 = Mock(spec=["put_json"]) mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", + "synapse.app.generic_worker", { - "send_federation": True, - "worker_name": "sender2", - "federation_sender_instances": ["sender1", "sender2"], + "worker_name": "federation_sender2", + "federation_sender_instances": [ + "federation_sender1", + "federation_sender2", + ], }, federation_http_client=mock_client2, ) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 59fea93e49..ca18ad6553 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -38,11 +38,6 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.other_user_id = self.register_user("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass") - def default_config(self): - conf = super().default_config() - conf["start_pushers"] = False - return conf - def _create_pusher_and_send_msg(self, localpart): # Create a user that will get push notifications user_id = self.register_user(localpart, "pass") @@ -92,8 +87,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): ) self.make_worker_hs( - "synapse.app.pusher", - {"start_pushers": False}, + "synapse.app.generic_worker", + {"worker_name": "pusher1", "pusher_instances": ["pusher1"]}, proxied_blacklisted_http_client=http_client_mock, ) @@ -122,9 +117,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): ) self.make_worker_hs( - "synapse.app.pusher", + "synapse.app.generic_worker", { - "start_pushers": True, "worker_name": "pusher1", "pusher_instances": ["pusher1", "pusher2"], }, @@ -137,9 +131,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): ) self.make_worker_hs( - "synapse.app.pusher", + "synapse.app.generic_worker", { - "start_pushers": True, "worker_name": "pusher2", "pusher_instances": ["pusher1", "pusher2"], }, diff --git a/tests/utils.py b/tests/utils.py index 045a8b5fa7..d76bf9716a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -125,7 +125,8 @@ def default_config( """ config_dict = { "server_name": name, - "send_federation": False, + # Setting this to an empty list turns off federation sending. + "federation_sender_instances": [], "media_store_path": "media", # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy @@ -183,8 +184,9 @@ def default_config( # rooms will fail. "default_room_version": DEFAULT_ROOM_VERSION, # disable user directory updates, because they get done in the - # background, which upsets the test runner. - "update_user_directory": False, + # background, which upsets the test runner. Setting this to an + # (obviously) fake worker name disables updating the user directory. + "update_user_directory_from_worker": "does_not_exist_worker_name", "caches": {"global_factor": 1, "sync_response_cache_duration": 0}, "listeners": [{"port": 0, "type": "http"}], } -- cgit 1.5.1 From cb59e080627745d089d073d9dac276362d9abaf6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 6 Dec 2022 09:52:55 +0000 Subject: Improve logging and opentracing for to-device message handling (#14598) A batch of changes intended to make it easier to trace to-device messages through the system. The intention here is that a client can set a property org.matrix.msgid in any to-device message it sends. That ID is then included in any tracing or logging related to the message. (Suggestions as to where this field should be documented welcome. I'm not enthusiastic about speccing it - it's very much an optional extra to help with debugging.) I've also generally improved the data we send to opentracing for these messages. --- changelog.d/14598.feature | 1 + synapse/api/constants.py | 3 + synapse/federation/sender/per_destination_queue.py | 2 +- synapse/handlers/appservice.py | 3 - synapse/handlers/devicemessage.py | 36 +++++---- synapse/handlers/sync.py | 26 ++++-- synapse/logging/opentracing.py | 11 ++- synapse/rest/client/sendtodevice.py | 1 - synapse/storage/databases/main/deviceinbox.py | 92 ++++++++++++++++++---- tests/handlers/test_appservice.py | 7 +- 10 files changed, 136 insertions(+), 46 deletions(-) create mode 100644 changelog.d/14598.feature (limited to 'tests/handlers') diff --git a/changelog.d/14598.feature b/changelog.d/14598.feature new file mode 100644 index 0000000000..88d561e286 --- /dev/null +++ b/changelog.d/14598.feature @@ -0,0 +1 @@ +Improve opentracing and logging for to-device message handling. \ No newline at end of file diff --git a/synapse/api/constants.py b/synapse/api/constants.py index bc04a0755b..89723d24fa 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -230,6 +230,9 @@ class EventContentFields: # The authorising user for joining a restricted room. AUTHORISING_USER: Final = "join_authorised_via_users_server" + # an unspecced field added to to-device messages to identify them uniquely-ish + TO_DEVICE_MSGID: Final = "org.matrix.msgid" + class RoomTypes: """Understood values of the room_type field of m.room.create events.""" diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 5af2784f1e..ffc9d95ee7 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -641,7 +641,7 @@ class PerDestinationQueue: if not message_id: continue - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id) edus = [ Edu( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index f68027aaed..5d1d21cdc8 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -578,9 +578,6 @@ class ApplicationServicesHandler: device_id, ), messages in recipient_device_to_messages.items(): for message_json in messages: - # Remove 'message_id' from the to-device message, as it's an internal ID - message_json.pop("message_id", None) - message_payload.append( { "to_user_id": user_id, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 444c08bc2e..75e89850f5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict -from synapse.api.constants import EduTypes, ToDeviceEventTypes +from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes from synapse.api.errors import SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background @@ -216,14 +216,24 @@ class DeviceMessageHandler: """ sender_user_id = requester.user.to_string() - message_id = random_string(16) - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) - - log_kv({"number_of_to_device_messages": len(messages)}) - set_tag("sender", sender_user_id) + set_tag(SynapseTags.TO_DEVICE_TYPE, message_type) + set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id) local_messages = {} remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): + # add an opentracing log entry for each message + for device_id, message_content in by_device.items(): + log_kv( + { + "event": "send_to_device_message", + "user_id": user_id, + "device_id": device_id, + EventContentFields.TO_DEVICE_MSGID: message_content.get( + EventContentFields.TO_DEVICE_MSGID + ), + } + ) + # Ratelimit local cross-user key requests by the sending device. if ( message_type == ToDeviceEventTypes.RoomKeyRequest @@ -233,6 +243,7 @@ class DeviceMessageHandler: requester, (sender_user_id, requester.device_id) ) if not allowed: + log_kv({"message": f"dropping key requests to {user_id}"}) logger.info( "Dropping room_key_request from %s to %s due to rate limit", sender_user_id, @@ -247,18 +258,11 @@ class DeviceMessageHandler: "content": message_content, "type": message_type, "sender": sender_user_id, - "message_id": message_id, } for device_id, message_content in by_device.items() } if messages_by_device: local_messages[user_id] = messages_by_device - log_kv( - { - "user_id": user_id, - "device_id": list(messages_by_device), - } - ) else: destination = get_domain_from_id(user_id) remote_messages.setdefault(destination, {})[user_id] = by_device @@ -267,7 +271,11 @@ class DeviceMessageHandler: remote_edu_contents = {} for destination, messages in remote_messages.items(): - log_kv({"destination": destination}) + # The EDU contains a "message_id" property which is used for + # idempotence. Make up a random one. + message_id = random_string(16) + log_kv({"destination": destination, "message_id": message_id}) + remote_edu_contents[destination] = { "messages": messages, "sender": sender_user_id, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0b395a104d..dace9b606f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,14 +31,20 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context -from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span +from synapse.logging.opentracing import ( + SynapseTags, + log_kv, + set_tag, + start_active_span, + trace, +) from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary @@ -1586,6 +1592,7 @@ class SyncHandler: else: return DeviceListUpdates() + @trace async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" ) -> None: @@ -1605,11 +1612,16 @@ class SyncHandler: ) for message in messages: - # We pop here as we shouldn't be sending the message ID down - # `/sync` - message_id = message.pop("message_id", None) - if message_id: - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + log_kv( + { + "event": "to_device_message", + "sender": message["sender"], + "type": message["type"], + EventContentFields.TO_DEVICE_MSGID: message["content"].get( + EventContentFields.TO_DEVICE_MSGID + ), + } + ) logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index b69060854f..a705af8356 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -292,8 +292,15 @@ logger = logging.getLogger(__name__) class SynapseTags: - # The message ID of any to_device message processed - TO_DEVICE_MESSAGE_ID = "to_device.message_id" + # The message ID of any to_device EDU processed + TO_DEVICE_EDU_ID = "to_device.edu_id" + + # Details about to-device messages + TO_DEVICE_TYPE = "to_device.type" + TO_DEVICE_SENDER = "to_device.sender" + TO_DEVICE_RECIPIENT = "to_device.recipient" + TO_DEVICE_RECIPIENT_DEVICE = "to_device.recipient_device" + TO_DEVICE_MSGID = "to_device.msgid" # client-generated ID # Whether the sync response has new data to be returned to the client. SYNC_RESULT = "sync.new_data" diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index 46a8b03829..55d52f0b28 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -46,7 +46,6 @@ class SendToDeviceRestServlet(servlet.RestServlet): def on_PUT( self, request: SynapseRequest, message_type: str, txn_id: str ) -> Awaitable[Tuple[int, JsonDict]]: - set_tag("message_type", message_type) set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( request, self._put, request, message_type, txn_id diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 73c95ffb6f..48a54d9cb8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -26,8 +26,15 @@ from typing import ( cast, ) +from synapse.api.constants import EventContentFields from synapse.logging import issue9533_logger -from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.logging.opentracing import ( + SynapseTags, + log_kv, + set_tag, + start_active_span, + trace, +) from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -397,6 +404,17 @@ class DeviceInboxWorkerStore(SQLBaseStore): (recipient_user_id, recipient_device_id), [] ).append(message_dict) + # start a new span for each message, so that we can tag each separately + with start_active_span("get_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, recipient_user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, recipient_device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + if limit is not None and rowcount == limit: # We ended up bumping up against the message limit. There may be more messages # to retrieve. Return what we have, as well as the last stream position that @@ -678,12 +696,35 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - if remote_messages_by_destination: - issue9533_logger.debug( - "Queued outgoing to-device messages with stream_id %i for %s", - stream_id, - list(remote_messages_by_destination.keys()), - ) + for destination, edu in remote_messages_by_destination.items(): + if issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Queued outgoing to-device messages with " + "stream_id %i, EDU message_id %s, type %s for %s: %s", + stream_id, + edu["message_id"], + edu["type"], + destination, + [ + f"{user_id}/{device_id} (msgid " + f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})" + for (user_id, messages_by_device) in edu["messages"].items() + for (device_id, msg) in messages_by_device.items() + ], + ) + + for (user_id, messages_by_device) in edu["messages"].items(): + for (device_id, msg) in messages_by_device.items(): + with start_active_span("store_outgoing_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) + set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) + set_tag(SynapseTags.TO_DEVICE_TYPE, edu["type"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + msg.get(EventContentFields.TO_DEVICE_MSGID), + ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self._clock.time_msec() @@ -801,7 +842,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] - message_json = json_encoder.encode(messages_by_device[device_id]) + + with start_active_span("serialise_to_device_message"): + msg = messages_by_device[device_id] + set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, msg["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + msg["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + message_json = json_encoder.encode(msg) + messages_json_for_user[device_id] = message_json if messages_json_for_user: @@ -821,15 +874,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - issue9533_logger.debug( - "Stored to-device messages with stream_id %i for %s", - stream_id, - [ - (user_id, device_id) - for (user_id, messages_by_device) in local_by_user_then_device.items() - for device_id in messages_by_device.keys() - ], - ) + if issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Stored to-device messages with stream_id %i: %s", + stream_id, + [ + f"{user_id}/{device_id} (msgid " + f"{msg['content'].get(EventContentFields.TO_DEVICE_MSGID)})" + for ( + user_id, + messages_by_device, + ) in messages_by_user_then_device.items() + for (device_id, msg) in messages_by_device.items() + ], + ) class DeviceInboxBackgroundUpdateStore(SQLBaseStore): diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 9ed26d87a7..57bfbd7734 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -765,7 +765,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)] messages = { self.exclusive_as_user: { - device_id: to_device_message_content for device_id in fake_device_ids + device_id: { + "type": "test_to_device_message", + "sender": "@some:sender", + "content": to_device_message_content, + } + for device_id in fake_device_ids } } -- cgit 1.5.1 From c2de2ca63060324cf2f80ddf3289b0fd7a4d861b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Dec 2022 09:37:07 +0000 Subject: Delete stale non-e2e devices for users, take 2 (#14595) This should help reduce the number of devices e.g. simple bots the repeatedly login rack up. We only delete non-e2e devices as they should be safe to delete, whereas if we delete e2e devices for a user we may accidentally break their ability to receive e2e keys for a message. --- changelog.d/14595.misc | 1 + synapse/handlers/device.py | 31 +++++++++++- synapse/storage/databases/main/devices.py | 79 ++++++++++++++++++++++++++++++- tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14595.misc (limited to 'tests/handlers') diff --git a/changelog.d/14595.misc b/changelog.d/14595.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14595.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d4750a32e6..7674c187ef 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -52,6 +52,7 @@ from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable +from synapse.util.iterutils import batch_iter from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -421,6 +422,9 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) + # Prune the user's device list if they already have a lot of devices. + await self._prune_too_many_devices(user_id) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -452,6 +456,31 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") + async def _prune_too_many_devices(self, user_id: str) -> None: + """Delete any excess old devices this user may have.""" + device_ids = await self.store.check_too_many_devices_for_user(user_id) + if not device_ids: + return + + # We don't want to block and try and delete tonnes of devices at once, + # so we cap the number of devices we delete synchronously. + first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] + await self.delete_devices(user_id, first_batch) + + if not remaining_device_ids: + return + + # Now spawn a background loop that deletes the rest. + async def _prune_too_many_devices_loop() -> None: + for batch in batch_iter(remaining_device_ids, 10): + await self.delete_devices(user_id, batch) + + await self.clock.sleep(1) + + run_as_background_process( + "_prune_too_many_devices_loop", _prune_too_many_devices_loop + ) + async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -481,7 +510,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a5bb4d404e..08ccd46a2b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,6 +1569,72 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def check_too_many_devices_for_user(self, user_id: str) -> List[str]: + """Check if the user has a lot of devices, and if so return the set of + devices we can prune. + + This does *not* return hidden devices or devices with E2E keys. + """ + + num_devices = await self.db_pool.simple_select_one_onecol( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcol="COALESCE(COUNT(*), 0)", + desc="count_devices", + ) + + # We let users have up to ten devices without pruning. + if num_devices <= 10: + return [] + + # We prune everything older than N days. + max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 + + if num_devices > 50: + # If the user has more than 50 devices, then we chose a last seen + # that ensures we keep at most 50 devices. + sql = """ + SELECT last_seen FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen IS NOT NULL + AND key_json IS NULL + ORDER BY last_seen DESC + LIMIT 1 + OFFSET 50 + """ + + rows = await self.db_pool.execute( + "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) + ) + if rows: + max_last_seen = max(rows[0][0], max_last_seen) + + # Now fetch the devices to delete. + sql = """ + SELECT DISTINCT device_id FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen < ? + AND key_json IS NULL + ORDER BY last_seen + """ + + def check_too_many_devices_for_user_txn( + txn: LoggingTransaction, + ) -> List[str]: + txn.execute(sql, (user_id, max_last_seen)) + return [device_id for device_id, in txn] + + return await self.db_pool.runInteraction( + "check_too_many_devices_for_user", + check_too_many_devices_for_user_txn, + ) + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1627,6 +1693,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, + "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1672,7 +1739,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + @cached(max_entries=0) + async def delete_device(self, user_id: str, device_id: str) -> None: + raise NotImplementedError() + + # Note: sometimes deleting rows out of `device_inbox` can take a long time, + # so we use a cache so that we deduplicate in flight requests to delete + # devices. + @cachedList(cached_method_name="delete_device", list_name="device_ids") + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> dict: """Deletes several devices. Args: @@ -1709,6 +1784,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) + return {} + async def update_device( self, user_id: str, device_id: str, new_display_name: Optional[str] = None ) -> None: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ce7525e29c..a456bffd63 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": None, + "last_seen_ts": 1000000, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 49ad3c1324..a9af1babed 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,6 +169,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) + last_seen = self.clock.time_msec() + if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -189,7 +191,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": None, + "last_seen": last_seen, }, ], ) -- cgit 1.5.1 From 94bc21e69f89ad873ad7a0deb6d9c4ff3cb480ef Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Dec 2022 13:31:32 +0000 Subject: Limit the number of devices we delete at once (#14649) --- changelog.d/14649.misc | 1 + synapse/handlers/device.py | 4 +++- synapse/storage/databases/main/devices.py | 11 ++++++++--- tests/handlers/test_device.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14649.misc (limited to 'tests/handlers') diff --git a/changelog.d/14649.misc b/changelog.d/14649.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14649.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7674c187ef..c935c7be90 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -458,10 +458,12 @@ class DeviceHandler(DeviceWorkerHandler): async def _prune_too_many_devices(self, user_id: str) -> None: """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id) + device_ids = await self.store.check_too_many_devices_for_user(user_id, 100) if not device_ids: return + logger.info("Pruning %d old devices for user %s", len(device_ids), user_id) + # We don't want to block and try and delete tonnes of devices at once, # so we cap the number of devices we delete synchronously. first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 08ccd46a2b..95d4c0622d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,11 +1569,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user(self, user_id: str) -> List[str]: + async def check_too_many_devices_for_user( + self, user_id: str, limit: int + ) -> List[str]: """Check if the user has a lot of devices, and if so return the set of devices we can prune. This does *not* return hidden devices or devices with E2E keys. + + Returns at most `limit` number of devices, ordered by last seen. """ num_devices = await self.db_pool.simple_select_one_onecol( @@ -1614,7 +1618,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # Now fetch the devices to delete. sql = """ - SELECT DISTINCT device_id FROM devices + SELECT device_id FROM devices LEFT JOIN e2e_device_keys_json USING (user_id, device_id) WHERE user_id = ? @@ -1622,12 +1626,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): AND last_seen < ? AND key_json IS NULL ORDER BY last_seen + LIMIT ? """ def check_too_many_devices_for_user_txn( txn: LoggingTransaction, ) -> List[str]: - txn.execute(sql, (user_id, max_last_seen)) + txn.execute(sql, (user_id, max_last_seen, limit)) return [device_id for device_id, in txn] return await self.db_pool.runInteraction( diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a456bffd63..e51cac9b33 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -20,6 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler +from synapse.rest import admin +from synapse.rest.client import account, login from synapse.server import HomeServer from synapse.util import Clock @@ -30,6 +32,12 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + account.register_servlets, + ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) handler = hs.get_device_handler() @@ -229,6 +237,29 @@ class DeviceTestCase(unittest.HomeserverTestCase): NotFoundError, ) + def test_login_delete_old_devices(self) -> None: + """Delete old devices if the user already has too many.""" + + user_id = self.register_user("user", "pass") + + # Create a bunch of devices + for _ in range(50): + self.login("user", "pass") + self.reactor.advance(1) + + # Advance the clock for ages (as we only delete old devices) + self.reactor.advance(60 * 60 * 24 * 300) + + # Log in again to start the pruning + self.login("user", "pass") + + # Give the background job time to do its thing + self.reactor.pump([1.0] * 100) + + # We should now only have the most recent device. + devices = self.get_success(self.handler.get_devices_by_user(user_id)) + self.assertEqual(len(devices), 1) + def _record_users(self) -> None: # check this works for both devices which have a recorded client_ip, # and those which don't. -- cgit 1.5.1 From 74b89c27613a34ec9b291ad3066db7ce0adff1db Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 12 Dec 2022 13:55:23 +0000 Subject: Revert the deletion of stale devices due to performance issues. (#14662) --- changelog.d/14595.misc | 1 - changelog.d/14649.misc | 1 - changelog.d/14662.removal | 1 + synapse/handlers/device.py | 33 +----------- synapse/storage/databases/main/devices.py | 84 +------------------------------ tests/handlers/test_device.py | 33 +----------- tests/storage/test_client_ips.py | 4 +- 7 files changed, 5 insertions(+), 152 deletions(-) delete mode 100644 changelog.d/14595.misc delete mode 100644 changelog.d/14649.misc create mode 100644 changelog.d/14662.removal (limited to 'tests/handlers') diff --git a/changelog.d/14595.misc b/changelog.d/14595.misc deleted file mode 100644 index f9bfc581ad..0000000000 --- a/changelog.d/14595.misc +++ /dev/null @@ -1 +0,0 @@ -Prune user's old devices on login if they have too many. diff --git a/changelog.d/14649.misc b/changelog.d/14649.misc deleted file mode 100644 index f9bfc581ad..0000000000 --- a/changelog.d/14649.misc +++ /dev/null @@ -1 +0,0 @@ -Prune user's old devices on login if they have too many. diff --git a/changelog.d/14662.removal b/changelog.d/14662.removal new file mode 100644 index 0000000000..19a387bbb4 --- /dev/null +++ b/changelog.d/14662.removal @@ -0,0 +1 @@ +(remove from changelog: unreleased) Revert the deletion of stale devices due to performance issues. \ No newline at end of file diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c935c7be90..d4750a32e6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -52,7 +52,6 @@ from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable -from synapse.util.iterutils import batch_iter from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -422,9 +421,6 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) - # Prune the user's device list if they already have a lot of devices. - await self._prune_too_many_devices(user_id) - if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -456,33 +452,6 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") - async def _prune_too_many_devices(self, user_id: str) -> None: - """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id, 100) - if not device_ids: - return - - logger.info("Pruning %d old devices for user %s", len(device_ids), user_id) - - # We don't want to block and try and delete tonnes of devices at once, - # so we cap the number of devices we delete synchronously. - first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] - await self.delete_devices(user_id, first_batch) - - if not remaining_device_ids: - return - - # Now spawn a background loop that deletes the rest. - async def _prune_too_many_devices_loop() -> None: - for batch in batch_iter(remaining_device_ids, 10): - await self.delete_devices(user_id, batch) - - await self.clock.sleep(1) - - run_as_background_process( - "_prune_too_many_devices_loop", _prune_too_many_devices_loop - ) - async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -512,7 +481,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 95d4c0622d..a5bb4d404e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,77 +1569,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user( - self, user_id: str, limit: int - ) -> List[str]: - """Check if the user has a lot of devices, and if so return the set of - devices we can prune. - - This does *not* return hidden devices or devices with E2E keys. - - Returns at most `limit` number of devices, ordered by last seen. - """ - - num_devices = await self.db_pool.simple_select_one_onecol( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcol="COALESCE(COUNT(*), 0)", - desc="count_devices", - ) - - # We let users have up to ten devices without pruning. - if num_devices <= 10: - return [] - - # We prune everything older than N days. - max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 - - if num_devices > 50: - # If the user has more than 50 devices, then we chose a last seen - # that ensures we keep at most 50 devices. - sql = """ - SELECT last_seen FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen IS NOT NULL - AND key_json IS NULL - ORDER BY last_seen DESC - LIMIT 1 - OFFSET 50 - """ - - rows = await self.db_pool.execute( - "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) - ) - if rows: - max_last_seen = max(rows[0][0], max_last_seen) - - # Now fetch the devices to delete. - sql = """ - SELECT device_id FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen < ? - AND key_json IS NULL - ORDER BY last_seen - LIMIT ? - """ - - def check_too_many_devices_for_user_txn( - txn: LoggingTransaction, - ) -> List[str]: - txn.execute(sql, (user_id, max_last_seen, limit)) - return [device_id for device_id, in txn] - - return await self.db_pool.runInteraction( - "check_too_many_devices_for_user", - check_too_many_devices_for_user_txn, - ) - class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1698,7 +1627,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, - "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1744,15 +1672,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - @cached(max_entries=0) - async def delete_device(self, user_id: str, device_id: str) -> None: - raise NotImplementedError() - - # Note: sometimes deleting rows out of `device_inbox` can take a long time, - # so we use a cache so that we deduplicate in flight requests to delete - # devices. - @cachedList(cached_method_name="delete_device", list_name="device_ids") - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> dict: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: @@ -1789,8 +1709,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) - return {} - async def update_device( self, user_id: str, device_id: str, new_display_name: Optional[str] = None ) -> None: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index e51cac9b33..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -20,8 +20,6 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler -from synapse.rest import admin -from synapse.rest.client import account, login from synapse.server import HomeServer from synapse.util import Clock @@ -32,12 +30,6 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - admin.register_servlets, - account.register_servlets, - ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) handler = hs.get_device_handler() @@ -123,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": 1000000, + "last_seen_ts": None, }, device_map["xyz"], ) @@ -237,29 +229,6 @@ class DeviceTestCase(unittest.HomeserverTestCase): NotFoundError, ) - def test_login_delete_old_devices(self) -> None: - """Delete old devices if the user already has too many.""" - - user_id = self.register_user("user", "pass") - - # Create a bunch of devices - for _ in range(50): - self.login("user", "pass") - self.reactor.advance(1) - - # Advance the clock for ages (as we only delete old devices) - self.reactor.advance(60 * 60 * 24 * 300) - - # Log in again to start the pruning - self.login("user", "pass") - - # Give the background job time to do its thing - self.reactor.pump([1.0] * 100) - - # We should now only have the most recent device. - devices = self.get_success(self.handler.get_devices_by_user(user_id)) - self.assertEqual(len(devices), 1) - def _record_users(self) -> None: # check this works for both devices which have a recorded client_ip, # and those which don't. diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 81e4e596e4..7f7f4ef892 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -170,8 +170,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) - last_seen = self.clock.time_msec() - if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -192,7 +190,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": last_seen, + "last_seen": None, }, ], ) -- cgit 1.5.1