From a6f1a3abecf8e8fd3e1bff439a06b853df18f194 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 2 Dec 2021 01:02:20 -0600 Subject: Add MSC3030 experimental client and federation API endpoints to get the closest event to a given timestamp (#9445) MSC3030: https://github.com/matrix-org/matrix-doc/pull/3030 Client API endpoint. This will also go and fetch from the federation API endpoint if unable to find an event locally or we found an extremity with possibly a closer event we don't know about. ``` GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= { "event_id": ... "origin_server_ts": ... } ``` Federation API endpoint: ``` GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= { "event_id": ... "origin_server_ts": ... } ``` Co-authored-by: Erik Johnston --- synapse/handlers/federation.py | 61 ++++++++--------- synapse/handlers/room.py | 144 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 30 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3112cc88b1..1ea837d082 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -68,6 +68,37 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: + """Get joined domains from state + + Args: + state: State map from type/state key to event. + + Returns: + Returns a list of servers with the lowest depth of their joins. + Sorted by lowest depth first. + """ + joined_users = [ + (state_key, int(event.depth)) + for (e_type, state_key), event in state.items() + if e_type == EventTypes.Member and event.membership == Membership.JOIN + ] + + joined_domains: Dict[str, int] = {} + for u, d in joined_users: + try: + dom = get_domain_from_id(u) + old_d = joined_domains.get(dom) + if old_d: + joined_domains[dom] = min(d, old_d) + else: + joined_domains[dom] = d + except Exception: + pass + + return sorted(joined_domains.items(), key=lambda d: d[1]) + + class FederationHandler: """Handles general incoming federation requests @@ -268,36 +299,6 @@ class FederationHandler: curr_state = await self.state_handler.get_current_state(room_id) - def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: - """Get joined domains from state - - Args: - state: State map from type/state key to event. - - Returns: - Returns a list of servers with the lowest depth of their joins. - Sorted by lowest depth first. - """ - joined_users = [ - (state_key, int(event.depth)) - for (e_type, state_key), event in state.items() - if e_type == EventTypes.Member and event.membership == Membership.JOIN - ] - - joined_domains: Dict[str, int] = {} - for u, d in joined_users: - try: - dom = get_domain_from_id(u) - old_d = joined_domains.get(dom) - if old_d: - joined_domains[dom] = min(d, old_d) - else: - joined_domains[dom] = d - except Exception: - pass - - return sorted(joined_domains.items(), key=lambda d: d[1]) - curr_domains = get_domains_from_state(curr_state) likely_domains = [ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 88053f9869..2bcdf32dcc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -46,6 +46,7 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, + HttpResponseException, LimitExceededError, NotFoundError, StoreError, @@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents +from synapse.federation.federation_client import InvalidResponseError +from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.streams import EventSource @@ -1220,6 +1223,147 @@ class RoomContextHandler: return results +class TimestampLookupHandler: + def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.state_handler = hs.get_state_handler() + self.federation_client = hs.get_federation_client() + + async def get_event_for_timestamp( + self, + requester: Requester, + room_id: str, + timestamp: int, + direction: str, + ) -> Tuple[str, int]: + """Find the closest event to the given timestamp in the given direction. + If we can't find an event locally or the event we have locally is next to a gap, + it will ask other federated homeservers for an event. + + Args: + requester: The user making the request according to the access token + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A tuple containing the `event_id` closest to the given timestamp in + the given direction and the `origin_server_ts`. + + Raises: + SynapseError if unable to find any event locally in the given direction + """ + + local_event_id = await self.store.get_event_id_for_timestamp( + room_id, timestamp, direction + ) + logger.debug( + "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s", + local_event_id, + timestamp, + ) + + # Check for gaps in the history where events could be hiding in between + # the timestamp given and the event we were able to find locally + is_event_next_to_backward_gap = False + is_event_next_to_forward_gap = False + if local_event_id: + local_event = await self.store.get_event( + local_event_id, allow_none=False, allow_rejected=False + ) + + if direction == "f": + # We only need to check for a backward gap if we're looking forwards + # to ensure there is nothing in between. + is_event_next_to_backward_gap = ( + await self.store.is_event_next_to_backward_gap(local_event) + ) + elif direction == "b": + # We only need to check for a forward gap if we're looking backwards + # to ensure there is nothing in between + is_event_next_to_forward_gap = ( + await self.store.is_event_next_to_forward_gap(local_event) + ) + + # If we found a gap, we should probably ask another homeserver first + # about more history in between + if ( + not local_event_id + or is_event_next_to_backward_gap + or is_event_next_to_forward_gap + ): + logger.debug( + "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first", + local_event_id, + timestamp, + ) + + # Find other homeservers from the given state in the room + curr_state = await self.state_handler.get_current_state(room_id) + curr_domains = get_domains_from_state(curr_state) + likely_domains = [ + domain for domain, depth in curr_domains if domain != self.server_name + ] + + # Loop through each homeserver candidate until we get a succesful response + for domain in likely_domains: + try: + remote_response = await self.federation_client.timestamp_to_event( + domain, room_id, timestamp, direction + ) + logger.debug( + "get_event_for_timestamp: response from domain(%s)=%s", + domain, + remote_response, + ) + + # TODO: Do we want to persist this as an extremity? + # TODO: I think ideally, we would try to backfill from + # this event and run this whole + # `get_event_for_timestamp` function again to make sure + # they didn't give us an event from their gappy history. + remote_event_id = remote_response.event_id + origin_server_ts = remote_response.origin_server_ts + + # Only return the remote event if it's closer than the local event + if not local_event or ( + abs(origin_server_ts - timestamp) + < abs(local_event.origin_server_ts - timestamp) + ): + return remote_event_id, origin_server_ts + except (HttpResponseException, InvalidResponseError) as ex: + # Let's not put a high priority on some other homeserver + # failing to respond or giving a random response + logger.debug( + "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + domain, + type(ex).__name__, + ex, + ex.args, + ) + except Exception as ex: + # But we do want to see some exceptions in our code + logger.warning( + "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + domain, + type(ex).__name__, + ex, + ex.args, + ) + + if not local_event_id: + raise SynapseError( + 404, + "Unable to find event from %s in direction %s" % (timestamp, direction), + errcode=Codes.NOT_FOUND, + ) + + return local_event_id, local_event.origin_server_ts + + class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() -- cgit 1.5.1 From d26808dd854006bd26a2366c675428ce0737238c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 2 Dec 2021 20:58:32 +0000 Subject: Comments on the /sync tentacles (#11494) This mainly consists of docstrings and inline comments. There are one or two type annotations and variable renames thrown in while I was here. Co-authored-by: Patrick Cloke --- changelog.d/11494.misc | 1 + synapse/handlers/sync.py | 156 +++++++++++++++++++++++-------- synapse/storage/databases/main/stream.py | 15 ++- 3 files changed, 129 insertions(+), 43 deletions(-) create mode 100644 changelog.d/11494.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11494.misc b/changelog.d/11494.misc new file mode 100644 index 0000000000..7afd7d3a07 --- /dev/null +++ b/changelog.d/11494.misc @@ -0,0 +1 @@ +Add comments to various parts of the `/sync` handler. \ No newline at end of file diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 891435c14d..53d4627147 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -334,6 +334,19 @@ class SyncHandler: full_state: bool, cache_context: ResponseCacheContext[SyncRequestKey], ) -> SyncResult: + """The start of the machinery that produces a /sync response. + + See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details. + + This method does high-level bookkeeping: + - tracking the kind of sync in the logging context + - deleting any to_device messages whose delivery has been acknowledged. + - deciding if we should dispatch an instant or delayed response + - marking the sync as being lazily loaded, if appropriate + + Computing the body of the response begins in the next method, + `current_sync_for_user`. + """ if since_token is None: sync_type = "initial_sync" elif full_state: @@ -363,7 +376,7 @@ class SyncHandler: sync_config, since_token, full_state=full_state ) else: - + # Otherwise, we wait for something to happen and report it to the user. async def current_sync_callback( before_token: StreamToken, after_token: StreamToken ) -> SyncResult: @@ -402,7 +415,12 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Get the sync for client needed to match what the server has now.""" + """Generates the response body of a sync result, represented as a SyncResult. + + This is a wrapper around `generate_sync_result` which starts an open tracing + span to track the sync. See `generate_sync_result` for the next part of your + indoctrination. + """ with start_active_span("current_sync_for_user"): log_kv({"since_token": since_token}) sync_result = await self.generate_sync_result( @@ -560,7 +578,7 @@ class SyncHandler: # that have happened since `since_key` up to `end_key`, so we # can just use `get_room_events_stream_for_room`. # Otherwise, we want to return the last N events in the room - # in toplogical ordering. + # in topological ordering. if since_key: events, end_key = await self.store.get_room_events_stream_for_room( room_id, @@ -1042,7 +1060,18 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates a sync result.""" + """Generates the response body of a sync result. + + This is represented by a `SyncResult` struct, which is built from small pieces + using a `SyncResultBuilder`. See also + https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync + the `sync_result_builder` is passed as a mutable ("inout") parameter to various + helper functions. These retrieve and process the data which forms the sync body, + often writing to the `sync_result_builder` to store their output. + + At the end, we transfer data from the `sync_result_builder` to a new `SyncResult` + instance to signify that the sync calculation is complete. + """ # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -1344,14 +1373,22 @@ class SyncHandler: async def _generate_sync_entry_for_account_data( self, sync_result_builder: "SyncResultBuilder" ) -> Dict[str, Dict[str, JsonDict]]: - """Generates the account data portion of the sync response. Populates - `sync_result_builder` with the result. + """Generates the account data portion of the sync response. + + Account data (called "Client Config" in the spec) can be set either globally + or for a specific room. Account data consists of a list of events which + accumulate state, much like a room. + + This function retrieves global and per-room account data. The former is written + to the given `sync_result_builder`. The latter is returned directly, to be + later written to the `sync_result_builder` on a room-by-room basis. Args: sync_result_builder Returns: - A dictionary containing the per room account data. + A dictionary whose keys (room ids) map to the per room account data for that + room. """ sync_config = sync_result_builder.sync_config user_id = sync_result_builder.sync_config.user.to_string() @@ -1359,7 +1396,7 @@ class SyncHandler: if since_token and not sync_result_builder.full_state: ( - account_data, + global_account_data, account_data_by_room, ) = await self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key @@ -1370,23 +1407,23 @@ class SyncHandler: ) if push_rules_changed: - account_data["m.push_rules"] = await self.push_rules_for_user( + global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: ( - account_data, + global_account_data, account_data_by_room, ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - account_data["m.push_rules"] = await self.push_rules_for_user( + global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) account_data_for_user = await sync_config.filter_collection.filter_account_data( [ {"type": account_data_type, "content": content} - for account_data_type, content in account_data.items() + for account_data_type, content in global_account_data.items() ] ) @@ -1460,15 +1497,22 @@ class SyncHandler: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. + In the response that reaches the client, rooms are divided into four categories: + `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of + room ids returned by this function. + Args: sync_result_builder account_data_by_room: Dictionary of per room account data Returns: - Returns a 4-tuple of - `(newly_joined_rooms, newly_joined_or_invited_users, - newly_left_rooms, newly_left_users)` + Returns a 4-tuple whose entries are: + - newly_joined_rooms + - newly_joined_or_invited_or_knocked_users + - newly_left_rooms + - newly_left_users """ + # Start by fetching all ephemeral events in rooms we've joined (if required). user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( sync_result_builder.since_token is None @@ -1590,6 +1634,8 @@ class SyncHandler: ) -> bool: """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. + + Does not modify the `sync_result_builder`. """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1597,12 +1643,13 @@ class SyncHandler: assert since_token - # Get a list of membership change events that have happened. - rooms_changed = await self.store.get_membership_changes_for_user( + # Get a list of membership change events that have happened to the user + # requesting the sync. + membership_changes = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) - if rooms_changed: + if membership_changes: return True stream_id = since_token.room_key.stream @@ -1614,7 +1661,25 @@ class SyncHandler: async def _get_rooms_changed( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: - """Gets the the changes that have happened since the last sync.""" + """Determine the changes in rooms to report to the user. + + Ideally, we want to report all events whose stream ordering `s` lies in the + range `since_token < s <= now_token`, where the two tokens are read from the + sync_result_builder. + + If there are too many events in that range to report, things get complicated. + In this situation we return a truncated list of the most recent events, and + indicate in the response that there is a "gap" of omitted events. Additionally: + + - we include a "state_delta", to describe the changes in state over the gap, + - we include all membership events applying to the user making the request, + even those in the gap. + + See the spec for the rationale: + https://spec.matrix.org/v1.1/client-server-api/#syncing + + The sync_result_builder is not modified by this function. + """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token @@ -1622,21 +1687,36 @@ class SyncHandler: assert since_token - # Get a list of membership change events that have happened. - rooms_changed = await self.store.get_membership_changes_for_user( + # The spec + # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync + # notes that membership events need special consideration: + # + # > When a sync is limited, the server MUST return membership events for events + # > in the gap (between since and the start of the returned timeline), regardless + # > as to whether or not they are redundant. + # + # We fetch such events here, but we only seem to use them for categorising rooms + # as newly joined, newly left, invited or knocked. + # TODO: we've already called this function and ran this query in + # _have_rooms_changed. We could keep the results in memory to avoid a + # second query, at the cost of more complicated source code. + membership_change_events = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} - for event in rooms_changed: + for event in membership_change_events: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - newly_joined_rooms = [] - newly_left_rooms = [] - room_entries = [] - invited = [] - knocked = [] + newly_joined_rooms: List[str] = [] + newly_left_rooms: List[str] = [] + room_entries: List[RoomSyncResultBuilder] = [] + invited: List[InvitedSyncResult] = [] + knocked: List[KnockedSyncResult] = [] for room_id, events in mem_change_events_by_room_id.items(): + # The body of this loop will add this room to at least one of the five lists + # above. Things get messy if you've e.g. joined, left, joined then left the + # room all in the same sync period. logger.debug( "Membership changes in %s: [%s]", room_id, @@ -1781,7 +1861,9 @@ class SyncHandler: timeline_limit = sync_config.filter_collection.timeline_limit() - # Get all events for rooms we're currently joined to. + # Get all events since the `from_key` in rooms we're currently joined to. + # If there are too many, we get the most recent events only. This leaves + # a "gap" in the timeline, as described by the spec for /sync. room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, @@ -1842,6 +1924,10 @@ class SyncHandler: ) -> _RoomChanges: """Returns entries for all rooms for the user. + Like `_get_rooms_changed`, but assumes the `since_token` is `None`. + + This function does not modify the sync_result_builder. + Args: sync_result_builder ignored_users: Set of users ignored by user. @@ -1853,16 +1939,9 @@ class SyncHandler: now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config - membership_list = ( - Membership.INVITE, - Membership.KNOCK, - Membership.JOIN, - Membership.LEAVE, - Membership.BAN, - ) - room_list = await self.store.get_rooms_for_local_user_where_membership_is( - user_id=user_id, membership_list=membership_list + user_id=user_id, + membership_list=Membership.LIST, ) room_entries = [] @@ -2212,8 +2291,7 @@ def _calculate_state( # to only include membership events for the senders in the timeline. # In practice, we can do this by removing them from the p_ids list, # which is the list of relevant state we know we have already sent to the client. - # see https://github.com/matrix-org/synapse/pull/2970 - # /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 + # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 if lazy_load_members: p_ids.difference_update( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 42dc807d17..57aab55259 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): oldest `limit` events. Returns: - The list of events (in ascending order) and the token from the start + The list of events (in ascending stream order) and the token from the start of the chunk of events returned. """ if from_key == to_key: @@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [], from_key - def f(txn): + def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: + """Fetch membership events for a given user. + + All such events whose stream ordering `s` lies in the range + `from_key < s <= to_key` are returned. Events are ordered by ascending stream + order. + """ + # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [] - def f(txn): + def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): Returns: A list of events and a token pointing to the start of the returned - events. The events returned are in ascending order. + events. The events returned are in ascending topological order. """ rows, token = await self.get_recent_event_ids_for_room( -- cgit 1.5.1 From 637df95de63196033a6da4a6e286e1d58ea517b6 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 3 Dec 2021 16:42:44 +0000 Subject: Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. (#11445) --- changelog.d/11445.feature | 1 + synapse/config/registration.py | 49 ++++++++++++++++++++ synapse/handlers/register.py | 20 ++++++-- tests/config/test_registration_config.py | 78 ++++++++++++++++++++++++++++++++ tests/rest/client/test_auth.py | 76 +++++++++++++++++++++++++++++++ 5 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11445.feature create mode 100644 tests/config/test_registration_config.py (limited to 'synapse/handlers') diff --git a/changelog.d/11445.feature b/changelog.d/11445.feature new file mode 100644 index 0000000000..211a722b65 --- /dev/null +++ b/changelog.d/11445.feature @@ -0,0 +1 @@ +Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. \ No newline at end of file diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 47853199f4..68a4985398 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -130,11 +130,60 @@ class RegistrationConfig(Config): int ] = refreshable_access_token_lifetime + if ( + self.session_lifetime is not None + and "refreshable_access_token_lifetime" in config + ): + if self.session_lifetime < self.refreshable_access_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `refreshable_access_token_lifetime` " + "configuration options have been set, but `refreshable_access_token_lifetime` " + " exceeds `session_lifetime`!" + ) + + # The `nonrefreshable_access_token_lifetime` applies for tokens that can NOT be + # refreshed using a refresh token. + # If it is None, then these tokens last for the entire length of the session, + # which is infinite by default. + # The intention behind this configuration option is to help with requiring + # all clients to use refresh tokens, if the homeserver administrator requires. + nonrefreshable_access_token_lifetime = config.get( + "nonrefreshable_access_token_lifetime", + None, + ) + if nonrefreshable_access_token_lifetime is not None: + nonrefreshable_access_token_lifetime = self.parse_duration( + nonrefreshable_access_token_lifetime + ) + self.nonrefreshable_access_token_lifetime = nonrefreshable_access_token_lifetime + + if ( + self.session_lifetime is not None + and self.nonrefreshable_access_token_lifetime is not None + ): + if self.session_lifetime < self.nonrefreshable_access_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `nonrefreshable_access_token_lifetime` " + "configuration options have been set, but `nonrefreshable_access_token_lifetime` " + " exceeds `session_lifetime`!" + ) + refresh_token_lifetime = config.get("refresh_token_lifetime") if refresh_token_lifetime is not None: refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime + if ( + self.session_lifetime is not None + and self.refresh_token_lifetime is not None + ): + if self.session_lifetime < self.refresh_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `refresh_token_lifetime` " + "configuration options have been set, but `refresh_token_lifetime` " + " exceeds `session_lifetime`!" + ) + # The fallback template used for authenticating using a registration token self.registration_token_template = self.read_template("registration_token.html") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 24ca11b924..b14ddd8267 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -1,4 +1,5 @@ # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -116,6 +117,9 @@ class RegistrationHandler: self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.registration.session_lifetime + self.nonrefreshable_access_token_lifetime = ( + hs.config.registration.nonrefreshable_access_token_lifetime + ) self.refreshable_access_token_lifetime = ( hs.config.registration.refreshable_access_token_lifetime ) @@ -794,13 +798,25 @@ class RegistrationHandler: class and RegisterDeviceReplicationServlet. """ assert not self.hs.config.worker.worker_app + now_ms = self.clock.time_msec() access_token_expiry = None if self.session_lifetime is not None: if is_guest: raise Exception( "session_lifetime is not currently implemented for guest access" ) - access_token_expiry = self.clock.time_msec() + self.session_lifetime + access_token_expiry = now_ms + self.session_lifetime + + if self.nonrefreshable_access_token_lifetime is not None: + if access_token_expiry is not None: + # Don't allow the non-refreshable access token to outlive the + # session. + access_token_expiry = min( + now_ms + self.nonrefreshable_access_token_lifetime, + access_token_expiry, + ) + else: + access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime refresh_token = None refresh_token_id = None @@ -818,8 +834,6 @@ class RegistrationHandler: # that this value is set before setting this flag). assert self.refreshable_access_token_lifetime is not None - now_ms = self.clock.time_msec() - # Set the expiry time of the refreshable access token access_token_expiry = now_ms + self.refreshable_access_token_lifetime diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py new file mode 100644 index 0000000000..17a84d20d8 --- /dev/null +++ b/tests/config/test_registration_config.py @@ -0,0 +1,78 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config import ConfigError +from synapse.config.homeserver import HomeServerConfig + +from tests.unittest import TestCase +from tests.utils import default_config + + +class RegistrationConfigTestCase(TestCase): + def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): + """ + session_lifetime should logically be larger than, or at least as large as, + all the different token lifetimes. + Test that the user is faced with configuration errors if they make it + smaller, as that configuration doesn't make sense. + """ + config_dict = default_config("test") + + # First test all the error conditions + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "nonrefreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "refreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "refresh_token_lifetime": "31m", + **config_dict, + } + ) + + # Then test all the fine conditions + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "31m", + "nonrefreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "31m", + "refreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + HomeServerConfig().parse_config_dict( + {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict} + ) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index d8a94f4c12..7239e1a1b5 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -524,6 +524,19 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) + def is_access_token_valid(self, access_token) -> bool: + """ + Checks whether an access token is valid, returning whether it is or not. + """ + code = self.make_request( + "GET", "/_matrix/client/v3/account/whoami", access_token=access_token + ).code + + # Either 200 or 401 is what we get back; anything else is a bug. + assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} + + return code == HTTPStatus.OK + def test_login_issue_refresh_token(self): """ A login response should include a refresh_token only if asked. @@ -671,6 +684,69 @@ class RefreshAuthTests(unittest.HomeserverTestCase): HTTPStatus.UNAUTHORIZED, ) + @override_config( + { + "refreshable_access_token_lifetime": "1m", + "nonrefreshable_access_token_lifetime": "10m", + } + ) + def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): + """ + Tests that the expiry times for refreshable and non-refreshable access + tokens can be different. + """ + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + } + login_response1 = self.make_request( + "POST", + "/_matrix/client/r0/login", + {"org.matrix.msc2918.refresh_token": True, **body}, + ) + self.assertEqual(login_response1.code, 200, login_response1.result) + self.assertApproximates( + login_response1.json_body["expires_in_ms"], 60 * 1000, 100 + ) + refreshable_access_token = login_response1.json_body["access_token"] + + login_response2 = self.make_request( + "POST", + "/_matrix/client/r0/login", + body, + ) + self.assertEqual(login_response2.code, 200, login_response2.result) + nonrefreshable_access_token = login_response2.json_body["access_token"] + + # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) + self.reactor.advance(59.0) + + # Both tokens should still be valid. + self.assertTrue(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 61 s (just past 1 minute, the time of expiry) + self.reactor.advance(2.0) + + # Only the non-refreshable token is still valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 599 s (just shy of 10 minutes, the time of expiry) + self.reactor.advance(599.0 - 61.0) + + # It's still the case that only the non-refreshable token is still valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 601 s (just past 10 minutes, the time of expiry) + self.reactor.advance(2.0) + + # Now neither token is valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} ) -- cgit 1.5.1 From 494ebd7347ba52d702802fba4c3bb13e7bfbc2cf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Dec 2021 10:51:15 -0500 Subject: Include bundled aggregations in /sync and related fixes (#11478) Due to updates to MSC2675 this includes a few fixes: * Include bundled aggregations for /sync. * Do not include bundled aggregations for /initialSync and /events. * Do not bundle aggregations for state events. * Clarifies comments and variable names. --- changelog.d/11478.bugfix | 1 + synapse/events/utils.py | 58 ++++++++++------ synapse/handlers/events.py | 5 +- synapse/handlers/initial_sync.py | 30 ++++++-- synapse/handlers/message.py | 8 +-- synapse/rest/admin/rooms.py | 13 +--- synapse/rest/client/relations.py | 9 ++- synapse/rest/client/room.py | 5 +- synapse/rest/client/sync.py | 6 +- tests/rest/client/test_relations.py | 135 +++++++++++++++++++++++++----------- 10 files changed, 169 insertions(+), 101 deletions(-) create mode 100644 changelog.d/11478.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11478.bugfix b/changelog.d/11478.bugfix new file mode 100644 index 0000000000..5f02636f50 --- /dev/null +++ b/changelog.d/11478.bugfix @@ -0,0 +1 @@ +Include bundled relation aggregations during a limited `/sync` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 05219a9dd0..84ef69df67 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -306,6 +306,7 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, + *, as_client_event: bool = True, event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, token_id: Optional[str] = None, @@ -393,7 +394,8 @@ class EventClientSerializer: self, event: Union[JsonDict, EventBase], time_now: int, - bundle_relations: bool = True, + *, + bundle_aggregations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -401,8 +403,9 @@ class EventClientSerializer: Args: event: The event being serialized. time_now: The current time in milliseconds - bundle_relations: Whether to include the bundled relations for this - event. + bundle_aggregations: Whether to include the bundled aggregations for this + event. Only applies to non-state events. (State events never include + bundled aggregations.) **kwargs: Arguments to pass to `serialize_event` Returns: @@ -414,20 +417,27 @@ class EventClientSerializer: serialized_event = serialize_event(event, time_now, **kwargs) - # If MSC1849 is enabled then we need to look if there are any relations - # we need to bundle in with the event. - # Do not bundle relations if the event has been redacted - if not event.internal_metadata.is_redacted() and ( - self._msc1849_enabled and bundle_relations + # Check if there are any bundled aggregations to include with the event. + # + # Do not bundle aggregations if any of the following at true: + # + # * Support is disabled via the configuration or the caller. + # * The event is a state event. + # * The event has been redacted. + if ( + self._msc1849_enabled + and bundle_aggregations + and not event.is_state() + and not event.internal_metadata.is_redacted() ): - await self._injected_bundled_relations(event, time_now, serialized_event) + await self._injected_bundled_aggregations(event, time_now, serialized_event) return serialized_event - async def _injected_bundled_relations( + async def _injected_bundled_aggregations( self, event: EventBase, time_now: int, serialized_event: JsonDict ) -> None: - """Potentially injects bundled relations into the unsigned portion of the serialized event. + """Potentially injects bundled aggregations into the unsigned portion of the serialized event. Args: event: The event being serialized. @@ -435,7 +445,7 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ - # Do not bundle relations for an event which represents an edit or an + # Do not bundle aggregations for an event which represents an edit or an # annotation. It does not make sense for them to have related events. relates_to = event.content.get("m.relates_to") if isinstance(relates_to, (dict, frozendict)): @@ -445,18 +455,18 @@ class EventClientSerializer: event_id = event.event_id - # The bundled relations to include. - relations = {} + # The bundled aggregations to include. + aggregations = {} annotations = await self.store.get_aggregation_groups_for_event(event_id) if annotations.chunk: - relations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - relations[RelationTypes.REFERENCE] = references.to_dict() + aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: @@ -482,7 +492,7 @@ class EventClientSerializer: else: serialized_event["content"].pop("m.relates_to", None) - relations[RelationTypes.REPLACE] = { + aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, @@ -495,17 +505,19 @@ class EventClientSerializer: latest_thread_event, ) = await self.store.get_thread_summary(event_id) if latest_thread_event: - relations[RelationTypes.THREAD] = { - # Don't bundle relations as this could recurse forever. + aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_relations=False + latest_thread_event, time_now, bundle_aggregations=False ), "count": thread_count, } - # If any bundled relations were found, include them. - if relations: - serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) + # If any bundled aggregations were found, include them. + if aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + aggregations + ) async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index b4ff935546..32b0254c5f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -122,9 +122,8 @@ class EventStreamHandler: events, time_now, as_client_event=as_client_event, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_relations=False, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, ) chunk = { diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d4e4556155..9cd21e7f2b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -165,7 +165,11 @@ class InitialSyncHandler: invite_event = await self.store.get_event(event.event_id) d["invite"] = await self._event_serializer.serialize_event( - invite_event, time_now, as_client_event + invite_event, + time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) rooms_ret.append(d) @@ -216,7 +220,11 @@ class InitialSyncHandler: d["messages"] = { "chunk": ( await self._event_serializer.serialize_events( - messages, time_now=time_now, as_client_event=as_client_event + messages, + time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) ), "start": await start_token.to_string(self.store), @@ -226,6 +234,8 @@ class InitialSyncHandler: d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, as_client_event=as_client_event, ) @@ -366,14 +376,18 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( + # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - room_state.values(), time_now + room_state.values(), time_now, bundle_aggregations=False ) ), "presence": [], @@ -392,8 +406,9 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() + # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( - current_state.values(), time_now + current_state.values(), time_now, bundle_aggregations=False ) now_token = self.hs.get_event_sources().get_current_token() @@ -467,7 +482,10 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 95b4fad3c6..87f671708c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -247,13 +247,7 @@ class MessageHandler: room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events( - room_state.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_relations=False, - ) + events = await self._event_serializer.serialize_events(room_state.values(), now) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6bbc5510f0..669ab44a45 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -449,13 +449,7 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events( - events.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_relations=False, - ) + room_state = await self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} return HTTPStatus.OK, ret @@ -789,10 +783,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_relations=False, + results["state"], time_now ) return HTTPStatus.OK, results diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b1a3304849..fc4e6921c5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,14 +224,13 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_relations to False when retrieving the original - # event because we want the content before relations were applied to - # it. + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. original_event = await self._event_serializer.serialize_event( - event, now, bundle_relations=False + event, now, bundle_aggregations=False ) # The relations returned for the requested event do include their - # bundled relations. + # bundled aggregations. serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 3598967be0..f48e2e6ca2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_relations=False, + results["state"], time_now ) return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index b6a2485732..88e4f5e063 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet): return self._event_serializer.serialize_events( events, time_now=time_now, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_relations=False, + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + bundle_aggregations=room.timeline.limited, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index b494da5138..397c12c2a6 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client import login, register, relations, room +from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel @@ -29,6 +29,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, + sync.register_servlets, login.register_servlets, register.register_servlets, admin.register_servlets_for_client_rest_resource, @@ -454,11 +455,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_aggregation_get_event(self): - """Test that annotations, references, and threads get correctly bundled when - getting the parent event. - """ - + def test_bundled_aggregations(self): + """Test that annotations, references, and threads get correctly bundled.""" + # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -485,49 +484,107 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) + def assert_bundle(actual): + """Assert the expected values of the bundled aggregations.""" - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { + # Ensure the fields are as expected. + self.assertCountEqual( + actual.keys(), + ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.THREAD, + ), + ) + + # Check the values of each field. + self.assertEquals( + { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, {"type": "m.reaction", "key": "b", "count": 1}, ] }, - RelationTypes.REFERENCE: { - "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] - }, - RelationTypes.THREAD: { - "count": 2, - "latest_event": { - "age": 100, - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "origin_server_ts": 1600, - "room_id": self.room, - "sender": self.user_id, - "type": "m.room.test", - "unsigned": {"age": 100}, - "user_id": self.user_id, + actual[RelationTypes.ANNOTATION], + ) + + self.assertEquals( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + actual[RelationTypes.REFERENCE], + ) + + self.assertEquals( + 2, + actual[RelationTypes.THREAD].get("count"), + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } }, + "event_id": thread_2, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "user_id": self.user_id, }, - }, + actual[RelationTypes.THREAD].get("latest_event"), + ) + + def _find_and_assert_event(events): + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + break + else: + raise AssertionError(f"Event {self.parent_id} not found in chunk") + assert_bundle(event["unsigned"].get("m.relations")) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["unsigned"].get("m.relations")) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, ) + self.assertEquals(200, channel.code, channel.json_body) + _find_and_assert_event(channel.json_body["chunk"]) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + + # Request sync. + channel = self.make_request("GET", "/sync", access_token=self.user_token) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + _find_and_assert_event(room_timeline["events"]) + + # Note that /relations is tested separately in test_aggregation_get_event_for_thread + # since it needs different data configured. def test_aggregation_get_event_for_annotation(self): - """Test that annotations do not get bundled relations included + """Test that annotations do not get bundled aggregations included when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -549,7 +606,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) def test_aggregation_get_event_for_thread(self): - """Test that threads get bundled relations included when directly requested.""" + """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] -- cgit 1.5.1 From a15a893df8428395df7cb95b729431575001c38a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 6 Dec 2021 18:43:06 +0100 Subject: Save the OIDC session ID (sid) with the device on login (#11482) As a step towards allowing back-channel logout for OIDC. --- changelog.d/11482.misc | 1 + synapse/handlers/auth.py | 34 +++++- synapse/handlers/device.py | 8 ++ synapse/handlers/oidc.py | 58 +++++---- synapse/handlers/register.py | 15 ++- synapse/handlers/sso.py | 4 + synapse/module_api/__init__.py | 2 + synapse/replication/http/login.py | 8 ++ synapse/rest/client/login.py | 7 +- synapse/storage/databases/main/devices.py | 50 +++++++- .../delta/65/11_devices_auth_provider_session.sql | 27 +++++ tests/handlers/test_auth.py | 6 +- tests/handlers/test_cas.py | 40 +++++- tests/handlers/test_oidc.py | 135 ++++++++++++++++++--- tests/handlers/test_saml.py | 40 +++++- 15 files changed, 370 insertions(+), 65 deletions(-) create mode 100644 changelog.d/11482.misc create mode 100644 synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql (limited to 'synapse/handlers') diff --git a/changelog.d/11482.misc b/changelog.d/11482.misc new file mode 100644 index 0000000000..e78662988f --- /dev/null +++ b/changelog.d/11482.misc @@ -0,0 +1 @@ +Save the OpenID Connect session ID on login. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4d9c4e5834..61607cf2ba 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -39,6 +39,7 @@ import attr import bcrypt import pymacaroons import unpaddedbase64 +from pymacaroons.exceptions import MacaroonVerificationFailedException from twisted.web.server import Request @@ -182,8 +183,11 @@ class LoginTokenAttributes: user_id = attr.ib(type=str) - # the SSO Identity Provider that the user authenticated with, to get this token auth_provider_id = attr.ib(type=str) + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id = attr.ib(type=Optional[str]) + """The session ID advertised by the SSO Identity Provider.""" class AuthHandler: @@ -1650,6 +1654,7 @@ class AuthHandler: client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, new_user: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> None: """Having figured out a mxid for this user, complete the HTTP request @@ -1665,6 +1670,7 @@ class AuthHandler: during successful login. Must be JSON serializable. new_user: True if we should use wording appropriate to a user who has just registered. + auth_provider_session_id: The session ID from the SSO IdP received during login. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1685,6 +1691,7 @@ class AuthHandler: extra_attributes, new_user=new_user, user_profile_data=profile, + auth_provider_session_id=auth_provider_session_id, ) def _complete_sso_login( @@ -1696,6 +1703,7 @@ class AuthHandler: extra_attributes: Optional[JsonDict] = None, new_user: bool = False, user_profile_data: Optional[ProfileInfo] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ The synchronous portion of complete_sso_login. @@ -1717,7 +1725,9 @@ class AuthHandler: # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id, auth_provider_id=auth_provider_id + registered_user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) # Append the login token to the original redirect URL (i.e. with its query @@ -1822,6 +1832,7 @@ class MacaroonGenerator: self, user_id: str, auth_provider_id: str, + auth_provider_session_id: Optional[str] = None, duration_in_ms: int = (2 * 60 * 1000), ) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1830,6 +1841,10 @@ class MacaroonGenerator: expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) + if auth_provider_session_id is not None: + macaroon.add_first_party_caveat( + "auth_provider_session_id = %s" % (auth_provider_session_id,) + ) return macaroon.serialize() def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: @@ -1851,15 +1866,28 @@ class MacaroonGenerator: user_id = get_value_from_macaroon(macaroon, "user_id") auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") + auth_provider_session_id: Optional[str] = None + try: + auth_provider_session_id = get_value_from_macaroon( + macaroon, "auth_provider_session_id" + ) + except MacaroonVerificationFailedException: + pass + v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") v.satisfy_exact("type = login") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) + v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) satisfy_expiry(v, self.hs.get_clock().time_msec) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) + return LoginTokenAttributes( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 68b446eb66..82ee11e921 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: str, device_id: Optional[str], initial_device_display_name: Optional[str] = None, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> str: """ If the given device has not been registered, register it with the @@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: @user:id device_id: device id supplied by client initial_device_display_name: device display name from client + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from the SSO IdP. Returns: device id (generated if none was supplied) """ @@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [device_id]) @@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=new_device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [new_device_id]) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 3665d91513..deb3539751 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -23,7 +23,7 @@ from authlib.common.security import generate_token from authlib.jose import JsonWebToken, jwt from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri -from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template from pymacaroons.exceptions import ( @@ -117,7 +117,8 @@ class OidcHandler: for idp_id, p in self._providers.items(): try: await p.load_metadata() - await p.load_jwks() + if not p._uses_userinfo: + await p.load_jwks() except Exception as e: raise Exception( "Error while initialising OIDC provider %r" % (idp_id,) @@ -498,10 +499,6 @@ class OidcProvider: return await self._jwks.get() async def _load_jwks(self) -> JWKS: - if self._uses_userinfo: - # We're not using jwt signing, return an empty jwk set - return {"keys": []} - metadata = await self.load_metadata() # Load the JWKS using the `jwks_uri` metadata. @@ -663,7 +660,7 @@ class OidcProvider: return UserInfo(resp) - async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. Args: @@ -673,7 +670,7 @@ class OidcProvider: request. This value should match the one inside the token. Returns: - An object representing the user. + The decoded claims in the ID token. """ metadata = await self.load_metadata() claims_params = { @@ -684,9 +681,6 @@ class OidcProvider: # If we got an `access_token`, there should be an `at_hash` claim # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - claims_cls = CodeIDToken - else: - claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) jwt = JsonWebToken(alg_values) @@ -703,7 +697,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -713,7 +707,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -721,7 +715,8 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) claims.validate(leeway=120) # allows 2 min of clock skew - return UserInfo(claims) + + return claims async def handle_redirect_request( self, @@ -837,8 +832,22 @@ class OidcProvider: logger.debug("Successfully obtained OAuth2 token data: %r", token) - # Now that we have a token, get the userinfo, either by decoding the - # `id_token` or by fetching the `userinfo_endpoint`. + # If there is an id_token, it should be validated, regardless of the + # userinfo endpoint is used or not. + if token.get("id_token") is not None: + try: + id_token = await self._parse_id_token(token, nonce=session_data.nonce) + sid = id_token.get("sid") + except Exception as e: + logger.exception("Invalid id_token") + self._sso_handler.render_error(request, "invalid_token", str(e)) + return + else: + id_token = None + sid = None + + # Now that we have a token, get the userinfo either from the `id_token` + # claims or by fetching the `userinfo_endpoint`. if self._uses_userinfo: try: userinfo = await self._fetch_userinfo(token) @@ -846,13 +855,14 @@ class OidcProvider: logger.exception("Could not fetch userinfo") self._sso_handler.render_error(request, "fetch_error", str(e)) return + elif id_token is not None: + userinfo = UserInfo(id_token) else: - try: - userinfo = await self._parse_id_token(token, nonce=session_data.nonce) - except Exception as e: - logger.exception("Invalid id_token") - self._sso_handler.render_error(request, "invalid_token", str(e)) - return + logger.error("Missing id_token in token response") + self._sso_handler.render_error( + request, "invalid_token", "Missing id_token in token response" + ) + return # first check if we're doing a UIA if session_data.ui_auth_session_id: @@ -884,7 +894,7 @@ class OidcProvider: # Call the mapper to register/login the user try: await self._complete_oidc_login( - userinfo, token, request, session_data.client_redirect_url + userinfo, token, request, session_data.client_redirect_url, sid ) except MappingException as e: logger.exception("Could not map user") @@ -896,6 +906,7 @@ class OidcProvider: token: Token, request: SynapseRequest, client_redirect_url: str, + sid: Optional[str], ) -> None: """Given a UserInfo response, complete the login flow @@ -1008,6 +1019,7 @@ class OidcProvider: oidc_response_to_user_attributes, grandfather_existing_users, extra_attributes, + auth_provider_session_id=sid, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b14ddd8267..f08a516a75 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -746,6 +746,7 @@ class RegistrationHandler: is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. @@ -756,9 +757,9 @@ class RegistrationHandler: device_id: The device ID to check, or None to generate a new one. initial_display_name: An optional display name for the device. is_guest: Whether this is a guest account - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: Whether it should also issue a refresh token + auth_provider_session_id: The session ID received during login from the SSO IdP. Returns: Tuple of device ID, access token, access token expiration time and refresh token """ @@ -769,6 +770,8 @@ class RegistrationHandler: is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) login_counter.labels( @@ -791,6 +794,8 @@ class RegistrationHandler: is_guest: bool = False, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> LoginDict: """Helper for register_device @@ -822,7 +827,11 @@ class RegistrationHandler: refresh_token_id = None registered_device_id = await self.device_handler.check_device_registered( - user_id, device_id, initial_display_name + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if is_guest: assert access_token_expiry is None diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 49fde01cf0..65c27bc64a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -365,6 +365,7 @@ class SsoHandler: sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -415,6 +416,8 @@ class SsoHandler: extra_login_attributes: An optional dictionary of extra attributes to be provided to the client in the login response. + auth_provider_session_id: An optional session ID from the IdP. + Raises: MappingException if there was a problem mapping the response to a user. RedirectException: if the mapping provider needs to redirect the user @@ -490,6 +493,7 @@ class SsoHandler: client_redirect_url, extra_login_attributes, new_user=new_user, + auth_provider_session_id=auth_provider_session_id, ) async def _call_attribute_mapper( diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a8154168be..6bfb4b8d1b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -626,6 +626,7 @@ class ModuleApi: user_id: str, duration_in_ms: int = (2 * 60 * 1000), auth_provider_id: str = "", + auth_provider_session_id: Optional[str] = None, ) -> str: """Generate a login token suitable for m.login.token authentication @@ -643,6 +644,7 @@ class ModuleApi: return self._hs.get_macaroon_generator().generate_short_term_login_token( user_id, auth_provider_id, + auth_provider_session_id, duration_in_ms, ) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 0db419ea57..daacc34cea 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost, should_issue_refresh_token, + auth_provider_id, + auth_provider_session_id, ): """ Args: @@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, "should_issue_refresh_token": should_issue_refresh_token, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, } async def _handle_request(self, request, user_id): @@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] should_issue_refresh_token = content["should_issue_refresh_token"] + auth_provider_id = content["auth_provider_id"] + auth_provider_session_id = content["auth_provider_session_id"] res = await self.registration_handler.register_device_inner( user_id, @@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) return 200, res diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index a66ee4fb3d..1b23fa18cf 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -303,6 +303,7 @@ class LoginRestServlet(RestServlet): ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -318,10 +319,10 @@ class LoginRestServlet(RestServlet): create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: True if this login should issue a refresh token alongside the access token. + auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: result: Dictionary of account information after successful login. @@ -354,6 +355,7 @@ class LoginRestServlet(RestServlet): initial_display_name, auth_provider_id=auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=auth_provider_session_id, ) result = LoginResponse( @@ -399,6 +401,7 @@ class LoginRestServlet(RestServlet): self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=res.auth_provider_session_id, ) async def _do_jwt_login( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9ccc66e589..d5a4a661cd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} + async def get_devices_by_auth_provider_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> List[Dict[str, Any]]: + """Retrieve the list of devices associated with a SSO IdP session ID. + + Args: + auth_provider_id: The SSO IdP ID as defined in the server config + auth_provider_session_id: The session ID within the IdP + Returns: + A list of dicts containing the device_id and the user_id of each device + """ + return await self.db_pool.simple_select_list( + table="device_auth_providers", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + retcols=("user_id", "device_id"), + desc="get_devices_by_auth_provider_session_id", + ) + @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: Optional[str] + self, + user_id: str, + device_id: str, + initial_device_display_name: Optional[str], + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + if auth_provider_id and auth_provider_session_id: + await self.db_pool.simple_insert( + "device_auth_providers", + values={ + "user_id": user_id, + "device_id": device_id, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="store_device_auth_provider", + ) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) + self.db_pool.simple_delete_many_txn( + txn, + table="device_auth_providers", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id}, + ) + await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql new file mode 100644 index 0000000000..a65bfb520d --- /dev/null +++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql @@ -0,0 +1,27 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track the auth provider used by each login as well as the session ID +CREATE TABLE device_auth_providers ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + auth_provider_id TEXT NOT NULL, + auth_provider_session_id TEXT NOT NULL +); + +CREATE INDEX device_auth_providers_devices + ON device_auth_providers (user_id, device_id); +CREATE INDEX device_auth_providers_sessions + ON device_auth_providers (auth_provider_id, auth_provider_session_id); diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 72e176da75..03b8b8615c 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) self.assertEqual(self.user1, res.user_id) @@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) @@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase): def _get_macaroon(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index b625995d12..8705ff8943 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -66,7 +66,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) def test_map_cas_user_to_existing_user(self): @@ -89,7 +95,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -98,7 +110,13 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_cas_user_to_invalid_localpart(self): @@ -116,7 +134,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True + "@f=c3=b6=c3=b6:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -160,7 +184,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a25c89bd5b..cfe3de5266 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -252,13 +252,6 @@ class OidcHandlerTestCase(HomeserverTestCase): with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) - # Return empty key set if JWKS are not used - self.provider._scopes = [] # not asking the openid scope - self.http_client.get_json.reset_mock() - jwks = self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_not_called() - self.assertEqual(jwks, {"keys": []}) - @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" @@ -455,7 +448,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=True + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=True, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -482,17 +481,58 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.provider._scopes = [] # do not ask the "openid" scope + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + } + self.provider._exchange_code = simple_async_mock(return_value=token) self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=False + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() + # With an ID token, userinfo fetching and sid in the ID token + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + "id_token": "id_token", + } + id_token = { + "sid": "abcdefgh", + } + self.provider._parse_id_token = simple_async_mock(return_value=id_token) + self.provider._exchange_code = simple_async_mock(return_value=token) + auth_handler.complete_sso_login.reset_mock() + self.provider._fetch_userinfo.reset_mock() + self.get_success(self.handler.handle_oidc_callback(request)) + + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=id_token["sid"], + ) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_called_once_with(token) + self.render_error.assert_not_called() + # Handle userinfo fetching error self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) @@ -776,6 +816,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url, {"phone": "1234567"}, new_user=True, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -790,7 +831,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "oidc", ANY, ANY, None, new_user=True + "@test_user:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -801,7 +848,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True + "@test_user_2:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -838,14 +891,26 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -860,7 +925,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -896,7 +967,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False + "@TEST_USER_2:test", + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -934,7 +1011,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "oidc", ANY, ANY, None, new_user=True + "@test_user1:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -1018,7 +1101,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1043,7 +1132,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1156,7 +1251,7 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={}) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 8cfc184fef..50551aa6e3 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -130,7 +130,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -156,7 +162,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -165,7 +177,13 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_saml_response_to_invalid_localpart(self): @@ -213,7 +231,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "saml", request, "", None, new_user=True + "@test_user1:test", + "saml", + request, + "", + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -309,7 +333,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) -- cgit 1.5.1 From 9c55dedc8c4484e6269451a8c3c10b3e314aeb4a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Dec 2021 11:24:31 +0000 Subject: Correctly ignore invites from ignored users (#11511) --- changelog.d/11511.bugfix | 1 + synapse/handlers/sync.py | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 changelog.d/11511.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11511.bugfix b/changelog.d/11511.bugfix new file mode 100644 index 0000000000..9c287357b8 --- /dev/null +++ b/changelog.d/11511.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where invites from ignored users were included in incremental syncs. \ No newline at end of file diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 53d4627147..97d5a26e20 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1771,6 +1771,7 @@ class SyncHandler: if not non_joins: continue + last_non_join = non_joins[-1] # Check if we have left the room. This can either be because we were # joined before *or* that we since joined and then left. @@ -1792,18 +1793,18 @@ class SyncHandler: newly_left_rooms.append(room_id) # Only bother if we're still currently invited - should_invite = non_joins[-1].membership == Membership.INVITE + should_invite = last_non_join.membership == Membership.INVITE if should_invite: - if event.sender not in ignored_users: - invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if last_non_join.sender not in ignored_users: + invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join) if invite_room_sync: invited.append(invite_room_sync) # Only bother if our latest membership in the room is knock (and we haven't # been accepted/rejected in the meantime). - should_knock = non_joins[-1].membership == Membership.KNOCK + should_knock = last_non_join.membership == Membership.KNOCK if should_knock: - knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1]) + knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join) if knock_room_sync: knocked.append(knock_room_sync) -- cgit 1.5.1 From b1ecd19c5d19815b69e425d80f442bf2877cab76 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Dec 2021 11:37:54 +0000 Subject: Fix 'delete room' admin api to work on incomplete rooms (#11523) If, for some reason, we don't have the create event, we should still be able to purge a room. --- changelog.d/11523.feature | 1 + synapse/handlers/pagination.py | 3 --- synapse/handlers/room.py | 21 +++++++-------------- synapse/rest/admin/rooms.py | 3 --- tests/rest/admin/test_room.py | 42 +++++++++++++++++++++++++----------------- 5 files changed, 33 insertions(+), 37 deletions(-) create mode 100644 changelog.d/11523.feature (limited to 'synapse/handlers') diff --git a/changelog.d/11523.feature b/changelog.d/11523.feature new file mode 100644 index 0000000000..ecac7f9db9 --- /dev/null +++ b/changelog.d/11523.feature @@ -0,0 +1 @@ +Extend the "delete room" admin api to work correctly on rooms which have previously been partially deleted. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index cd64142735..4f42438053 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -406,9 +406,6 @@ class PaginationHandler: force: set true to skip checking for joined users. """ with await self.pagination_lock.write(room_id): - # check we know about the room - await self.store.get_room_version_id(room_id) - # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2bcdf32dcc..ead2198e14 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1535,20 +1535,13 @@ class RoomShutdownHandler: await self.store.block_room(room_id, requester_user_id) if not await self.store.get_room(room_id): - if block: - # We allow you to block an unknown room. - return { - "kicked_users": [], - "failed_to_kick_users": [], - "local_aliases": [], - "new_room_id": None, - } - else: - # But if you don't want to preventatively block another room, - # this function can't do anything useful. - raise NotFoundError( - "Cannot shut down room: unknown room id %s" % (room_id,) - ) + # if we don't know about the room, there is nothing left to do. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 669ab44a45..829e86675a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -106,9 +106,6 @@ class RoomRestV2Servlet(RestServlet): HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) ) - if not await self._store.get_room(room_id): - raise NotFoundError("Unknown room id %s" % (room_id,)) - delete_id = self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d3858e460d..22f9aa6234 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -83,7 +83,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + Check that unknown rooms/server return 200 """ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test" @@ -94,8 +94,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_room_is_not_valid(self): """ @@ -508,27 +507,36 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - @parameterized.expand( - [ - ("DELETE", "/_synapse/admin/v2/rooms/%s"), - ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), - ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), - ] - ) - def test_room_does_not_exist(self, method: str, url: str): - """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + def test_room_does_not_exist(self): """ + Check that unknown rooms/server return 200 + This is important, as it allows incomplete vestiges of rooms to be cleared up + even if the create event/etc is missing. + """ + room_id = "!unknown:test" channel = self.make_request( - method, - url % "!unknown:test", + "DELETE", + f"/_synapse/admin/v2/rooms/{room_id}", content={}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/rooms/{room_id}/delete_status", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) @parameterized.expand( [ -- cgit 1.5.1 From 2a3ec6facf79f6aae011d9fb6f9ed5e43c7b6bec Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Dec 2021 12:34:38 +0000 Subject: Correctly register shutdown handler for presence workers (#11518) Fixes #11517 --- changelog.d/11518.bugfix | 1 + synapse/handlers/presence.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11518.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11518.bugfix b/changelog.d/11518.bugfix new file mode 100644 index 0000000000..3c27382c7a --- /dev/null +++ b/changelog.d/11518.bugfix @@ -0,0 +1 @@ +Fix a regression in Synapse 1.48.0 where presence workers would not clear their presence updates over replication on shutdown. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3df872c578..454d06c973 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -421,7 +421,7 @@ class WorkerPresenceHandler(BasePresenceHandler): self._on_shutdown, ) - def _on_shutdown(self) -> None: + async def _on_shutdown(self) -> None: if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) -- cgit 1.5.1 From 14d593f72d10b4d8cb67e3288bb3131ee30ccf59 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Dec 2021 12:42:05 +0000 Subject: Refactors in `_generate_sync_entry_for_rooms` (#11515) * Move sync_token up to the top * Pull out _get_ignored_users * Try to signpost the body of `_generate_sync_entry_for_rooms` * Pull out _calculate_user_changes Co-authored-by: Patrick Cloke --- changelog.d/11494.misc | 2 +- changelog.d/11515.misc | 1 + synapse/handlers/sync.py | 122 ++++++++++++++++++++++++++++++----------------- 3 files changed, 79 insertions(+), 46 deletions(-) create mode 100644 changelog.d/11515.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11494.misc b/changelog.d/11494.misc index 7afd7d3a07..52cf26a4b6 100644 --- a/changelog.d/11494.misc +++ b/changelog.d/11494.misc @@ -1 +1 @@ -Add comments to various parts of the `/sync` handler. \ No newline at end of file +Refactor various parts of the `/sync` handler. \ No newline at end of file diff --git a/changelog.d/11515.misc b/changelog.d/11515.misc new file mode 100644 index 0000000000..7f9d8cd57f --- /dev/null +++ b/changelog.d/11515.misc @@ -0,0 +1 @@ +Refactor various parts of the `/sync` handler. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 97d5a26e20..f3039c3c3f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1506,16 +1506,22 @@ class SyncHandler: account_data_by_room: Dictionary of per room account data Returns: - Returns a 4-tuple whose entries are: + Returns a 4-tuple describing rooms the user has joined or left, and users who've + joined or left rooms any rooms the user is in. This gets used later in + `_generate_sync_entry_for_device_list`. + + Its entries are: - newly_joined_rooms - newly_joined_or_invited_or_knocked_users - newly_left_rooms - newly_left_users """ - # Start by fetching all ephemeral events in rooms we've joined (if required). + since_token = sync_result_builder.since_token + + # 1. Start by fetching all ephemeral events in rooms we've joined (if required). user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - sync_result_builder.since_token is None + since_token is None and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) @@ -1529,9 +1535,8 @@ class SyncHandler: ) sync_result_builder.now_token = now_token - # We check up front if anything has changed, if it hasn't then there is + # 2. We check up front if anything has changed, if it hasn't then there is # no point in going further. - since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: have_changed = await self._have_rooms_changed(sync_result_builder) @@ -1544,20 +1549,8 @@ class SyncHandler: logger.debug("no-oping sync") return set(), set(), set(), set() - ignored_account_data = ( - await self.store.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id=user_id - ) - ) - - # If there is ignored users account data and it matches the proper type, - # then use it. - ignored_users: FrozenSet[str] = frozenset() - if ignored_account_data: - ignored_users_data = ignored_account_data.get("ignored_users", {}) - if isinstance(ignored_users_data, dict): - ignored_users = frozenset(ignored_users_data.keys()) - + # 3. Work out which rooms need reporting in the sync response. + ignored_users = await self._get_ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( sync_result_builder, ignored_users @@ -1567,7 +1560,6 @@ class SyncHandler: ) else: room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) - tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1578,6 +1570,8 @@ class SyncHandler: newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms + # 4. We need to apply further processing to `room_entries` (rooms considered + # joined or archived). async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: logger.debug("Generating room entry for %s", room_entry.room_id) await self._generate_room_entry( @@ -1596,31 +1590,13 @@ class SyncHandler: sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) - # Now we want to get any newly joined, invited or knocking users - newly_joined_or_invited_or_knocked_users = set() - newly_left_users = set() - if since_token: - for joined_sync in sync_result_builder.joined: - it = itertools.chain( - joined_sync.timeline.events, joined_sync.state.values() - ) - for event in it: - if event.type == EventTypes.Member: - if ( - event.membership == Membership.JOIN - or event.membership == Membership.INVITE - or event.membership == Membership.KNOCK - ): - newly_joined_or_invited_or_knocked_users.add( - event.state_key - ) - else: - prev_content = event.unsigned.get("prev_content", {}) - prev_membership = prev_content.get("membership", None) - if prev_membership == Membership.JOIN: - newly_left_users.add(event.state_key) - - newly_left_users -= newly_joined_or_invited_or_knocked_users + # 5. Work out which users have joined or left rooms we're in. We use this + # to build the device_list part of the sync response in + # `_generate_sync_entry_for_device_list`. + ( + newly_joined_or_invited_or_knocked_users, + newly_left_users, + ) = sync_result_builder.calculate_user_changes() return ( set(newly_joined_rooms), @@ -1629,6 +1605,29 @@ class SyncHandler: newly_left_users, ) + async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]: + """Retrieve the users ignored by the given user from their global account_data. + + Returns an empty set if + - there is no global account_data entry for ignored_users + - there is such an entry, but it's not a JSON object. + """ + # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead? + ignored_account_data = ( + await self.store.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ) + ) + + # If there is ignored users account data and it matches the proper type, + # then use it. + ignored_users: FrozenSet[str] = frozenset() + if ignored_account_data: + ignored_users_data = ignored_account_data.get("ignored_users", {}) + if isinstance(ignored_users_data, dict): + ignored_users = frozenset(ignored_users_data.keys()) + return ignored_users + async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" ) -> bool: @@ -2341,6 +2340,39 @@ class SyncResultBuilder: groups: Optional[GroupsSyncResult] = None to_device: List[JsonDict] = attr.Factory(list) + def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]: + """Work out which other users have joined or left rooms we are joined to. + + This data only is only useful for an incremental sync. + + The SyncResultBuilder is not modified by this function. + """ + newly_joined_or_invited_or_knocked_users = set() + newly_left_users = set() + if self.since_token: + for joined_sync in self.joined: + it = itertools.chain( + joined_sync.timeline.events, joined_sync.state.values() + ) + for event in it: + if event.type == EventTypes.Member: + if ( + event.membership == Membership.JOIN + or event.membership == Membership.INVITE + or event.membership == Membership.KNOCK + ): + newly_joined_or_invited_or_knocked_users.add( + event.state_key + ) + else: + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership == Membership.JOIN: + newly_left_users.add(event.state_key) + + newly_left_users -= newly_joined_or_invited_or_knocked_users + return newly_joined_or_invited_or_knocked_users, newly_left_users + @attr.s(slots=True, auto_attribs=True) class RoomSyncResultBuilder: -- cgit 1.5.1 From 8541809cb952ebf0da2a95dd93eccd5644dab49d Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 8 Dec 2021 05:01:38 -0500 Subject: Send and handle cross-signing messages using the stable prefix. (#10520) --- changelog.d/10520.misc | 1 + synapse/handlers/e2e_keys.py | 8 ++++++-- synapse/storage/databases/main/devices.py | 4 +++- tests/federation/test_federation_sender.py | 5 +++-- 4 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10520.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10520.misc b/changelog.d/10520.misc new file mode 100644 index 0000000000..a911e165da --- /dev/null +++ b/changelog.d/10520.misc @@ -0,0 +1 @@ +Send and handle cross-signing messages using the stable prefix. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 60c11e3d21..b2554bda04 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -65,8 +65,12 @@ class E2eKeysHandler: else: # Only register this edu handler on master as it requires writing # device updates to the db - # - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "m.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) + # also handle the unstable version + # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( "org.matrix.signing_key_update", self._edu_updater.incoming_signing_key_update, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d5a4a661cd..838a2a6a3d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -274,7 +274,9 @@ class DeviceWorkerStore(SQLBaseStore): # add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("m.signing_key_update", result)) + # also send the unstable version + # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) return now_stream_id, results diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b457dad6d2..b2376e2db9 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -266,7 +266,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) # expect signing key update edu - self.assertEqual(len(self.edus), 1) + self.assertEqual(len(self.edus), 2) + self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") # sign the devices @@ -491,7 +492,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) -> None: """Check that the txn has an EDU with a signing key update.""" edus = txn["edus"] - self.assertEqual(len(edus), 1) + self.assertEqual(len(edus), 2) def generate_and_upload_device_signing_key( self, user_id: str, device_id: str -- cgit 1.5.1