summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-07-11 16:05:11 +0100
committerErik Johnston <erik@matrix.org>2024-07-11 16:05:11 +0100
commitdd50e9e86f7c4f7c6981373eb5bf5c2515b1058c (patch)
tree1d9e3b74e2370898223da4fa6bb463b0456f4130 /synapse
parentMerge remote-tracking branch 'origin/release-v1.111' into matrix-org-hotfixes (diff)
parentFix filtering room types on remote rooms (#17434) (diff)
downloadsynapse-dd50e9e86f7c4f7c6981373eb5bf5c2515b1058c.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/ratelimiting.py5
-rw-r--r--synapse/handlers/sliding_sync.py276
-rw-r--r--synapse/rest/client/sync.py17
-rw-r--r--synapse/storage/databases/main/state.py52
-rw-r--r--synapse/types/handlers/__init__.py35
-rw-r--r--synapse/types/rest/client/__init__.py48
6 files changed, 352 insertions, 81 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py

index 26b8711851..b80630c5d3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py
@@ -236,9 +236,8 @@ class Ratelimiter: requester: The requester that is doing the action, if any. key: An arbitrary key used to classify an action. Defaults to the requester's user ID. - n_actions: The number of times the user wants to do this action. If the user - cannot do all of the actions, the user's action count is not incremented - at all. + n_actions: The number of times the user performed the action. May be negative + to "refund" the rate limit. _time_now_s: The current time. Optional, defaults to the current time according to self.clock. Only used by tests. """ diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 8e2f751c02..8e6c2fb860 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py
@@ -18,18 +18,13 @@ # # import logging +from itertools import chain from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple import attr from immutabledict import immutabledict -from synapse.api.constants import ( - AccountDataTypes, - Direction, - EventContentFields, - EventTypes, - Membership, -) +from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership from synapse.events import EventBase from synapse.events.utils import strip_event from synapse.handlers.relations import BundledAggregations @@ -464,6 +459,7 @@ class SlidingSyncHandler: membership_state_keys = room_sync_config.required_state_map.get( EventTypes.Member ) + # Also see `StateFilter.must_await_full_state(...)` for comparison lazy_loading = ( membership_state_keys is not None and len(membership_state_keys) == 1 @@ -540,11 +536,15 @@ class SlidingSyncHandler: rooms[room_id] = room_sync_result + extensions = await self.get_extensions_response( + sync_config=sync_config, to_token=to_token + ) + return SlidingSyncResult( next_pos=to_token, lists=lists, rooms=rooms, - extensions={}, + extensions=extensions, ) async def get_sync_room_ids_for_user( @@ -953,11 +953,15 @@ class SlidingSyncHandler: # provided in the list. `None` is a valid type for rooms which do not have a # room type. if filters.room_types is not None or filters.not_room_types is not None: - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for room_id in filtered_room_id_set.copy(): - create_event = await self.store.get_create_event_for_room(room_id) - room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + room_to_type = await self.store.bulk_get_room_type( + { + room_id + for room_id in filtered_room_id_set + # We only know the room types for joined rooms + if sync_room_map[room_id].membership == Membership.JOIN + } + ) + for room_id, room_type in room_to_type.items(): if ( filters.room_types is not None and room_type not in filters.room_types @@ -1202,7 +1206,7 @@ class SlidingSyncHandler: # Figure out any stripped state events for invite/knocks. This allows the # potential joiner to identify the room. - stripped_state: List[JsonDict] = [] + stripped_state: Optional[List[JsonDict]] = None if room_membership_for_user_at_to_token.membership in ( Membership.INVITE, Membership.KNOCK, @@ -1239,7 +1243,7 @@ class SlidingSyncHandler: # updates. initial = True - # Fetch the required state for the room + # Fetch the `required_state` for the room # # No `required_state` for invite/knock rooms (just `stripped_state`) # @@ -1247,13 +1251,15 @@ class SlidingSyncHandler: # of membership. Currently, we have to make this optional because # `invite`/`knock` rooms only have `stripped_state`. See # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 + # + # Calculate the `StateFilter` based on the `required_state` for the room room_state: Optional[StateMap[EventBase]] = None + required_room_state: Optional[StateMap[EventBase]] = None if room_membership_for_user_at_to_token.membership not in ( Membership.INVITE, Membership.KNOCK, ): - # Calculate the `StateFilter` based on the `required_state` for the room - state_filter: Optional[StateFilter] = StateFilter.none() + required_state_filter = StateFilter.none() # If we have a double wildcard ("*", "*") in the `required_state`, we need # to fetch all state for the room # @@ -1276,7 +1282,7 @@ class SlidingSyncHandler: if StateValues.WILDCARD in room_sync_config.required_state_map.get( StateValues.WILDCARD, set() ): - state_filter = StateFilter.all() + required_state_filter = StateFilter.all() # TODO: `StateFilter` currently doesn't support wildcard event types. We're # currently working around this by returning all state to the client but it # would be nice to fetch less from the database and return just what the @@ -1285,7 +1291,7 @@ class SlidingSyncHandler: room_sync_config.required_state_map.get(StateValues.WILDCARD) is not None ): - state_filter = StateFilter.all() + required_state_filter = StateFilter.all() else: required_state_types: List[Tuple[str, Optional[str]]] = [] for ( @@ -1317,51 +1323,88 @@ class SlidingSyncHandler: else: required_state_types.append((state_type, state_key)) - state_filter = StateFilter.from_types(required_state_types) - - # We can skip fetching state if we don't need any - if state_filter != StateFilter.none(): - # We can return all of the state that was requested if we're doing an - # initial sync - if initial: - # People shouldn't see past their leave/ban event - if room_membership_for_user_at_to_token.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - room_state = await self.storage_controllers.state.get_state_at( - room_id, - stream_position=to_token.copy_and_replace( - StreamKeyType.ROOM, - room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), - ), - state_filter=state_filter, - # Partially-stated rooms should have all state events except for - # the membership events and since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able - # to retrieve everything requested. Plus we don't want to block - # the whole sync waiting for this one room. - await_full_state=False, - ) - # Otherwise, we can get the latest current state in the room - else: - room_state = await self.storage_controllers.state.get_current_state( - room_id, - state_filter, - # Partially-stated rooms should have all state events except for - # the membership events and since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able - # to retrieve everything requested. Plus we don't want to block - # the whole sync waiting for this one room. - await_full_state=False, - ) - # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` + required_state_filter = StateFilter.from_types(required_state_types) + + # We need this base set of info for the response so let's just fetch it along + # with the `required_state` for the room + META_ROOM_STATE = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")] + state_filter = StateFilter( + types=StateFilter.from_types( + chain(META_ROOM_STATE, required_state_filter.to_types()) + ).types, + include_others=required_state_filter.include_others, + ) + + # We can return all of the state that was requested if this was the first + # time we've sent the room down this connection. + if initial: + # People shouldn't see past their leave/ban event + if room_membership_for_user_at_to_token.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + room_state = await self.storage_controllers.state.get_state_at( + room_id, + stream_position=to_token.copy_and_replace( + StreamKeyType.ROOM, + room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), + ), + state_filter=state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # Otherwise, we can get the latest current state in the room else: - # TODO: Once we can figure out if we've sent a room down this connection before, - # we can return updates instead of the full required state. - raise NotImplementedError() + room_state = await self.storage_controllers.state.get_current_state( + room_id, + state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` + else: + # TODO: Once we can figure out if we've sent a room down this connection before, + # we can return updates instead of the full required state. + raise NotImplementedError() + + if required_state_filter != StateFilter.none(): + required_room_state = required_state_filter.filter_state(room_state) + + # Find the room name and avatar from the state + room_name: Optional[str] = None + room_avatar: Optional[str] = None + if room_state is not None: + name_event = room_state.get((EventTypes.Name, "")) + if name_event is not None: + room_name = name_event.content.get("name") + + avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if avatar_event is not None: + room_avatar = avatar_event.content.get("url") + elif stripped_state is not None: + for event in stripped_state: + if event["type"] == EventTypes.Name: + room_name = event.get("content", {}).get("name") + elif event["type"] == EventTypes.RoomAvatar: + room_avatar = event.get("content", {}).get("url") + + # Found everything so we can stop looking + if room_name is not None and room_avatar is not None: + break # Figure out the last bump event in the room last_bump_event_result = ( @@ -1378,16 +1421,16 @@ class SlidingSyncHandler: bump_stamp = bump_event_pos.stream return SlidingSyncResult.RoomResult( - # TODO: Dummy value - name=None, - # TODO: Dummy value - avatar=None, + name=room_name, + avatar=room_avatar, # TODO: Dummy value heroes=None, # TODO: Dummy value is_dm=False, initial=initial, - required_state=list(room_state.values()) if room_state else None, + required_state=( + list(required_room_state.values()) if required_room_state else None + ), timeline_events=timeline_events, bundled_aggregations=bundled_aggregations, stripped_state=stripped_state, @@ -1404,3 +1447,100 @@ class SlidingSyncHandler: notification_count=0, highlight_count=0, ) + + async def get_extensions_response( + self, + sync_config: SlidingSyncConfig, + to_token: StreamToken, + ) -> SlidingSyncResult.Extensions: + """Handle extension requests. + + Args: + sync_config: Sync configuration + to_token: The point in the stream to sync up to. + """ + + if sync_config.extensions is None: + return SlidingSyncResult.Extensions() + + to_device_response = None + if sync_config.extensions.to_device: + to_device_response = await self.get_to_device_extensions_response( + sync_config=sync_config, + to_device_request=sync_config.extensions.to_device, + to_token=to_token, + ) + + return SlidingSyncResult.Extensions(to_device=to_device_response) + + async def get_to_device_extensions_response( + self, + sync_config: SlidingSyncConfig, + to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension, + to_token: StreamToken, + ) -> SlidingSyncResult.Extensions.ToDeviceExtension: + """Handle to-device extension (MSC3885) + + Args: + sync_config: Sync configuration + to_device_request: The to-device extension from the request + to_token: The point in the stream to sync up to. + """ + + user_id = sync_config.user.to_string() + device_id = sync_config.device_id + + # Check that this request has a valid device ID (not all requests have + # to belong to a device, and so device_id is None), and that the + # extension is enabled. + if device_id is None or not to_device_request.enabled: + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=f"{to_token.to_device_key}", + events=[], + ) + + since_stream_id = 0 + if to_device_request.since is not None: + # We've already validated this is an int. + since_stream_id = int(to_device_request.since) + + if to_token.to_device_key < since_stream_id: + # The since token is ahead of our current token, so we return an + # empty response. + logger.warning( + "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r", + since_stream_id, + to_token.to_device_key, + ) + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=to_device_request.since, + events=[], + ) + + # Delete everything before the given since token, as we know the + # device must have received them. + deleted = await self.store.delete_messages_for_device( + user_id=user_id, + device_id=device_id, + up_to_stream_id=since_stream_id, + ) + + logger.debug( + "Deleted %d to-device messages up to %d for %s", + deleted, + since_stream_id, + user_id, + ) + + messages, stream_id = await self.store.get_messages_for_device( + user_id=user_id, + device_id=device_id, + from_stream_id=since_stream_id, + to_stream_id=to_token.to_device_key, + limit=min(to_device_request.limit, 100), # Limit to at most 100 events + ) + + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=f"{stream_id}", + events=messages, + ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 13aed1dc85..94d5faf9f7 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py
@@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet): response["rooms"] = await self.encode_rooms( requester, sliding_sync_result.rooms ) - response["extensions"] = {} # TODO: sliding_sync_result.extensions + response["extensions"] = await self.encode_extensions( + requester, sliding_sync_result.extensions + ) return response @@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet): return serialized_rooms + async def encode_extensions( + self, requester: Requester, extensions: SlidingSyncResult.Extensions + ) -> JsonDict: + result = {} + + if extensions.to_device is not None: + result["to_device"] = { + "next_batch": extensions.to_device.next_batch, + "events": extensions.to_device.events, + } + + return result + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index b2a67aff89..5188b2f7a4 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -41,7 +41,7 @@ from typing import ( import attr -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase @@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_event(create_id) return create_event + @cached(max_entries=10000) + async def get_room_type(self, room_id: str) -> Optional[str]: + """Get the room type for a given room. The server must be joined to the + given room. + """ + + row = await self.db_pool.simple_select_one( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcols=("room_type",), + allow_none=True, + desc="get_room_type", + ) + + if row is not None: + return row[0] + + # If we haven't updated `room_stats_state` with the room yet, query the + # create event directly. + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + return room_type + + @cachedList(cached_method_name="get_room_type", list_name="room_ids") + async def bulk_get_room_type( + self, room_ids: Set[str] + ) -> Mapping[str, Optional[str]]: + """Bulk fetch room types for the given rooms, the server must be in all + the rooms given. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="room_stats_state", + column="room_id", + iterable=room_ids, + retcols=("room_id", "room_type"), + desc="bulk_get_room_type", + ) + + # If we haven't updated `room_stats_state` with the room yet, query the + # create events directly. This should happen only rarely so we don't + # mind if we do this in a loop. + results = dict(rows) + for room_id in room_ids - results.keys(): + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + results[room_id] = room_type + + return results + @cached(max_entries=100000, iterable=True) async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 43dcdf20dd..a8a3a8f242 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py
@@ -18,7 +18,7 @@ # # from enum import Enum -from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple import attr from typing_extensions import TypedDict @@ -252,10 +252,39 @@ class SlidingSyncResult: count: int ops: List[Operation] + @attr.s(slots=True, frozen=True, auto_attribs=True) + class Extensions: + """Responses for extensions + + Attributes: + to_device: The to-device extension (MSC3885) + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceExtension: + """The to-device extension (MSC3885) + + Attributes: + next_batch: The to-device stream token the client should use + to get more results + events: A list of to-device messages for the client + """ + + next_batch: str + events: Sequence[JsonMapping] + + def __bool__(self) -> bool: + return bool(self.events) + + to_device: Optional[ToDeviceExtension] = None + + def __bool__(self) -> bool: + return bool(self.to_device) + next_pos: StreamToken lists: Dict[str, SlidingWindowList] rooms: Dict[str, RoomResult] - extensions: JsonMapping + extensions: Extensions def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -271,5 +300,5 @@ class SlidingSyncResult: next_pos=next_pos, lists={}, rooms={}, - extensions={}, + extensions=SlidingSyncResult.Extensions(), ) diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 55f6b44053..1e8fe76c99 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py
@@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel): class RoomSubscription(CommonRoomParameters): pass - class Extension(RequestBodyModel): - enabled: Optional[StrictBool] = False - lists: Optional[List[StrictStr]] = None - rooms: Optional[List[StrictStr]] = None + class Extensions(RequestBodyModel): + """The extensions section of the request. + + Extensions MUST have an `enabled` flag which defaults to `false`. If a client + sends an unknown extension name, the server MUST ignore it (or else backwards + compatibility between clients and servers is broken when a newer client tries to + communicate with an older server). + """ + + class ToDeviceExtension(RequestBodyModel): + """The to-device extension (MSC3885) + + Attributes: + enabled + limit: Maximum number of to-device messages to return + since: The `next_batch` from the previous sync response + """ + + enabled: Optional[StrictBool] = False + limit: StrictInt = 100 + since: Optional[StrictStr] = None + + @validator("since") + def since_token_check( + cls, value: Optional[StrictStr] + ) -> Optional[StrictStr]: + # `since` comes in as an opaque string token but we know that it's just + # an integer representing the position in the device inbox stream. We + # want to pre-validate it to make sure it works fine in downstream code. + if value is None: + return value + + try: + int(value) + except ValueError: + raise ValueError( + "'extensions.to_device.since' is invalid (should look like an int)" + ) + + return value + + to_device: Optional[ToDeviceExtension] = None # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 if TYPE_CHECKING: @@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel): else: lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type] room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None - extensions: Optional[Dict[StrictStr, Extension]] = None + extensions: Optional[Extensions] = None @validator("lists") def lists_length_check(