diff options
25 files changed, 1041 insertions, 134 deletions
diff --git a/changelog.d/17447.feature b/changelog.d/17447.feature new file mode 100644 index 0000000000..6f80e298ae --- /dev/null +++ b/changelog.d/17447.feature @@ -0,0 +1 @@ +Track which rooms have been sent to clients in the experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/changelog.d/17452.misc b/changelog.d/17452.misc new file mode 100644 index 0000000000..4fd07f617b --- /dev/null +++ b/changelog.d/17452.misc @@ -0,0 +1 @@ +Change sliding sync to use their own token format in preparation for storing per-connection state. diff --git a/pyproject.toml b/pyproject.toml index 0f040fc612..1f9ee2b944 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -201,8 +201,8 @@ netaddr = ">=0.7.18" # add a lower bound to the Jinja2 dependency. Jinja2 = ">=3.0" bleach = ">=1.4.3" -# We use `Self`, which were added in `typing-extensions` 4.0. -typing-extensions = ">=4.0" +# We use `assert_never`, which were added in `typing-extensions` 4.1. +typing-extensions = ">=4.1" # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. cryptography = ">=3.4.7" diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index ec35784c5f..c04039a573 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -125,7 +125,7 @@ class AdminHandler: # Get all rooms the user is in or has been in rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, - membership_list=Membership.LIST, + membership_list=frozenset(Membership.LIST), ) # We only try and fetch events for rooms the user has been in. If diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 99f9f6e64a..f397911f28 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -34,7 +34,7 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, trace from synapse.storage.databases.main.e2e_room_keys import RoomKey from synapse.types import JsonDict -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import ReadWriteLock if TYPE_CHECKING: from synapse.server import HomeServer @@ -58,7 +58,7 @@ class E2eRoomKeysHandler: # clients belonging to a user will receive and try to upload a new session at # roughly the same time. Also used to lock out uploads when the key is being # changed. - self._upload_linearizer = Linearizer("upload_room_keys_lock") + self._upload_lock = ReadWriteLock() @trace async def get_room_keys( @@ -89,7 +89,7 @@ class E2eRoomKeysHandler: # we deliberately take the lock to get keys so that changing the version # works atomically - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.read(user_id): # make sure the backup version exists try: await self.store.get_e2e_room_keys_version_info(user_id, version) @@ -132,7 +132,7 @@ class E2eRoomKeysHandler: """ # lock for consistency with uploading - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): # make sure the backup version exists try: version_info = await self.store.get_e2e_room_keys_version_info( @@ -193,7 +193,7 @@ class E2eRoomKeysHandler: # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): # Check that the version we're trying to upload is the current version try: version_info = await self.store.get_e2e_room_keys_version_info(user_id) @@ -355,7 +355,7 @@ class E2eRoomKeysHandler: # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) @@ -382,7 +382,7 @@ class E2eRoomKeysHandler: } """ - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.read(user_id): try: res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: @@ -407,7 +407,7 @@ class E2eRoomKeysHandler: NotFoundError: if this backup version doesn't exist """ - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): try: await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: @@ -437,7 +437,7 @@ class E2eRoomKeysHandler: raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): try: old_info = await self.store.get_e2e_room_keys_version_info( user_id, version diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index bd3c87f5f4..7ab4f2b67d 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -135,7 +135,7 @@ class InitialSyncHandler: memberships.append(Membership.LEAVE) room_list = await self.store.get_rooms_for_local_user_where_membership_is( - user_id=user_id, membership_list=memberships + user_id=user_id, membership_list=frozenset(memberships) ) user = UserID.from_string(user_id) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index a7d52fa648..8066132527 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -278,7 +278,7 @@ class SearchHandler: # TODO: Search through left rooms too rooms = await self.store.get_rooms_for_local_user_where_membership_is( requester.user.to_string(), - membership_list=[Membership.JOIN], + membership_list=(Membership.JOIN,), # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) room_ids = {r.room_id for r in rooms} diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 06d7d43f0b..b07b62a8fc 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -18,11 +18,13 @@ # # import logging +from enum import Enum from itertools import chain from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple import attr from immutabledict import immutabledict +from typing_extensions import assert_never from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership from synapse.events import EventBase @@ -37,7 +39,9 @@ from synapse.types import ( PersistedEventPosition, Requester, RoomStreamToken, + SlidingSyncStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -330,6 +334,8 @@ class StateValues: # `state_key`. LAZY: Final = "$LAZY" + ME: Final = "$ME" + class SlidingSyncHandler: def __init__(self, hs: "HomeServer"): @@ -342,11 +348,13 @@ class SlidingSyncHandler: self.relations_handler = hs.get_relations_handler() self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync + self.connection_store = SlidingSyncConnectionStore() + async def wait_for_sync_for_user( self, requester: Requester, sync_config: SlidingSyncConfig, - from_token: Optional[StreamToken] = None, + from_token: Optional[SlidingSyncStreamToken] = None, timeout_ms: int = 0, ) -> SlidingSyncResult: """ @@ -381,7 +389,7 @@ class SlidingSyncHandler: # this returns false, it means we timed out waiting, and we should # just return an empty response. before_wait_ts = self.clock.time_msec() - if not await self.notifier.wait_for_stream_token(from_token): + if not await self.notifier.wait_for_stream_token(from_token.stream_token): logger.warning( "Timed out waiting for worker to catch up. Returning empty response" ) @@ -419,16 +427,17 @@ class SlidingSyncHandler: sync_config.user.to_string(), timeout_ms, current_sync_callback, - from_token=from_token, + from_token=from_token.stream_token, ) return result + @trace async def current_sync_for_user( self, sync_config: SlidingSyncConfig, to_token: StreamToken, - from_token: Optional[StreamToken] = None, + from_token: Optional[SlidingSyncStreamToken] = None, ) -> SlidingSyncResult: """ Generates the response body of a Sliding Sync result, represented as a @@ -449,6 +458,12 @@ class SlidingSyncHandler: # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() + await self.connection_store.mark_token_seen( + user_id, + conn_id=sync_config.connection_id(), + from_token=from_token, + ) + # Get all of the room IDs that the user should be able to see in the sync # response has_lists = sync_config.lists is not None and len(sync_config.lists) > 0 @@ -461,7 +476,7 @@ class SlidingSyncHandler: await self.get_room_membership_for_user_at_to_token( user=sync_config.user, to_token=to_token, - from_token=from_token, + from_token=from_token.stream_token if from_token else None, ) ) @@ -479,9 +494,14 @@ class SlidingSyncHandler: for list_key, list_config in sync_config.lists.items(): # Apply filters filtered_sync_room_map = sync_room_map - if list_config.filters is not None: + + if list_config.filters: + filtered_sync_room_map = await self.filter_rooms( - sync_config.user, sync_room_map, list_config.filters, to_token + sync_config.user, + filtered_sync_room_map, + list_config.filters, + to_token, ) # Sort the list @@ -504,7 +524,6 @@ class SlidingSyncHandler: # Also see `StateFilter.must_await_full_state(...)` for comparison lazy_loading = ( membership_state_keys is not None - and len(membership_state_keys) == 1 and StateValues.LAZY in membership_state_keys ) @@ -592,6 +611,31 @@ class SlidingSyncHandler: else: relevant_room_map[room_id] = room_sync_config + # Filter out rooms that haven't received updates and we've sent down + # previously. + if from_token: + rooms_should_send = set() + for room_id in relevant_room_map: + status = await self.connection_store.have_sent_room( + user_id, + sync_config.connection_id(), + from_token.connection_token, + room_id, + ) + if status.status != HaveSentRoomFlag.LIVE: + rooms_should_send.add(room_id) + + # TODO: Also check current state delta stream + rooms_that_have_updates = ( + self.store._events_stream_cache.get_entities_changed( + relevant_room_map, from_token.stream_token.room_key.stream + ) + ) + rooms_should_send.update(rooms_that_have_updates) + relevant_room_map = { + r: c for r, c in relevant_room_map.items() if r in rooms_should_send + } + # Fetch room data rooms: Dict[str, SlidingSyncResult.RoomResult] = {} @@ -599,7 +643,7 @@ class SlidingSyncHandler: @tag_args async def handle_room(room_id: str) -> None: room_sync_result = await self.get_room_sync_data( - user=sync_config.user, + sync_config=sync_config, room_id=room_id, room_sync_config=relevant_room_map[room_id], room_membership_for_user_at_to_token=room_membership_for_user_map[ @@ -618,8 +662,21 @@ class SlidingSyncHandler: sync_config=sync_config, to_token=to_token ) + if has_lists or has_room_subscriptions: + connection_token = await self.connection_store.record_rooms( + user_id, + conn_id=sync_config.connection_id(), + from_token=from_token, + sent_room_ids=relevant_room_map.keys(), + unsent_room_ids=[], # TODO: We currently ssume that we have sent down all updates. + ) + elif from_token: + connection_token = from_token.connection_token + else: + connection_token = 0 + return SlidingSyncResult( - next_pos=to_token, + next_pos=SlidingSyncStreamToken(to_token, connection_token), lists=lists, rooms=rooms, extensions=extensions, @@ -656,13 +713,12 @@ class SlidingSyncHandler: # First grab a current snapshot rooms for the user # (also handles forgotten rooms) - room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is( - user_id=user_id, - # We want to fetch any kind of membership (joined and left rooms) in order - # to get the `event_pos` of the latest room membership event for the - # user. - membership_list=Membership.LIST, - excluded_rooms=self.rooms_to_exclude_globally, + room_for_user_list = ( + await self.store.get_rooms_for_local_user_where_membership_is( + user_id=user_id, + membership_list=frozenset(Membership.LIST), + excluded_rooms=self.rooms_to_exclude_globally, + ) ) # If the user has never joined any rooms before, we can just return an empty list @@ -1073,6 +1129,7 @@ class SlidingSyncHandler: # return None + @trace async def filter_rooms( self, user: UserID, @@ -1196,6 +1253,7 @@ class SlidingSyncHandler: # Assemble a new sync room map but only with the `filtered_room_id_set` return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set} + @trace async def sort_rooms( self, sync_room_map: Dict[str, _RoomMembershipForUser], @@ -1350,11 +1408,11 @@ class SlidingSyncHandler: async def get_room_sync_data( self, - user: UserID, + sync_config: SlidingSyncConfig, room_id: str, room_sync_config: RoomSyncConfig, room_membership_for_user_at_to_token: _RoomMembershipForUser, - from_token: Optional[StreamToken], + from_token: Optional[SlidingSyncStreamToken], to_token: StreamToken, ) -> SlidingSyncResult.RoomResult: """ @@ -1372,6 +1430,38 @@ class SlidingSyncHandler: from_token: The point in the stream to sync from. to_token: The point in the stream to sync up to. """ + user = sync_config.user + + # Determine whether we should limit the timeline to the token range. + # + # We should return historical messages (before token range) in the + # following cases because we want clients to be able to show a basic + # screen of information: + # - Initial sync (because no `from_token` to limit us anyway) + # - When users `newly_joined` + # - For an incremental sync where we haven't sent it down this + # connection before + to_bound = None + initial = True + if from_token and not room_membership_for_user_at_to_token.newly_joined: + room_status = await self.connection_store.have_sent_room( + user_id=user.to_string(), + conn_id=sync_config.connection_id(), + connection_token=from_token.connection_token, + room_id=room_id, + ) + if room_status.status == HaveSentRoomFlag.LIVE: + to_bound = from_token.stream_token.room_key + initial = False + elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: + assert room_status.last_token is not None + to_bound = room_status.last_token + initial = False + elif room_status.status == HaveSentRoomFlag.NEVER: + to_bound = None + initial = True + else: + assert_never(room_status.status) # Assemble the list of timeline events # @@ -1408,21 +1498,9 @@ class SlidingSyncHandler: room_membership_for_user_at_to_token.event_pos.to_room_stream_token() ) - # Determine whether we should limit the timeline to the token range. - # - # We should return historical messages (before token range) in the - # following cases because we want clients to be able to show a basic - # screen of information: - # - Initial sync (because no `from_token` to limit us anyway) - # - When users `newly_joined` - # - TODO: For an incremental sync where we haven't sent it down this - # connection before - to_bound = ( - from_token.room_key - if from_token is not None - and not room_membership_for_user_at_to_token.newly_joined - else None - ) + fiddled_timeline_limit = room_sync_config.timeline_limit + # if to_bound: + # fiddled_timeline_limit = max(fiddled_timeline_limit, 10) timeline_events, new_room_key = await self.store.paginate_room_events( room_id=room_id, @@ -1431,7 +1509,7 @@ class SlidingSyncHandler: direction=Direction.BACKWARDS, # We add one so we can determine if there are enough events to saturate # the limit or not (see `limited`) - limit=room_sync_config.timeline_limit + 1, + limit=fiddled_timeline_limit + 1, event_filter=None, ) @@ -1442,11 +1520,11 @@ class SlidingSyncHandler: # Determine our `limited` status based on the timeline. We do this before # filtering the events so we can accurately determine if there is more to # paginate even if we filter out some/all events. - if len(timeline_events) > room_sync_config.timeline_limit: + if len(timeline_events) > fiddled_timeline_limit: limited = True # Get rid of that extra "+ 1" event because we only used it to determine # if we hit the limit or not - timeline_events = timeline_events[-room_sync_config.timeline_limit :] + timeline_events = timeline_events[-fiddled_timeline_limit:] assert timeline_events[0].internal_metadata.stream_ordering new_room_key = RoomStreamToken( stream=timeline_events[0].internal_metadata.stream_ordering - 1 @@ -1485,7 +1563,9 @@ class SlidingSyncHandler: instance_name=timeline_event.internal_metadata.instance_name, stream=timeline_event.internal_metadata.stream_ordering, ) - if persisted_position.persisted_after(from_token.room_key): + if persisted_position.persisted_after( + from_token.stream_token.room_key + ): num_live += 1 else: # Since we're iterating over the timeline events in @@ -1542,12 +1622,6 @@ class SlidingSyncHandler: # indicate to the client that a state reset happened. Perhaps we should indicate # this by setting `initial: True` and empty `required_state`. - # TODO: Since we can't determine whether we've already sent a room down this - # Sliding Sync connection before (we plan to add this optimization in the - # future), we're always returning the requested room state instead of - # updates. - initial = True - # Check whether the room has a name set name_state_ids = await self.get_current_state_ids_at( room_id=room_id, @@ -1661,6 +1735,13 @@ class SlidingSyncHandler: # FIXME: We probably also care about invite, ban, kick, targets, etc # but the spec only mentions "senders". + elif ( + state_type == EventTypes.Member + and state_key == StateValues.ME + ): + required_state_types.append( + (EventTypes.Member, user.to_string()) + ) else: required_state_types.append((state_type, state_key)) @@ -1691,9 +1772,17 @@ class SlidingSyncHandler: to_token=to_token, ) else: - # TODO: Once we can figure out if we've sent a room down this connection before, - # we can return updates instead of the full required state. - raise NotImplementedError() + assert to_bound is not None + + deltas = await self.store.get_current_state_deltas_for_room( + room_id, to_bound, to_token.room_key + ) + # TODO: Filter room state before fetching events + # TODO: Handle state resets where event_id is None + events = await self.store.get_events( + [d.event_id for d in deltas if d.event_id] + ) + room_state = {(s.type, s.state_key): s for s in events.values()} required_room_state: StateMap[EventBase] = {} if required_state_filter != StateFilter.none(): @@ -1727,18 +1816,32 @@ class SlidingSyncHandler: ) # Figure out the last bump event in the room - last_bump_event_result = ( - await self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id, to_token.room_key, event_types=DEFAULT_BUMP_EVENT_TYPES + last_bump_event_stream_ordering = None + if timeline_events: + for e in reversed(timeline_events): + if ( + e.type in DEFAULT_BUMP_EVENT_TYPES + and e.internal_metadata.stream_ordering > 0 + ): + last_bump_event_stream_ordering = ( + e.internal_metadata.stream_ordering + ) + break + + if last_bump_event_stream_ordering is None: + last_bump_event_result = ( + await self.store.get_last_event_pos_in_room_before_stream_ordering( + room_id, to_token.room_key, event_types=DEFAULT_BUMP_EVENT_TYPES + ) ) - ) + if last_bump_event_result is not None: + last_bump_event_stream_ordering = last_bump_event_result[1].stream # By default, just choose the membership event position bump_stamp = room_membership_for_user_at_to_token.event_pos.stream # But if we found a bump event, use that instead - if last_bump_event_result is not None: - _, bump_event_pos = last_bump_event_result - bump_stamp = bump_event_pos.stream + if last_bump_event_stream_ordering is not None: + bump_stamp = last_bump_event_stream_ordering return SlidingSyncResult.RoomResult( name=room_name, @@ -1807,7 +1910,7 @@ class SlidingSyncHandler: """ user_id = sync_config.user.to_string() - device_id = sync_config.device_id + device_id = sync_config.requester.device_id # Check that this request has a valid device ID (not all requests have # to belong to a device, and so device_id is None), and that the @@ -1844,12 +1947,13 @@ class SlidingSyncHandler: up_to_stream_id=since_stream_id, ) - logger.debug( - "Deleted %d to-device messages up to %d for %s", - deleted, - since_stream_id, - user_id, - ) + if deleted: + logger.debug( + "Deleted %d to-device messages up to %d for %s", + deleted, + since_stream_id, + user_id, + ) messages, stream_id = await self.store.get_messages_for_device( user_id=user_id, @@ -1863,3 +1967,198 @@ class SlidingSyncHandler: next_batch=f"{stream_id}", events=messages, ) + + +class HaveSentRoomFlag(Enum): + """Flag for whether we have sent the room down a sliding sync connection. + + The valid state changes here are: + NEVER -> LIVE + LIVE -> PREVIOUSLY + PREVIOUSLY -> LIVE + """ + + # The room has never been sent down (or we have forgotten we have sent it + # down). + NEVER = 1 + + # We have previously sent the room down, but there are updates that we + # haven't sent down. + PREVIOUSLY = 2 + + # We have sent the room down and the client has received all updates. + LIVE = 3 + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class HaveSentRoom: + """Whether we have sent the room down a sliding sync connection. + + Attributes: + status: Flag of if we have or haven't sent down the room + last_token: If the flag is `PREVIOUSLY` then this is non-null and + contains the last stream token of the last updates we sent down + the room, i.e. we still need to send everything since then to the + client. + """ + + status: HaveSentRoomFlag + last_token: Optional[RoomStreamToken] + + @staticmethod + def previously(last_token: RoomStreamToken) -> "HaveSentRoom": + """Constructor for `PREVIOUSLY` flag.""" + return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token) + + +HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None) +HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None) + + +@attr.s(auto_attribs=True) +class SlidingSyncConnectionStore: + """In-memory store of per-connection state, including what rooms we have + previously sent down a sliding sync connection. + + Note: This is NOT safe to run in a worker setup. + + The complication here is that we need to handle requests being resent, i.e. + if we sent down a room in a response that the client received, we must + consider the room *not* sent when we get the request again. + + This is handled by using an integer "token", which is returned to the client + as part of the sync token. For each connection we store a mapping from + tokens to the room states, and create a new entry when we send down new + rooms. + + Note that for any given sliding sync connection we will only store a maximum + of two different tokens: the previous token from the request and a new token + sent in the response. When we receive a request with a given token, we then + clear out all other entries with a different token. + + Attributes: + _connections: Mapping from `(user_id, conn_id)` to mapping of `token` + to mapping of room ID to `HaveSentRoom`. + """ + + # `(user_id, conn_id)` -> `token` -> `room_id` -> `HaveSentRoom` + _connections: Dict[Tuple[str, str], Dict[int, Dict[str, HaveSentRoom]]] = ( + attr.Factory(dict) + ) + + async def have_sent_room( + self, user_id: str, conn_id: str, connection_token: int, room_id: str + ) -> HaveSentRoom: + """Whether for the given user_id/conn_id/token, return whether we have + previously sent the room down + """ + + sync_statuses = self._connections.setdefault((user_id, conn_id), {}) + room_status = sync_statuses.get(connection_token, {}).get( + room_id, HAVE_SENT_ROOM_NEVER + ) + + return room_status + + async def record_rooms( + self, + user_id: str, + conn_id: str, + from_token: Optional[SlidingSyncStreamToken], + *, + sent_room_ids: StrCollection, + unsent_room_ids: StrCollection, + ) -> int: + """Record which rooms we have/haven't sent down in a new response + + Attributes: + user_id + conn_id + from_token: The since token from the request, if any + sent_room_ids: The set of room IDs that we have sent down as + part of this request (only needs to be ones we didn't + previously sent down). + unsent_room_ids: The set of room IDs that have had updates + since the `last_room_token`, but which were not included in + this request + """ + prev_connection_token = 0 + if from_token is not None: + prev_connection_token = from_token.connection_token + + # If there are no changes then this is a noop. + if not sent_room_ids and not unsent_room_ids: + return prev_connection_token + + sync_statuses = self._connections.setdefault((user_id, conn_id), {}) + + # Generate a new token, removing any existing entries in that token + # (which can happen if requests get resent). + new_store_token = prev_connection_token + 1 + sync_statuses.pop(new_store_token, None) + + # Copy over and update the room mappings. + new_room_statuses = dict(sync_statuses.get(prev_connection_token, {})) + + # Whether we have updated the `new_room_statuses`, if we don't by the + # end we can treat this as a noop. + have_updated = False + for room_id in sent_room_ids: + new_room_statuses[room_id] = HAVE_SENT_ROOM_LIVE + have_updated = True + + # Whether we add/update the entries for unsent rooms depends on the + # existing entry: + # - LIVE: We have previously sent down everything up to + # `last_room_token, so we update the entry to be `PREVIOUSLY` with + # `last_room_token`. + # - PREVIOUSLY: We have previously sent down everything up to *a* + # given token, so we don't need to update the entry. + # - NEVER: We have never previously sent down the room, and we haven't + # sent anything down this time either so we leave it as NEVER. + + # Work out the new state for unsent rooms that were `LIVE`. + if from_token: + new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key) + else: + new_unsent_state = HAVE_SENT_ROOM_NEVER + + for room_id in unsent_room_ids: + prev_state = new_room_statuses.get(room_id) + if prev_state is not None and prev_state.status == HaveSentRoomFlag.LIVE: + new_room_statuses[room_id] = new_unsent_state + have_updated = True + + if not have_updated: + return prev_connection_token + + sync_statuses[new_store_token] = new_room_statuses + + return new_store_token + + async def mark_token_seen( + self, + user_id: str, + conn_id: str, + from_token: Optional[SlidingSyncStreamToken], + ) -> None: + """We have received a request with the given token, so we can clear out + any other tokens associated with the connection. + + If there is no from token then we have started afresh, and so we delete + all tokens associated with the device. + """ + # Clear out any tokens for the connection that doesn't match the one + # from the request. + + sync_statuses = self._connections.pop((user_id, conn_id), {}) + if from_token is None: + return + + sync_statuses = { + i: room_statuses + for i, room_statuses in sync_statuses.items() + if i == from_token.connection_token + } + if sync_statuses: + self._connections[(user_id, conn_id)] = sync_statuses diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index de227faec3..fefc35ecdb 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -2711,7 +2711,7 @@ class SyncHandler: room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, - membership_list=Membership.LIST, + membership_list=frozenset(Membership.LIST), excluded_rooms=sync_result_builder.excluded_room_ids, ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 1d8cbfdf00..7c91b15cef 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -52,9 +52,9 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import trace_with_opname +from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname from synapse.rest.admin.experimental_features import ExperimentalFeature -from synapse.types import JsonDict, Requester, StreamToken +from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken from synapse.types.rest.client import SlidingSyncBody from synapse.util import json_decoder from synapse.util.caches.lrucache import LruCache @@ -881,7 +881,6 @@ class SlidingSyncRestServlet(RestServlet): ) user = requester.user - device_id = requester.device_id timeout = parse_integer(request, "timeout", default=0) # Position in the stream @@ -889,7 +888,9 @@ class SlidingSyncRestServlet(RestServlet): from_token = None if from_token_string is not None: - from_token = await StreamToken.from_string(self.store, from_token_string) + from_token = await SlidingSyncStreamToken.from_string( + self.store, from_token_string + ) # TODO: We currently don't know whether we're going to use sticky params or # maybe some filters like sync v2 where they are built up once and referenced @@ -897,14 +898,19 @@ class SlidingSyncRestServlet(RestServlet): # in. body = parse_and_validate_json_object_from_request(request, SlidingSyncBody) logger.info("Sliding sync request: %r", body) + log_kv({"request_body": body}) + + if body.lists: + set_tag("sliding_sync.lists", True) sync_config = SlidingSyncConfig( user=user, - device_id=device_id, + requester=requester, # FIXME: Currently, we're just manually copying the fields from the - # `SlidingSyncBody` into the config. How can we gurantee into the future + # `SlidingSyncBody` into the config. How can we guarantee into the future # that we don't forget any? I would like something more structured like # `copy_attributes(from=body, to=config)` + conn_id=body.conn_id, lists=body.lists, room_subscriptions=body.room_subscriptions, extensions=body.extensions, @@ -984,7 +990,7 @@ class SlidingSyncRestServlet(RestServlet): serialized_rooms: Dict[str, JsonDict] = {} for room_id, room_result in rooms.items(): serialized_rooms[room_id] = { - "bump_stamp": room_result.bump_stamp, + "bump_stamp": abs(room_result.bump_stamp), "joined_count": room_result.joined_count, "invited_count": room_result.invited_count, "notification_count": room_result.notification_count, diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 001a290e87..e9cdc628d5 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -114,7 +114,7 @@ class ServerNoticesManager: return None rooms = await self._store.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.INVITE, Membership.JOIN] + user_id, (Membership.INVITE, Membership.JOIN) ) for room in rooms: # it's worth noting that there is an asymmetry here in that we @@ -262,7 +262,7 @@ class ServerNoticesManager: # Check whether the user has already joined or been invited to this room. If # that's the case, there is no need to re-invite them. joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.INVITE, Membership.JOIN] + user_id, (Membership.INVITE, Membership.JOIN) ) for room in joined_rooms: if room.room_id == room_id: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 26b8e1a172..8c2c0c5ab0 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -309,6 +309,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) # type: ignore[attr-defined] + self._attempt_to_invalidate_cache( + "get_max_stream_ordering_in_room", (room_id,) + ) if redacts: self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f7acdb859..0c7c2f9306 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -551,7 +551,7 @@ class PersistEventsStore: # From this point onwards the events are only events that we haven't # seen before. - self._store_event_txn(txn, events_and_contexts=events_and_contexts) + self._store_event_txn(txn, room_id, events_and_contexts=events_and_contexts) if new_forward_extremities: self._update_forward_extremities_txn( @@ -1555,6 +1555,7 @@ class PersistEventsStore: def _store_event_txn( self, txn: LoggingTransaction, + room_id: str, events_and_contexts: Collection[Tuple[EventBase, EventContext]], ) -> None: """Insert new events into the event, event_json, redaction and @@ -1629,6 +1630,27 @@ class PersistEventsStore: ], ) + # Update the `sliding_sync_room_metadata` with the latest + # (non-backfilled, ie positive) stream ordering. + # + # We know this list is sorted and non-empty, so we just take the last + # one event. + max_stream_ordering: int + for e, _ in events_and_contexts: + assert e.internal_metadata.stream_ordering is not None + max_stream_ordering = e.internal_metadata.stream_ordering + + if max_stream_ordering > 0: + self.db_pool.simple_upsert_txn( + txn, + table="sliding_sync_room_metadata", + keyvalues={"room_id": room_id}, + values={ + "instance_name": self._instance_name, + "last_stream_ordering": max_stream_ordering, + }, + ) + # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 80a4bf95f2..498a136543 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -228,6 +228,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return row return bool(row[0]), bool(row[1]) + @cached(max_entries=10000) + async def get_room_type(self, room_id: str) -> Optional[str]: + # TODO: Upsert room_stats_state on room creation / initial join. + return await self.db_pool.simple_select_one_onecol( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcol="room_type", + allow_none=True, + desc="get_room_type", + ) + + @cachedList(cached_method_name="get_room_type", list_name="room_ids") + async def bulk_get_room_type( + self, room_ids: StrCollection + ) -> Mapping[str, Optional[str]]: + rows = await self.db_pool.simple_select_many_batch( + table="room_stats_state", + column="room_id", + iterable=room_ids, + retcols=("room_id", "room_type"), + desc="bulk_get_room_type", + ) + return dict(rows) + async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 640ab123f0..2e0e6afac5 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -385,7 +385,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ return await self.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.INVITE] + user_id, (Membership.INVITE,) ) async def get_knocked_at_rooms_for_local_user( @@ -401,7 +401,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ return await self.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.KNOCK] + user_id, (Membership.KNOCK,) ) async def get_invite_for_local_user_in_room( diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 036972ac25..cd6cb2c7a9 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -26,6 +26,8 @@ import attr from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main.stream import _filter_results_by_stream +from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -156,3 +158,39 @@ class StateDeltasStore(SQLBaseStore): "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) + + async def get_current_state_deltas_for_room( + self, room_id: str, from_token: RoomStreamToken, to_token: RoomStreamToken + ) -> List[StateDelta]: + """Get the state deltas between that have happened between two + tokens.""" + + def get_current_state_deltas_for_room_txn( + txn: LoggingTransaction, + ) -> List[StateDelta]: + sql = """ + SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE room_id = ? AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute( + sql, (room_id, from_token.stream, to_token.get_max_stream_pos()) + ) + + return [ + StateDelta( + stream_id=row[1], + room_id=room_id, + event_type=row[2], + state_key=row[3], + event_id=row[4], + prev_event_id=row[5], + ) + for row in txn + if _filter_results_by_stream(from_token, to_token, row[0], row[1]) + ] + + return await self.db_pool.runInteraction( + "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn + ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index b119acda29..7df811e451 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -50,6 +50,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -78,8 +79,13 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection -from synapse.util.caches.descriptors import cached +from synapse.types import ( + JsonDict, + PersistedEventPosition, + RoomStreamToken, + StrCollection, +) +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter @@ -611,6 +617,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() + database.updates.register_background_update_handler( + "sliding_sync_room_metadata", self._sliding_sync_room_metadata_bg_update + ) + def get_room_max_stream_ordering(self) -> int: """Get the stream_ordering of regular events that we have committed up to @@ -1186,6 +1196,52 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return None + @cachedList( + cached_method_name="get_max_stream_ordering_in_room", + list_name="room_ids", + ) + async def get_max_stream_ordering_in_rooms( + self, room_ids: StrCollection + ) -> Mapping[str, Optional[PersistedEventPosition]]: + """Get the positions for the latest event in a room. + + A batched version of `get_max_stream_ordering_in_room`. + """ + rows = await self.db_pool.simple_select_many_batch( + table="sliding_sync_room_metadata", + column="room_id", + iterable=room_ids, + retcols=("room_id", "instance_name", "last_stream_ordering"), + desc="get_max_stream_ordering_in_rooms", + ) + + return { + room_id: PersistedEventPosition(instance_name, stream) + for room_id, instance_name, stream in rows + } + + @cached(max_entries=10000) + async def get_max_stream_ordering_in_room( + self, + room_id: str, + ) -> Optional[PersistedEventPosition]: + """Get the position for the latest event in a room. + + Note: this may be after the current token for the room stream on this + process (e.g. due to replication lag) + """ + row = await self.db_pool.simple_select_one( + table="sliding_sync_room_metadata", + retcols=("instance_name", "last_stream_ordering"), + keyvalues={"room_id": room_id}, + allow_none=True, + desc="get_max_stream_ordering_in_room", + ) + if not row: + return None + + return PersistedEventPosition(instance_name=row[0], stream=row[1]) + async def get_last_event_pos_in_room_before_stream_ordering( self, room_id: str, @@ -2073,3 +2129,88 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return RoomStreamToken(stream=last_position.stream - 1) return None + + async def _sliding_sync_room_metadata_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update to fill out 'sliding_sync_room_metadata' table""" + previous_room = progress.get("previous_room", "") + + def _sliding_sync_room_metadata_bg_update_txn(txn: LoggingTransaction) -> int: + # Both these queries are just getting the most recent + # instance_name/stream ordering for the next N rooms. + if isinstance(self.database_engine, PostgresEngine): + sql = """ + SELECT room_id, instance_name, stream_ordering FROM rooms AS r, + LATERAL ( + SELECT instance_name, stream_ordering + FROM events WHERE events.room_id = r.room_id + ORDER BY stream_ordering DESC + LIMIT 1 + ) e + WHERE r.room_id > ? + ORDER BY r.room_id ASC + LIMIT ? + """ + else: + sql = """ + SELECT + room_id, + ( + SELECT instance_name + FROM events WHERE events.room_id = r.room_id + ORDER BY stream_ordering DESC + LIMIT 1 + ), + ( + SELECT stream_ordering + FROM events WHERE events.room_id = r.room_id + ORDER BY stream_ordering DESC + LIMIT 1 + ) + FROM rooms AS r + WHERE r.room_id > ? + ORDER BY r.room_id ASC + LIMIT ? + """ + + txn.execute(sql, (previous_room, batch_size)) + rows = txn.fetchall() + if not rows: + return 0 + + self.db_pool.simple_upsert_many_txn( + txn, + table="sliding_sync_room_metadata", + key_names=("room_id",), + key_values=[(room_id,) for room_id, _, _ in rows], + value_names=( + "instance_name", + "last_stream_ordering", + ), + value_values=[ + ( + instance_name or "master", + stream, + ) + for _, instance_name, stream in rows + ], + ) + + self.db_pool.updates._background_update_progress_txn( + txn, "sliding_sync_room_metadata", {"previous_room": rows[-1][0]} + ) + + return len(rows) + + rows = await self.db_pool.runInteraction( + "_sliding_sync_room_metadata_bg_update", + _sliding_sync_room_metadata_bg_update_txn, + ) + + if rows == 0: + await self.db_pool.updates._end_background_update( + "sliding_sync_room_metadata" + ) + + return rows diff --git a/synapse/storage/schema/main/delta/85/07_sliding_sync.sql b/synapse/storage/schema/main/delta/85/07_sliding_sync.sql new file mode 100644 index 0000000000..e8bc33ff40 --- /dev/null +++ b/synapse/storage/schema/main/delta/85/07_sliding_sync.sql @@ -0,0 +1,24 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- A table that maps from room ID to metadata useful for sliding sync. +CREATE TABLE sliding_sync_room_metadata ( + room_id TEXT NOT NULL PRIMARY KEY, + + -- The instance_name / stream ordering of the last event in the room. + instance_name TEXT NOT NULL, + last_stream_ordering BIGINT NOT NULL +); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8507, 'sliding_sync_room_metadata', '{}'); diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 3962ecc996..23ac1842f8 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1138,6 +1138,43 @@ StreamToken.START = StreamToken( @attr.s(slots=True, frozen=True, auto_attribs=True) +class SlidingSyncStreamToken: + """The same as a `StreamToken`, but includes an extra field at the start for + the sliding sync connection token (separated by a '/'). This is used to + store per-connection state. + + This then looks something like: + 5/s2633508_17_338_6732159_1082514_541479_274711_265584_1_379 + """ + + stream_token: StreamToken + connection_token: int + + @staticmethod + @cancellable + async def from_string(store: "DataStore", string: str) -> "SlidingSyncStreamToken": + """Creates a SlidingSyncStreamToken from its textual representation.""" + try: + connection_token_str, stream_token_str = string.split("/", 1) + connection_token = int(connection_token_str) + stream_token = await StreamToken.from_string(store, stream_token_str) + + return SlidingSyncStreamToken( + stream_token=stream_token, + connection_token=connection_token, + ) + except CancelledError: + raise + except Exception: + raise SynapseError(400, "Invalid stream token") + + async def to_string(self, store: "DataStore") -> str: + """Serializes the token to a string""" + stream_token_str = await self.stream_token.to_string(store) + return f"{self.connection_token}/{stream_token_str}" + + +@attr.s(slots=True, frozen=True, auto_attribs=True) class PersistedPosition: """Position of a newly persisted row with instance that persisted it.""" diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py index 409120470a..0c2ab13c93 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py @@ -31,7 +31,14 @@ else: from pydantic import Extra from synapse.events import EventBase -from synapse.types import JsonDict, JsonMapping, StreamToken, UserID +from synapse.types import ( + JsonDict, + JsonMapping, + Requester, + SlidingSyncStreamToken, + StreamToken, + UserID, +) from synapse.types.rest.client import SlidingSyncBody if TYPE_CHECKING: @@ -102,7 +109,7 @@ class SlidingSyncConfig(SlidingSyncBody): """ user: UserID - device_id: Optional[str] + requester: Requester # Pydantic config class Config: @@ -113,6 +120,31 @@ class SlidingSyncConfig(SlidingSyncBody): # Allow custom types like `UserID` to be used in the model arbitrary_types_allowed = True + def connection_id(self) -> str: + """Return a string identifier for this connection. May clash with + connection IDs from different users. + + This is generally a combination of device ID and conn_id. However, both + these two are optional (e.g. puppet access tokens don't have device + IDs), so this handles those edge cases. + """ + + # `conn_id` can be null, in which case we default to the empty string + # (if conn ID is empty then the client can't have multiple sync loops) + conn_id = self.conn_id or "" + + if self.requester.device_id: + return f"D/{self.requester.device_id}/{conn_id}" + + if self.requester.access_token_id: + # If we don't have a device, then the access token ID should be a + # stable ID. + return f"A/{self.requester.access_token_id}/{conn_id}" + + # If we have neither then its likely an AS or some weird token. Either + # way we can just fail here. + raise Exception("Cannot use sliding sync with access token type") + class OperationType(Enum): """ @@ -287,7 +319,7 @@ class SlidingSyncResult: def __bool__(self) -> bool: return bool(self.to_device) - next_pos: StreamToken + next_pos: SlidingSyncStreamToken lists: Dict[str, SlidingWindowList] rooms: Dict[str, RoomResult] extensions: Extensions @@ -300,7 +332,7 @@ class SlidingSyncResult: return bool(self.lists or self.rooms or self.extensions) @staticmethod - def empty(next_pos: StreamToken) -> "SlidingSyncResult": + def empty(next_pos: SlidingSyncStreamToken) -> "SlidingSyncResult": "Return a new empty result" return SlidingSyncResult( next_pos=next_pos, diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index dbe37bc712..5be8cf5389 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -120,6 +120,9 @@ class SlidingSyncBody(RequestBodyModel): Sliding Sync API request body. Attributes: + conn_id: An optional string to identify this connection to the server. If this + is missing, only 1 sliding sync connection can be made to the server at + any one time. lists: Sliding window API. A map of list key to list information (:class:`SlidingSyncList`). Max lists: 100. The list keys should be arbitrary strings which the client is using to refer to the list. Keep this @@ -315,6 +318,8 @@ class SlidingSyncBody(RequestBodyModel): to_device: Optional[ToDeviceExtension] = None + conn_id: Optional[str] + # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 if TYPE_CHECKING: lists: Optional[Dict[str, SlidingSyncList]] = None diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index a85ea994de..77caab2489 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -535,7 +535,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): # Check that the membership of @invitee:test in the room is now "leave". memberships = self.get_success( store.get_rooms_for_local_user_where_membership_is( - invitee_id, [Membership.LEAVE] + invitee_id, (Membership.LEAVE,) ) ) self.assertEqual(len(memberships), 1, memberships) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index a008ee465b..2a27571929 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -50,7 +50,14 @@ from synapse.rest.client import ( sync, ) from synapse.server import HomeServer -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID +from synapse.types import ( + JsonDict, + RoomStreamToken, + SlidingSyncStreamToken, + StreamKeyType, + StreamToken, + UserID, +) from synapse.util import Clock from tests import unittest @@ -1448,7 +1455,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): ) future_position_token_serialized = self.get_success( - future_position_token.to_string(self.store) + SlidingSyncStreamToken(future_position_token, 0).to_string(self.store) ) # Make the Sliding Sync request @@ -2605,7 +2612,22 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): room_id1, "activity before token2", tok=user2_tok ) - from_token = self.event_sources.get_current_token() + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 4, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + from_token = channel.json_body["pos"] # Join the room after the `from_token` which will make us consider this room as # `newly_joined`. @@ -2627,8 +2649,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Make an incremental Sliding Sync request (what we're trying to test) channel = self.make_request( "POST", - self.sync_endpoint - + f"?pos={self.get_success(from_token.to_string(self.store))}", + self.sync_endpoint + f"?pos={from_token}", { "lists": { "foo-list": { @@ -2814,7 +2835,22 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.helper.send(room_id1, "activity after invite3", tok=user2_tok) self.helper.send(room_id1, "activity after invite4", tok=user2_tok) - from_token = self.event_sources.get_current_token() + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 4, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + from_token = channel.json_body["pos"] self.helper.send(room_id1, "activity after token5", tok=user2_tok) self.helper.send(room_id1, "activity after toekn6", tok=user2_tok) @@ -2822,8 +2858,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Make the Sliding Sync request channel = self.make_request( "POST", - self.sync_endpoint - + f"?pos={self.get_success(from_token.to_string(self.store))}", + self.sync_endpoint + f"?pos={from_token}", { "lists": { "foo-list": { @@ -3071,7 +3106,22 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.helper.send(room_id1, "activity after invite3", tok=user2_tok) self.helper.send(room_id1, "activity after invite4", tok=user2_tok) - from_token = self.event_sources.get_current_token() + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 4, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + from_token = channel.json_body["pos"] self.helper.send(room_id1, "activity after token5", tok=user2_tok) self.helper.send(room_id1, "activity after toekn6", tok=user2_tok) @@ -3079,8 +3129,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Make the Sliding Sync request channel = self.make_request( "POST", - self.sync_endpoint - + f"?pos={self.get_success(from_token.to_string(self.store))}", + self.sync_endpoint + f"?pos={from_token}", { "lists": { "foo-list": { @@ -3236,7 +3285,22 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.helper.send(room_id1, "activity before2", tok=user2_tok) self.helper.join(room_id1, user1_id, tok=user1_tok) - from_token = self.event_sources.get_current_token() + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 4, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + from_token = channel.json_body["pos"] event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok) event_response4 = self.helper.send(room_id1, "activity after4", tok=user2_tok) @@ -3252,8 +3316,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Make the Sliding Sync request channel = self.make_request( "POST", - self.sync_endpoint - + f"?pos={self.get_success(from_token.to_string(self.store))}", + self.sync_endpoint + f"?pos={from_token}", { "lists": { "foo-list": { @@ -3313,15 +3376,29 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.helper.send(room_id1, "activity after3", tok=user2_tok) - from_token = self.event_sources.get_current_token() + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 4, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + from_token = channel.json_body["pos"] self.helper.send(room_id1, "activity after4", tok=user2_tok) # Make the Sliding Sync request channel = self.make_request( "POST", - self.sync_endpoint - + f"?pos={self.get_success(from_token.to_string(self.store))}", + self.sync_endpoint + f"?pos={from_token}", { "lists": { "foo-list": { @@ -3448,13 +3525,27 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) self.helper.join(room_id1, user1_id, tok=user1_tok) - after_room_token = self.event_sources.get_current_token() + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 4, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + after_room_token = channel.json_body["pos"] # Make the Sliding Sync request channel = self.make_request( "POST", - self.sync_endpoint - + f"?pos={self.get_success(after_room_token.to_string(self.store))}", + self.sync_endpoint + f"?pos={after_room_token}", { "lists": { "foo-list": { @@ -3473,22 +3564,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.json_body) - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - # The returned state doesn't change from initial to incremental sync. In the - # future, we will only return updates but only if we've sent the room down the + # We only return updates but only if we've sent the room down the # connection before. - self._assertRequiredStateIncludes( - channel.json_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - state_map[(EventTypes.RoomHistoryVisibility, "")], - }, - exact=True, - ) - self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state")) + self.assertNotIn(room_id1, channel.json_body["rooms"]) def test_rooms_required_state_wildcard(self) -> None: """ @@ -3726,7 +3804,22 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): user3_id = self.register_user("user3", "pass") user3_tok = self.login(user3_id, "pass") - from_token = self.event_sources.get_current_token() + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 4, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + from_token = channel.json_body["pos"] room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) self.helper.join(room_id1, user1_id, tok=user1_tok) @@ -3764,8 +3857,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Make the Sliding Sync request with lazy loading for the room members channel = self.make_request( "POST", - self.sync_endpoint - + f"?pos={self.get_success(from_token.to_string(self.store))}", + self.sync_endpoint + f"?pos={from_token}", { "lists": { "foo-list": { @@ -4227,6 +4319,187 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): channel.json_body["rooms"].get(room_id1), channel.json_body["rooms"] ) + def test_incremental_sync_incremental_state(self) -> None: + """Test that we only get state updates in incremental sync for rooms + we've already seen. + """ + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id1, user1_id, tok=user1_tok) + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [ + [EventTypes.Create, ""], + [EventTypes.RoomHistoryVisibility, ""], + # This one doesn't exist in the room + [EventTypes.Name, ""], + ], + "timeline_limit": 0, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + from_token = channel.json_body["pos"] + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + self._assertRequiredStateIncludes( + channel.json_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Create, "")], + state_map[(EventTypes.RoomHistoryVisibility, "")], + }, + exact=True, + ) + + # Send a state event + self.helper.send_state( + room_id1, EventTypes.Name, body={"name": "foo"}, tok=user2_tok + ) + + channel = self.make_request( + "POST", + self.sync_endpoint + f"?pos={from_token}", + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [ + [EventTypes.Create, ""], + [EventTypes.RoomHistoryVisibility, ""], + [EventTypes.Name, ""], + ], + "timeline_limit": 0, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + self._assertRequiredStateIncludes( + channel.json_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Name, "")], + }, + exact=True, + ) + + def test_incremental_sync_full_state_new_room(self) -> None: + """Test that we get state all state in incremental sync for rooms that + we haven't seen before. + """ + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id1, user1_id, tok=user1_tok) + + room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id2, user1_id, tok=user1_tok) + + # Make the Sliding Sync request, we'll only receive room_id2 + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 0]], + "required_state": [ + [EventTypes.Create, ""], + [EventTypes.RoomHistoryVisibility, ""], + # This one doesn't exist in the room + [EventTypes.Name, ""], + ], + "timeline_limit": 0, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + from_token = channel.json_body["pos"] + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id2) + ) + + self._assertRequiredStateIncludes( + channel.json_body["rooms"][room_id2]["required_state"], + { + state_map[(EventTypes.Create, "")], + state_map[(EventTypes.RoomHistoryVisibility, "")], + }, + exact=True, + ) + self.assertNotIn(room_id1, channel.json_body["rooms"]) + + # Send a state event in room 1 + self.helper.send_state( + room_id1, EventTypes.Name, body={"name": "foo"}, tok=user2_tok + ) + + # We should get room_id1 down sync, with the full state. + channel = self.make_request( + "POST", + self.sync_endpoint + f"?pos={from_token}", + { + "lists": { + "foo-list": { + "ranges": [[0, 0]], + "required_state": [ + [EventTypes.Create, ""], + [EventTypes.RoomHistoryVisibility, ""], + [EventTypes.Name, ""], + ], + "timeline_limit": 0, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + self._assertRequiredStateIncludes( + channel.json_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Create, "")], + state_map[(EventTypes.RoomHistoryVisibility, "")], + state_map[(EventTypes.Name, "")], + }, + exact=True, + ) + class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase): """Tests for the to-device sliding sync extension""" diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index c4e216c308..037bbca1ba 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -440,6 +440,7 @@ class EventChainStoreTestCase(HomeserverTestCase): assert persist_events_store is not None persist_events_store._store_event_txn( txn, + events[0].room_id, [ (e, EventContext(self.hs.get_storage_controllers(), {})) for e in events diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 418b556108..e2f19e25e3 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -68,7 +68,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): rooms_for_user = self.get_success( self.store.get_rooms_for_local_user_where_membership_is( - self.u_alice, [Membership.JOIN] + self.u_alice, (Membership.JOIN,) ) ) |