diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 26b8711851..b80630c5d3 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -236,9 +236,8 @@ class Ratelimiter:
requester: The requester that is doing the action, if any.
key: An arbitrary key used to classify an action. Defaults to the
requester's user ID.
- n_actions: The number of times the user wants to do this action. If the user
- cannot do all of the actions, the user's action count is not incremented
- at all.
+ n_actions: The number of times the user performed the action. May be negative
+ to "refund" the rate limit.
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
"""
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 8e2f751c02..8e6c2fb860 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -18,18 +18,13 @@
#
#
import logging
+from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple
import attr
from immutabledict import immutabledict
-from synapse.api.constants import (
- AccountDataTypes,
- Direction,
- EventContentFields,
- EventTypes,
- Membership,
-)
+from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership
from synapse.events import EventBase
from synapse.events.utils import strip_event
from synapse.handlers.relations import BundledAggregations
@@ -464,6 +459,7 @@ class SlidingSyncHandler:
membership_state_keys = room_sync_config.required_state_map.get(
EventTypes.Member
)
+ # Also see `StateFilter.must_await_full_state(...)` for comparison
lazy_loading = (
membership_state_keys is not None
and len(membership_state_keys) == 1
@@ -540,11 +536,15 @@ class SlidingSyncHandler:
rooms[room_id] = room_sync_result
+ extensions = await self.get_extensions_response(
+ sync_config=sync_config, to_token=to_token
+ )
+
return SlidingSyncResult(
next_pos=to_token,
lists=lists,
rooms=rooms,
- extensions={},
+ extensions=extensions,
)
async def get_sync_room_ids_for_user(
@@ -953,11 +953,15 @@ class SlidingSyncHandler:
# provided in the list. `None` is a valid type for rooms which do not have a
# room type.
if filters.room_types is not None or filters.not_room_types is not None:
- # Make a copy so we don't run into an error: `Set changed size during
- # iteration`, when we filter out and remove items
- for room_id in filtered_room_id_set.copy():
- create_event = await self.store.get_create_event_for_room(room_id)
- room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+ room_to_type = await self.store.bulk_get_room_type(
+ {
+ room_id
+ for room_id in filtered_room_id_set
+ # We only know the room types for joined rooms
+ if sync_room_map[room_id].membership == Membership.JOIN
+ }
+ )
+ for room_id, room_type in room_to_type.items():
if (
filters.room_types is not None
and room_type not in filters.room_types
@@ -1202,7 +1206,7 @@ class SlidingSyncHandler:
# Figure out any stripped state events for invite/knocks. This allows the
# potential joiner to identify the room.
- stripped_state: List[JsonDict] = []
+ stripped_state: Optional[List[JsonDict]] = None
if room_membership_for_user_at_to_token.membership in (
Membership.INVITE,
Membership.KNOCK,
@@ -1239,7 +1243,7 @@ class SlidingSyncHandler:
# updates.
initial = True
- # Fetch the required state for the room
+ # Fetch the `required_state` for the room
#
# No `required_state` for invite/knock rooms (just `stripped_state`)
#
@@ -1247,13 +1251,15 @@ class SlidingSyncHandler:
# of membership. Currently, we have to make this optional because
# `invite`/`knock` rooms only have `stripped_state`. See
# https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
+ #
+ # Calculate the `StateFilter` based on the `required_state` for the room
room_state: Optional[StateMap[EventBase]] = None
+ required_room_state: Optional[StateMap[EventBase]] = None
if room_membership_for_user_at_to_token.membership not in (
Membership.INVITE,
Membership.KNOCK,
):
- # Calculate the `StateFilter` based on the `required_state` for the room
- state_filter: Optional[StateFilter] = StateFilter.none()
+ required_state_filter = StateFilter.none()
# If we have a double wildcard ("*", "*") in the `required_state`, we need
# to fetch all state for the room
#
@@ -1276,7 +1282,7 @@ class SlidingSyncHandler:
if StateValues.WILDCARD in room_sync_config.required_state_map.get(
StateValues.WILDCARD, set()
):
- state_filter = StateFilter.all()
+ required_state_filter = StateFilter.all()
# TODO: `StateFilter` currently doesn't support wildcard event types. We're
# currently working around this by returning all state to the client but it
# would be nice to fetch less from the database and return just what the
@@ -1285,7 +1291,7 @@ class SlidingSyncHandler:
room_sync_config.required_state_map.get(StateValues.WILDCARD)
is not None
):
- state_filter = StateFilter.all()
+ required_state_filter = StateFilter.all()
else:
required_state_types: List[Tuple[str, Optional[str]]] = []
for (
@@ -1317,51 +1323,88 @@ class SlidingSyncHandler:
else:
required_state_types.append((state_type, state_key))
- state_filter = StateFilter.from_types(required_state_types)
-
- # We can skip fetching state if we don't need any
- if state_filter != StateFilter.none():
- # We can return all of the state that was requested if we're doing an
- # initial sync
- if initial:
- # People shouldn't see past their leave/ban event
- if room_membership_for_user_at_to_token.membership in (
- Membership.LEAVE,
- Membership.BAN,
- ):
- room_state = await self.storage_controllers.state.get_state_at(
- room_id,
- stream_position=to_token.copy_and_replace(
- StreamKeyType.ROOM,
- room_membership_for_user_at_to_token.event_pos.to_room_stream_token(),
- ),
- state_filter=state_filter,
- # Partially-stated rooms should have all state events except for
- # the membership events and since we've already excluded
- # partially-stated rooms unless `required_state` only has
- # `["m.room.member", "$LAZY"]` for membership, we should be able
- # to retrieve everything requested. Plus we don't want to block
- # the whole sync waiting for this one room.
- await_full_state=False,
- )
- # Otherwise, we can get the latest current state in the room
- else:
- room_state = await self.storage_controllers.state.get_current_state(
- room_id,
- state_filter,
- # Partially-stated rooms should have all state events except for
- # the membership events and since we've already excluded
- # partially-stated rooms unless `required_state` only has
- # `["m.room.member", "$LAZY"]` for membership, we should be able
- # to retrieve everything requested. Plus we don't want to block
- # the whole sync waiting for this one room.
- await_full_state=False,
- )
- # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token`
+ required_state_filter = StateFilter.from_types(required_state_types)
+
+ # We need this base set of info for the response so let's just fetch it along
+ # with the `required_state` for the room
+ META_ROOM_STATE = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")]
+ state_filter = StateFilter(
+ types=StateFilter.from_types(
+ chain(META_ROOM_STATE, required_state_filter.to_types())
+ ).types,
+ include_others=required_state_filter.include_others,
+ )
+
+ # We can return all of the state that was requested if this was the first
+ # time we've sent the room down this connection.
+ if initial:
+ # People shouldn't see past their leave/ban event
+ if room_membership_for_user_at_to_token.membership in (
+ Membership.LEAVE,
+ Membership.BAN,
+ ):
+ room_state = await self.storage_controllers.state.get_state_at(
+ room_id,
+ stream_position=to_token.copy_and_replace(
+ StreamKeyType.ROOM,
+ room_membership_for_user_at_to_token.event_pos.to_room_stream_token(),
+ ),
+ state_filter=state_filter,
+ # Partially-stated rooms should have all state events except for
+ # remote membership events. Since we've already excluded
+ # partially-stated rooms unless `required_state` only has
+ # `["m.room.member", "$LAZY"]` for membership, we should be able to
+ # retrieve everything requested. When we're lazy-loading, if there
+ # are some remote senders in the timeline, we should also have their
+ # membership event because we had to auth that timeline event. Plus
+ # we don't want to block the whole sync waiting for this one room.
+ await_full_state=False,
+ )
+ # Otherwise, we can get the latest current state in the room
else:
- # TODO: Once we can figure out if we've sent a room down this connection before,
- # we can return updates instead of the full required state.
- raise NotImplementedError()
+ room_state = await self.storage_controllers.state.get_current_state(
+ room_id,
+ state_filter,
+ # Partially-stated rooms should have all state events except for
+ # remote membership events. Since we've already excluded
+ # partially-stated rooms unless `required_state` only has
+ # `["m.room.member", "$LAZY"]` for membership, we should be able to
+ # retrieve everything requested. When we're lazy-loading, if there
+ # are some remote senders in the timeline, we should also have their
+ # membership event because we had to auth that timeline event. Plus
+ # we don't want to block the whole sync waiting for this one room.
+ await_full_state=False,
+ )
+ # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token`
+ else:
+ # TODO: Once we can figure out if we've sent a room down this connection before,
+ # we can return updates instead of the full required state.
+ raise NotImplementedError()
+
+ if required_state_filter != StateFilter.none():
+ required_room_state = required_state_filter.filter_state(room_state)
+
+ # Find the room name and avatar from the state
+ room_name: Optional[str] = None
+ room_avatar: Optional[str] = None
+ if room_state is not None:
+ name_event = room_state.get((EventTypes.Name, ""))
+ if name_event is not None:
+ room_name = name_event.content.get("name")
+
+ avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
+ if avatar_event is not None:
+ room_avatar = avatar_event.content.get("url")
+ elif stripped_state is not None:
+ for event in stripped_state:
+ if event["type"] == EventTypes.Name:
+ room_name = event.get("content", {}).get("name")
+ elif event["type"] == EventTypes.RoomAvatar:
+ room_avatar = event.get("content", {}).get("url")
+
+ # Found everything so we can stop looking
+ if room_name is not None and room_avatar is not None:
+ break
# Figure out the last bump event in the room
last_bump_event_result = (
@@ -1378,16 +1421,16 @@ class SlidingSyncHandler:
bump_stamp = bump_event_pos.stream
return SlidingSyncResult.RoomResult(
- # TODO: Dummy value
- name=None,
- # TODO: Dummy value
- avatar=None,
+ name=room_name,
+ avatar=room_avatar,
# TODO: Dummy value
heroes=None,
# TODO: Dummy value
is_dm=False,
initial=initial,
- required_state=list(room_state.values()) if room_state else None,
+ required_state=(
+ list(required_room_state.values()) if required_room_state else None
+ ),
timeline_events=timeline_events,
bundled_aggregations=bundled_aggregations,
stripped_state=stripped_state,
@@ -1404,3 +1447,100 @@ class SlidingSyncHandler:
notification_count=0,
highlight_count=0,
)
+
+ async def get_extensions_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ to_token: StreamToken,
+ ) -> SlidingSyncResult.Extensions:
+ """Handle extension requests.
+
+ Args:
+ sync_config: Sync configuration
+ to_token: The point in the stream to sync up to.
+ """
+
+ if sync_config.extensions is None:
+ return SlidingSyncResult.Extensions()
+
+ to_device_response = None
+ if sync_config.extensions.to_device:
+ to_device_response = await self.get_to_device_extensions_response(
+ sync_config=sync_config,
+ to_device_request=sync_config.extensions.to_device,
+ to_token=to_token,
+ )
+
+ return SlidingSyncResult.Extensions(to_device=to_device_response)
+
+ async def get_to_device_extensions_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
+ to_token: StreamToken,
+ ) -> SlidingSyncResult.Extensions.ToDeviceExtension:
+ """Handle to-device extension (MSC3885)
+
+ Args:
+ sync_config: Sync configuration
+ to_device_request: The to-device extension from the request
+ to_token: The point in the stream to sync up to.
+ """
+
+ user_id = sync_config.user.to_string()
+ device_id = sync_config.device_id
+
+ # Check that this request has a valid device ID (not all requests have
+ # to belong to a device, and so device_id is None), and that the
+ # extension is enabled.
+ if device_id is None or not to_device_request.enabled:
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=f"{to_token.to_device_key}",
+ events=[],
+ )
+
+ since_stream_id = 0
+ if to_device_request.since is not None:
+ # We've already validated this is an int.
+ since_stream_id = int(to_device_request.since)
+
+ if to_token.to_device_key < since_stream_id:
+ # The since token is ahead of our current token, so we return an
+ # empty response.
+ logger.warning(
+ "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
+ since_stream_id,
+ to_token.to_device_key,
+ )
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=to_device_request.since,
+ events=[],
+ )
+
+ # Delete everything before the given since token, as we know the
+ # device must have received them.
+ deleted = await self.store.delete_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ up_to_stream_id=since_stream_id,
+ )
+
+ logger.debug(
+ "Deleted %d to-device messages up to %d for %s",
+ deleted,
+ since_stream_id,
+ user_id,
+ )
+
+ messages, stream_id = await self.store.get_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ from_stream_id=since_stream_id,
+ to_stream_id=to_token.to_device_key,
+ limit=min(to_device_request.limit, 100), # Limit to at most 100 events
+ )
+
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=f"{stream_id}",
+ events=messages,
+ )
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 13aed1dc85..94d5faf9f7 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms
)
- response["extensions"] = {} # TODO: sliding_sync_result.extensions
+ response["extensions"] = await self.encode_extensions(
+ requester, sliding_sync_result.extensions
+ )
return response
@@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet):
return serialized_rooms
+ async def encode_extensions(
+ self, requester: Requester, extensions: SlidingSyncResult.Extensions
+ ) -> JsonDict:
+ result = {}
+
+ if extensions.to_device is not None:
+ result["to_device"] = {
+ "next_batch": extensions.to_device.next_batch,
+ "events": extensions.to_device.events,
+ }
+
+ return result
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index b2a67aff89..5188b2f7a4 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -41,7 +41,7 @@ from typing import (
import attr
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
@@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_event(create_id)
return create_event
+ @cached(max_entries=10000)
+ async def get_room_type(self, room_id: str) -> Optional[str]:
+ """Get the room type for a given room. The server must be joined to the
+ given room.
+ """
+
+ row = await self.db_pool.simple_select_one(
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ retcols=("room_type",),
+ allow_none=True,
+ desc="get_room_type",
+ )
+
+ if row is not None:
+ return row[0]
+
+ # If we haven't updated `room_stats_state` with the room yet, query the
+ # create event directly.
+ create_event = await self.get_create_event_for_room(room_id)
+ room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+ return room_type
+
+ @cachedList(cached_method_name="get_room_type", list_name="room_ids")
+ async def bulk_get_room_type(
+ self, room_ids: Set[str]
+ ) -> Mapping[str, Optional[str]]:
+ """Bulk fetch room types for the given rooms, the server must be in all
+ the rooms given.
+ """
+
+ rows = await self.db_pool.simple_select_many_batch(
+ table="room_stats_state",
+ column="room_id",
+ iterable=room_ids,
+ retcols=("room_id", "room_type"),
+ desc="bulk_get_room_type",
+ )
+
+ # If we haven't updated `room_stats_state` with the room yet, query the
+ # create events directly. This should happen only rarely so we don't
+ # mind if we do this in a loop.
+ results = dict(rows)
+ for room_id in room_ids - results.keys():
+ create_event = await self.get_create_event_for_room(room_id)
+ room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+ results[room_id] = room_type
+
+ return results
+
@cached(max_entries=100000, iterable=True)
async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 43dcdf20dd..a8a3a8f242 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -18,7 +18,7 @@
#
#
from enum import Enum
-from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
import attr
from typing_extensions import TypedDict
@@ -252,10 +252,39 @@ class SlidingSyncResult:
count: int
ops: List[Operation]
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class Extensions:
+ """Responses for extensions
+
+ Attributes:
+ to_device: The to-device extension (MSC3885)
+ """
+
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class ToDeviceExtension:
+ """The to-device extension (MSC3885)
+
+ Attributes:
+ next_batch: The to-device stream token the client should use
+ to get more results
+ events: A list of to-device messages for the client
+ """
+
+ next_batch: str
+ events: Sequence[JsonMapping]
+
+ def __bool__(self) -> bool:
+ return bool(self.events)
+
+ to_device: Optional[ToDeviceExtension] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.to_device)
+
next_pos: StreamToken
lists: Dict[str, SlidingWindowList]
rooms: Dict[str, RoomResult]
- extensions: JsonMapping
+ extensions: Extensions
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -271,5 +300,5 @@ class SlidingSyncResult:
next_pos=next_pos,
lists={},
rooms={},
- extensions={},
+ extensions=SlidingSyncResult.Extensions(),
)
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 55f6b44053..1e8fe76c99 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel):
class RoomSubscription(CommonRoomParameters):
pass
- class Extension(RequestBodyModel):
- enabled: Optional[StrictBool] = False
- lists: Optional[List[StrictStr]] = None
- rooms: Optional[List[StrictStr]] = None
+ class Extensions(RequestBodyModel):
+ """The extensions section of the request.
+
+ Extensions MUST have an `enabled` flag which defaults to `false`. If a client
+ sends an unknown extension name, the server MUST ignore it (or else backwards
+ compatibility between clients and servers is broken when a newer client tries to
+ communicate with an older server).
+ """
+
+ class ToDeviceExtension(RequestBodyModel):
+ """The to-device extension (MSC3885)
+
+ Attributes:
+ enabled
+ limit: Maximum number of to-device messages to return
+ since: The `next_batch` from the previous sync response
+ """
+
+ enabled: Optional[StrictBool] = False
+ limit: StrictInt = 100
+ since: Optional[StrictStr] = None
+
+ @validator("since")
+ def since_token_check(
+ cls, value: Optional[StrictStr]
+ ) -> Optional[StrictStr]:
+ # `since` comes in as an opaque string token but we know that it's just
+ # an integer representing the position in the device inbox stream. We
+ # want to pre-validate it to make sure it works fine in downstream code.
+ if value is None:
+ return value
+
+ try:
+ int(value)
+ except ValueError:
+ raise ValueError(
+ "'extensions.to_device.since' is invalid (should look like an int)"
+ )
+
+ return value
+
+ to_device: Optional[ToDeviceExtension] = None
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
@@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel):
else:
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
- extensions: Optional[Dict[StrictStr, Extension]] = None
+ extensions: Optional[Extensions] = None
@validator("lists")
def lists_length_check(
|