diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index f1f6f30b95..3231574402 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -46,6 +46,7 @@ from synapse.storage.roommember import MemberSummary
from synapse.types import (
DeviceListUpdates,
JsonDict,
+ JsonMapping,
PersistedEventPosition,
Requester,
RoomStreamToken,
@@ -357,6 +358,7 @@ class SlidingSyncHandler:
self.event_sources = hs.get_event_sources()
self.relations_handler = hs.get_relations_handler()
self.device_handler = hs.get_device_handler()
+ self.push_rules_handler = hs.get_push_rules_handler()
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
async def wait_for_sync_for_user(
@@ -628,6 +630,7 @@ class SlidingSyncHandler:
extensions = await self.get_extensions_response(
sync_config=sync_config,
+ lists=lists,
from_token=from_token,
to_token=to_token,
)
@@ -1797,6 +1800,7 @@ class SlidingSyncHandler:
async def get_extensions_response(
self,
sync_config: SlidingSyncConfig,
+ lists: Dict[str, SlidingSyncResult.SlidingWindowList],
to_token: StreamToken,
from_token: Optional[SlidingSyncStreamToken],
) -> SlidingSyncResult.Extensions:
@@ -1804,6 +1808,7 @@ class SlidingSyncHandler:
Args:
sync_config: Sync configuration
+ lists: Sliding window API. A map of list key to list results.
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
@@ -1828,9 +1833,20 @@ class SlidingSyncHandler:
from_token=from_token,
)
+ account_data_response = None
+ if sync_config.extensions.account_data is not None:
+ account_data_response = await self.get_account_data_extension_response(
+ sync_config=sync_config,
+ lists=lists,
+ account_data_request=sync_config.extensions.account_data,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
return SlidingSyncResult.Extensions(
to_device=to_device_response,
e2ee=e2ee_response,
+ account_data=account_data_response,
)
async def get_to_device_extension_response(
@@ -1956,3 +1972,125 @@ class SlidingSyncHandler:
device_one_time_keys_count=device_one_time_keys_count,
device_unused_fallback_key_types=device_unused_fallback_key_types,
)
+
+ async def get_account_data_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+ account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]:
+ """Handle Account Data extension (MSC3959)
+
+ Args:
+ sync_config: Sync configuration
+ lists: Sliding window API. A map of list key to list results.
+ account_data_request: The account_data extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ user_id = sync_config.user.to_string()
+
+ # Skip if the extension is not enabled
+ if not account_data_request.enabled:
+ return None
+
+ global_account_data_map: Mapping[str, JsonMapping] = {}
+ if from_token is not None:
+ global_account_data_map = (
+ await self.store.get_updated_global_account_data_for_user(
+ user_id, from_token.stream_token.account_data_key
+ )
+ )
+
+ have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
+ user_id, from_token.stream_token.push_rules_key
+ )
+ if have_push_rules_changed:
+ global_account_data_map = dict(global_account_data_map)
+ global_account_data_map[AccountDataTypes.PUSH_RULES] = (
+ await self.push_rules_handler.push_rules_for_user(sync_config.user)
+ )
+ else:
+ all_global_account_data = await self.store.get_global_account_data_for_user(
+ user_id
+ )
+
+ global_account_data_map = dict(all_global_account_data)
+ global_account_data_map[AccountDataTypes.PUSH_RULES] = (
+ await self.push_rules_handler.push_rules_for_user(sync_config.user)
+ )
+
+ # We only want to include account data for rooms that are already in the sliding
+ # sync response AND that were requested in the account data request.
+ relevant_room_ids: Set[str] = set()
+
+ # See what rooms from the room subscriptions we should get account data for
+ if (
+ account_data_request.rooms is not None
+ and sync_config.room_subscriptions is not None
+ ):
+ actual_room_ids = sync_config.room_subscriptions.keys()
+
+ for room_id in account_data_request.rooms:
+ # A wildcard means we process all rooms from the room subscriptions
+ if room_id == "*":
+ relevant_room_ids.update(sync_config.room_subscriptions.keys())
+ break
+
+ if room_id in actual_room_ids:
+ relevant_room_ids.add(room_id)
+
+ # See what rooms from the sliding window lists we should get account data for
+ if account_data_request.lists is not None:
+ for list_key in account_data_request.lists:
+ # Just some typing because we share the variable name in multiple places
+ actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None
+
+ # A wildcard means we process rooms from all lists
+ if list_key == "*":
+ for actual_list in lists.values():
+ # We only expect a single SYNC operation for any list
+ assert len(actual_list.ops) == 1
+ sync_op = actual_list.ops[0]
+ assert sync_op.op == OperationType.SYNC
+
+ relevant_room_ids.update(sync_op.room_ids)
+
+ break
+
+ actual_list = lists.get(list_key)
+ if actual_list is not None:
+ # We only expect a single SYNC operation for any list
+ assert len(actual_list.ops) == 1
+ sync_op = actual_list.ops[0]
+ assert sync_op.op == OperationType.SYNC
+
+ relevant_room_ids.update(sync_op.room_ids)
+
+ # Fetch room account data
+ account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
+ if len(relevant_room_ids) > 0:
+ if from_token is not None:
+ account_data_by_room_map = (
+ await self.store.get_updated_room_account_data_for_user(
+ user_id, from_token.stream_token.account_data_key
+ )
+ )
+ else:
+ account_data_by_room_map = (
+ await self.store.get_room_account_data_for_user(user_id)
+ )
+
+ # Filter down to the relevant rooms
+ account_data_by_room_map = {
+ room_id: account_data_map
+ for room_id, account_data_map in account_data_by_room_map.items()
+ if room_id in relevant_room_ids
+ }
+
+ return SlidingSyncResult.Extensions.AccountDataExtension(
+ global_account_data_map=global_account_data_map,
+ account_data_by_room_map=account_data_by_room_map,
+ )
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d72dfa2b10..7cf1f56435 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -929,7 +929,6 @@ class SlidingSyncRestServlet(RestServlet):
return 200, response_content
- # TODO: Is there a better way to encode things?
async def encode_response(
self,
requester: Requester,
@@ -1117,6 +1116,24 @@ class SlidingSyncRestServlet(RestServlet):
extensions.e2ee.device_list_updates.left
)
+ if extensions.account_data is not None:
+ serialized_extensions["account_data"] = {
+ # Same as the the top-level `account_data.events` field in Sync v2.
+ "global": [
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in extensions.account_data.global_account_data_map.items()
+ ],
+ # Same as the joined room's account_data field in Sync v2, e.g the path
+ # `rooms.join["!foo:bar"].account_data.events`.
+ "rooms": {
+ room_id: [
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in event_map.items()
+ ]
+ for room_id, event_map in extensions.account_data.account_data_by_room_map.items()
+ },
+ }
+
return serialized_extensions
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 59eb0963ee..479222a18d 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -330,11 +330,31 @@ class SlidingSyncResult:
or self.device_unused_fallback_key_types
)
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class AccountDataExtension:
+ """The Account Data extension (MSC3959)
+
+ Attributes:
+ global_account_data_map: Mapping from `type` to `content` of global account
+ data events.
+ account_data_by_room_map: Mapping from room_id to mapping of `type` to
+ `content` of room account data events.
+ """
+
+ global_account_data_map: Mapping[str, JsonMapping]
+ account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]]
+
+ def __bool__(self) -> bool:
+ return bool(
+ self.global_account_data_map or self.account_data_by_room_map
+ )
+
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
+ account_data: Optional[AccountDataExtension] = None
def __bool__(self) -> bool:
- return bool(self.to_device or self.e2ee)
+ return bool(self.to_device or self.e2ee or self.account_data)
next_pos: SlidingSyncStreamToken
lists: Dict[str, SlidingWindowList]
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index f3c45a0d6a..34e07ddac5 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -322,8 +322,26 @@ class SlidingSyncBody(RequestBodyModel):
enabled: Optional[StrictBool] = False
+ class AccountDataExtension(RequestBodyModel):
+ """The Account Data extension (MSC3959)
+
+ Attributes:
+ enabled
+ lists: List of list keys (from the Sliding Window API) to apply this
+ extension to.
+ rooms: List of room IDs (from the Room Subscription API) to apply this
+ extension to.
+ """
+
+ enabled: Optional[StrictBool] = False
+ # Process all lists defined in the Sliding Window API. (This is the default.)
+ lists: Optional[List[StrictStr]] = ["*"]
+ # Process all room subscriptions defined in the Room Subscription API. (This is the default.)
+ rooms: Optional[List[StrictStr]] = ["*"]
+
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
+ account_data: Optional[AccountDataExtension] = None
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
|