diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index c615cc7c32..64b5acbe98 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -29,6 +29,7 @@ from typing import (
Callable,
Dict,
Final,
+ Generic,
List,
Literal,
Mapping,
@@ -37,6 +38,7 @@ from typing import (
Sequence,
Set,
Tuple,
+ TypeVar,
Union,
cast,
)
@@ -55,6 +57,7 @@ from synapse.api.constants import (
from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.events import EventBase, StrippedStateEvent
from synapse.events.utils import parse_stripped_state_event, strip_event
+from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.relations import BundledAggregations
from synapse.logging.opentracing import (
SynapseTags,
@@ -821,7 +824,7 @@ class SlidingSyncHandler:
async def handle_room(room_id: str) -> None:
room_sync_result = await self.get_room_sync_data(
sync_config=sync_config,
- per_connection_state=previous_connection_state,
+ previous_connection_state=previous_connection_state,
room_id=room_id,
room_sync_config=relevant_rooms_to_send_map[room_id],
room_membership_for_user_at_to_token=room_membership_for_user_map[
@@ -839,9 +842,13 @@ class SlidingSyncHandler:
with start_active_span("sliding_sync.generate_room_entries"):
await concurrently_execute(handle_room, relevant_rooms_to_send_map, 10)
+ new_connection_state = previous_connection_state.get_mutable()
+
extensions = await self.get_extensions_response(
sync_config=sync_config,
actual_lists=lists,
+ previous_connection_state=previous_connection_state,
+ new_connection_state=new_connection_state,
# We're purposely using `relevant_room_map` instead of
# `relevant_rooms_to_send_map` here. This needs to be all room_ids we could
# send regardless of whether they have an event update or not. The
@@ -854,8 +861,6 @@ class SlidingSyncHandler:
)
if has_lists or has_room_subscriptions:
- new_connection_state = previous_connection_state.get_mutable()
-
# We now calculate if any rooms outside the range have had updates,
# which we are not sending down.
#
@@ -886,7 +891,7 @@ class SlidingSyncHandler:
unsent_room_ids = list(missing_event_map_by_room)
new_connection_state.rooms.record_unsent_rooms(
- unsent_room_ids, from_token.stream_token
+ unsent_room_ids, from_token.stream_token.room_key
)
new_connection_state.rooms.record_sent_rooms(
@@ -896,7 +901,7 @@ class SlidingSyncHandler:
connection_position = await self.connection_store.record_new_state(
sync_config=sync_config,
from_token=from_token,
- per_connection_state=new_connection_state,
+ new_connection_state=new_connection_state,
)
elif from_token:
connection_position = from_token.connection_position
@@ -1949,7 +1954,7 @@ class SlidingSyncHandler:
async def get_room_sync_data(
self,
sync_config: SlidingSyncConfig,
- per_connection_state: "PerConnectionState",
+ previous_connection_state: "PerConnectionState",
room_id: str,
room_sync_config: RoomSyncConfig,
room_membership_for_user_at_to_token: _RoomMembershipForUser,
@@ -1997,7 +2002,7 @@ class SlidingSyncHandler:
from_bound = None
initial = True
if from_token and not room_membership_for_user_at_to_token.newly_joined:
- room_status = per_connection_state.rooms.have_sent_room(room_id)
+ room_status = previous_connection_state.rooms.have_sent_room(room_id)
if room_status.status == HaveSentRoomFlag.LIVE:
from_bound = from_token.stream_token.room_key
initial = False
@@ -2476,6 +2481,8 @@ class SlidingSyncHandler:
async def get_extensions_response(
self,
sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ new_connection_state: "MutablePerConnectionState",
actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
actual_room_ids: Set[str],
actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
@@ -2486,6 +2493,9 @@ class SlidingSyncHandler:
Args:
sync_config: Sync configuration
+ new_connection_state: Snapshot of the current per-connection state
+ new_per_connection_state: A mutable copy of the per-connection
+ state, used to record updates to the state during this request.
actual_lists: Sliding window API. A map of list key to list results in the
Sliding Sync response.
actual_room_ids: The actual room IDs in the the Sliding Sync response.
@@ -2530,6 +2540,8 @@ class SlidingSyncHandler:
if sync_config.extensions.receipts is not None:
receipts_response = await self.get_receipts_extension_response(
sync_config=sync_config,
+ previous_connection_state=previous_connection_state,
+ new_connection_state=new_connection_state,
actual_lists=actual_lists,
actual_room_ids=actual_room_ids,
actual_room_response_map=actual_room_response_map,
@@ -2849,6 +2861,8 @@ class SlidingSyncHandler:
async def get_receipts_extension_response(
self,
sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ new_connection_state: "MutablePerConnectionState",
actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
actual_room_ids: Set[str],
actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
@@ -2860,6 +2874,9 @@ class SlidingSyncHandler:
Args:
sync_config: Sync configuration
+ previous_connection_state: The current per-connection state
+ new_connection_state: A mutable copy of the per-connection
+ state, used to record updates to the state.
actual_lists: Sliding window API. A map of list key to list results in the
Sliding Sync response.
actual_room_ids: The actual room IDs in the the Sliding Sync response.
@@ -2882,50 +2899,145 @@ class SlidingSyncHandler:
room_id_to_receipt_map: Dict[str, JsonMapping] = {}
if len(relevant_room_ids) > 0:
- # TODO: Take connection tracking into account so that when a room comes back
- # into range we can send the receipts that were missed.
- receipt_source = self.event_sources.sources.receipt
- receipts, _ = await receipt_source.get_new_events(
- user=sync_config.user,
- from_key=(
- from_token.stream_token.receipt_key
- if from_token
- else MultiWriterStreamToken(stream=0)
- ),
- to_key=to_token.receipt_key,
- # This is a dummy value and isn't used in the function
- limit=0,
- room_ids=relevant_room_ids,
- is_guest=False,
+ # We need to handle the different cases depending on if we have sent
+ # down receipts previously or not, so we split the relevant rooms
+ # up into different collections based on status.
+ live_rooms = set()
+ previously_rooms: Dict[str, MultiWriterStreamToken] = {}
+ initial_rooms = set()
+
+ for room_id in relevant_room_ids:
+ if not from_token:
+ initial_rooms.add(room_id)
+ continue
+
+ # If we're sending down the room from scratch again for some reason, we
+ # should always resend the receipts as well (regardless of if
+ # we've sent them down before). This is to mimic the behaviour
+ # of what happens on initial sync, where you get a chunk of
+ # timeline with all of the corresponding receipts for the events in the timeline.
+ room_result = actual_room_response_map.get(room_id)
+ if room_result is not None and room_result.initial:
+ initial_rooms.add(room_id)
+ continue
+
+ room_status = previous_connection_state.receipts.have_sent_room(room_id)
+ if room_status.status == HaveSentRoomFlag.LIVE:
+ live_rooms.add(room_id)
+ elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
+ assert room_status.last_token is not None
+ previously_rooms[room_id] = room_status.last_token
+ elif room_status.status == HaveSentRoomFlag.NEVER:
+ initial_rooms.add(room_id)
+ else:
+ assert_never(room_status.status)
+
+ # The set of receipts that we fetched. Private receipts need to be
+ # filtered out before returning.
+ fetched_receipts = []
+
+ # For live rooms we just fetch all receipts in those rooms since the
+ # `since` token.
+ if live_rooms:
+ assert from_token is not None
+ receipts = await self.store.get_linearized_receipts_for_rooms(
+ room_ids=live_rooms,
+ from_key=from_token.stream_token.receipt_key,
+ to_key=to_token.receipt_key,
+ )
+ fetched_receipts.extend(receipts)
+
+ # For rooms we've previously sent down, but aren't up to date, we
+ # need to use the from token from the room status.
+ if previously_rooms:
+ for room_id, receipt_token in previously_rooms.items():
+ # TODO: Limit the number of receipts we're about to send down
+ # for the room, if its too many we should TODO
+ previously_receipts = (
+ await self.store.get_linearized_receipts_for_room(
+ room_id=room_id,
+ from_key=receipt_token,
+ to_key=to_token.receipt_key,
+ )
+ )
+ fetched_receipts.extend(previously_receipts)
+
+ # For rooms we haven't previously sent down, we could send all receipts
+ # from that room but we only want to include receipts for events
+ # in the timeline to avoid bloating and blowing up the sync response
+ # as the number of users in the room increases. (this behavior is part of the spec)
+ for room_id in initial_rooms:
+ room_result = actual_room_response_map.get(room_id)
+ if room_result is None:
+ continue
+
+ relevant_event_ids = [
+ event.event_id for event in room_result.timeline_events
+ ]
+
+ # TODO: In the future, it would be good to fetch less receipts
+ # out of the database in the first place but we would need to
+ # add a new `event_id` index to `receipts_linearized`.
+ initial_receipts = await self.store.get_linearized_receipts_for_room(
+ room_id=room_id,
+ to_key=to_token.receipt_key,
+ )
+
+ for receipt in initial_receipts:
+ content = {
+ event_id: content_value
+ for event_id, content_value in receipt["content"].items()
+ if event_id in relevant_event_ids
+ }
+ if content:
+ fetched_receipts.append(
+ {
+ "type": receipt["type"],
+ "room_id": receipt["room_id"],
+ "content": content,
+ }
+ )
+
+ fetched_receipts = ReceiptEventSource.filter_out_private_receipts(
+ fetched_receipts, sync_config.user.to_string()
)
- for receipt in receipts:
+ for receipt in fetched_receipts:
# These fields should exist for every receipt
room_id = receipt["room_id"]
type = receipt["type"]
content = receipt["content"]
- # For `inital: True` rooms, we only want to include receipts for events
- # in the timeline.
- room_result = actual_room_response_map.get(room_id)
- if room_result is not None:
- if room_result.initial:
- # TODO: In the future, it would be good to fetch less receipts
- # out of the database in the first place but we would need to
- # add a new `event_id` index to `receipts_linearized`.
- relevant_event_ids = [
- event.event_id for event in room_result.timeline_events
- ]
-
- assert isinstance(content, dict)
- content = {
- event_id: content_value
- for event_id, content_value in content.items()
- if event_id in relevant_event_ids
- }
-
room_id_to_receipt_map[room_id] = {"type": type, "content": content}
+ # Now we update the per-connection state to track which receipts we have
+ # and haven't sent down.
+ new_connection_state.receipts.record_sent_rooms(relevant_room_ids)
+
+ if from_token:
+ # Now find the set of rooms that may have receipts that we're not sending
+ # down. We only need to check rooms that we have previously returned
+ # receipts for (in `previous_connection_state`) because we only care about
+ # updating `LIVE` rooms to `PREVIOUSLY`. The `PREVIOUSLY` rooms will just
+ # stay pointing at their previous position so we don't need to waste time
+ # checking those and since we default to `NEVER`, rooms that were `NEVER`
+ # sent before don't need to be recorded as we'll handle them correctly when
+ # they come into range for the first time.
+ rooms_no_receipts = [
+ room_id
+ for room_id, room_status in previous_connection_state.receipts._statuses.items()
+ if room_status.status == HaveSentRoomFlag.LIVE
+ and room_id not in relevant_room_ids
+ ]
+ changed_rooms = await self.store.get_rooms_with_receipts_between(
+ rooms_no_receipts,
+ from_key=from_token.stream_token.receipt_key,
+ to_key=to_token.receipt_key,
+ )
+ new_connection_state.receipts.record_unsent_rooms(
+ changed_rooms, from_token.stream_token.receipt_key
+ )
+
return SlidingSyncResult.Extensions.ReceiptsExtension(
room_id_to_receipt_map=room_id_to_receipt_map,
)
@@ -3016,9 +3128,15 @@ class HaveSentRoomFlag(Enum):
LIVE = 3
+T = TypeVar("T")
+
+
@attr.s(auto_attribs=True, slots=True, frozen=True)
-class HaveSentRoom:
- """Whether we have sent the room down a sliding sync connection.
+class HaveSentRoom(Generic[T]):
+ """Whether we have sent the room data down a sliding sync connection.
+
+ We are generic over the type of token used, e.g. `RoomStreamToken` or
+ `MultiWriterStreamToken`.
Attributes:
status: Flag of if we have or haven't sent down the room
@@ -3029,54 +3147,58 @@ class HaveSentRoom:
"""
status: HaveSentRoomFlag
- last_token: Optional[RoomStreamToken]
+ last_token: Optional[T]
@staticmethod
- def previously(last_token: RoomStreamToken) -> "HaveSentRoom":
+ def live() -> "HaveSentRoom[T]":
+ return HaveSentRoom(HaveSentRoomFlag.LIVE, None)
+
+ @staticmethod
+ def previously(last_token: T) -> "HaveSentRoom[T]":
"""Constructor for `PREVIOUSLY` flag."""
return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token)
-
-HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None)
-HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None)
+ @staticmethod
+ def never() -> "HaveSentRoom[T]":
+ return HaveSentRoom(HaveSentRoomFlag.NEVER, None)
@attr.s(auto_attribs=True, slots=True, frozen=True)
-class RoomStatusMap:
+class RoomStatusMap(Generic[T]):
"""For a given stream, e.g. events, records what we have or have not sent
down for that stream in a given room."""
# `room_id` -> `HaveSentRoom`
- _statuses: Mapping[str, HaveSentRoom] = attr.Factory(dict)
+ _statuses: Mapping[str, HaveSentRoom[T]] = attr.Factory(dict)
- def have_sent_room(self, room_id: str) -> HaveSentRoom:
+ def have_sent_room(self, room_id: str) -> HaveSentRoom[T]:
"""Return whether we have previously sent the room down"""
- return self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER)
+ return self._statuses.get(room_id, HaveSentRoom.never())
- def get_mutable(self) -> "MutableRoomStatusMap":
+ def get_mutable(self) -> "MutableRoomStatusMap[T]":
"""Get a mutable copy of this state."""
return MutableRoomStatusMap(
statuses=self._statuses,
)
- def copy(self) -> "RoomStatusMap":
+ def copy(self) -> "RoomStatusMap[T]":
"""Make a copy of the class. Useful for converting from a mutable to
immutable version."""
return RoomStatusMap(statuses=dict(self._statuses))
-class MutableRoomStatusMap(RoomStatusMap):
+class MutableRoomStatusMap(RoomStatusMap[T]):
"""A mutable version of `RoomStatusMap`"""
# We use a ChainMap here so that we can easily track what has been updated
# and what hasn't. Note that when we persist the per connection state this
# will get flattened to a normal dict (via calling `.copy()`)
- _statuses: typing.ChainMap[str, HaveSentRoom]
+ _statuses: typing.ChainMap[str, HaveSentRoom[T]]
def __init__(
self,
- statuses: Mapping[str, HaveSentRoom],
+ statuses: Mapping[str, HaveSentRoom[T]],
) -> None:
# ChainMap requires a mutable mapping, but we're not actually going to
# mutate it.
@@ -3086,22 +3208,20 @@ class MutableRoomStatusMap(RoomStatusMap):
statuses=ChainMap({}, statuses),
)
- def get_updates(self) -> Mapping[str, HaveSentRoom]:
+ def get_updates(self) -> Mapping[str, HaveSentRoom[T]]:
"""Return only the changes that were made"""
return self._statuses.maps[0]
def record_sent_rooms(self, room_ids: StrCollection) -> None:
"""Record that we have sent these rooms in the response"""
for room_id in room_ids:
- current_status = self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER)
+ current_status = self._statuses.get(room_id, HaveSentRoom.never())
if current_status.status == HaveSentRoomFlag.LIVE:
continue
- self._statuses[room_id] = HAVE_SENT_ROOM_LIVE
+ self._statuses[room_id] = HaveSentRoom.live()
- def record_unsent_rooms(
- self, room_ids: StrCollection, from_token: StreamToken
- ) -> None:
+ def record_unsent_rooms(self, room_ids: StrCollection, from_token: T) -> None:
"""Record that we have not sent these rooms in the response, but there
have been updates.
"""
@@ -3116,33 +3236,42 @@ class MutableRoomStatusMap(RoomStatusMap):
# sent anything down this time either so we leave it as NEVER.
for room_id in room_ids:
- current_status = self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER)
+ current_status = self._statuses.get(room_id, HaveSentRoom.never())
if current_status.status != HaveSentRoomFlag.LIVE:
continue
- self._statuses[room_id] = HaveSentRoom.previously(from_token.room_key)
+ self._statuses[room_id] = HaveSentRoom.previously(from_token)
@attr.s(auto_attribs=True)
class PerConnectionState:
- """The per-connection state. A snapshot of what we've sent down the connection before.
+ """The per-connection state. A snapshot of what we've sent down the
+ connection before.
- Currently, we track whether we've sent down various aspects of a given room before.
+ Currently, we track whether we've sent down various aspects of a given room
+ before.
- We use the `rooms` field to store the position in the events stream for each room that we've previously sent to the client before. On the next request that includes the room, we can then send only what's changed since that recorded position.
+ We use the `rooms` field to store the position in the events stream for each
+ room that we've previously sent to the client before. On the next request
+ that includes the room, we can then send only what's changed since that
+ recorded position.
- Same goes for the `receipts` field so we only need to send the new receipts since the last time you made a sync request.
+ Same goes for the `receipts` field so we only need to send the new receipts
+ since the last time you made a sync request.
Attributes:
rooms: The status of each room for the events stream.
+ receipts: The status of each room for the receipts stream.
"""
- rooms: RoomStatusMap = attr.Factory(RoomStatusMap)
+ rooms: RoomStatusMap[RoomStreamToken] = attr.Factory(RoomStatusMap)
+ receipts: RoomStatusMap[MultiWriterStreamToken] = attr.Factory(RoomStatusMap)
def get_mutable(self) -> "MutablePerConnectionState":
"""Get a mutable copy of this state."""
return MutablePerConnectionState(
rooms=self.rooms.get_mutable(),
+ receipts=self.receipts.get_mutable(),
)
@@ -3150,10 +3279,11 @@ class PerConnectionState:
class MutablePerConnectionState(PerConnectionState):
"""A mutable version of `PerConnectionState`"""
- rooms: MutableRoomStatusMap
+ rooms: MutableRoomStatusMap[RoomStreamToken]
+ receipts: MutableRoomStatusMap[MultiWriterStreamToken]
def has_updates(self) -> bool:
- return bool(self.rooms.get_updates())
+ return bool(self.rooms.get_updates()) or bool(self.receipts.get_updates())
@attr.s(auto_attribs=True)
@@ -3233,7 +3363,7 @@ class SlidingSyncConnectionStore:
self,
sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken],
- per_connection_state: MutablePerConnectionState,
+ new_connection_state: MutablePerConnectionState,
) -> int:
"""Record updated per-connection state, returning the connection
position associated with the new state.
@@ -3245,7 +3375,7 @@ class SlidingSyncConnectionStore:
if from_token is not None:
prev_connection_token = from_token.connection_position
- if not per_connection_state.has_updates():
+ if not new_connection_state.has_updates():
return prev_connection_token
conn_key = self._get_connection_key(sync_config)
@@ -3259,7 +3389,8 @@ class SlidingSyncConnectionStore:
# We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
# don't grow forever.
sync_statuses[new_store_token] = PerConnectionState(
- rooms=per_connection_state.rooms.copy(),
+ rooms=new_connection_state.rooms.copy(),
+ receipts=new_connection_state.receipts.copy(),
)
return new_store_token
|