summary refs log tree commit diff
path: root/synapse/handlers/sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sync.py')
-rw-r--r--synapse/handlers/sync.py87
1 files changed, 67 insertions, 20 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b1c58ffdc8..7f2138d804 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -160,6 +160,16 @@ class InvitedSyncResult:
 
 
 @attr.s(slots=True, frozen=True)
+class KnockedSyncResult:
+    room_id = attr.ib(type=str)
+    knock = attr.ib(type=EventBase)
+
+    def __bool__(self) -> bool:
+        """Knocked rooms should always be reported to the client"""
+        return True
+
+
+@attr.s(slots=True, frozen=True)
 class GroupsSyncResult:
     join = attr.ib(type=JsonDict)
     invite = attr.ib(type=JsonDict)
@@ -192,6 +202,7 @@ class _RoomChanges:
 
     room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
     invited = attr.ib(type=List[InvitedSyncResult])
+    knocked = attr.ib(type=List[KnockedSyncResult])
     newly_joined_rooms = attr.ib(type=List[str])
     newly_left_rooms = attr.ib(type=List[str])
 
@@ -205,6 +216,7 @@ class SyncResult:
         account_data: List of account_data events for the user.
         joined: JoinedSyncResult for each joined room.
         invited: InvitedSyncResult for each invited room.
+        knocked: KnockedSyncResult for each knocked on room.
         archived: ArchivedSyncResult for each archived room.
         to_device: List of direct messages for the device.
         device_lists: List of user_ids whose devices have changed
@@ -220,6 +232,7 @@ class SyncResult:
     account_data = attr.ib(type=List[JsonDict])
     joined = attr.ib(type=List[JoinedSyncResult])
     invited = attr.ib(type=List[InvitedSyncResult])
+    knocked = attr.ib(type=List[KnockedSyncResult])
     archived = attr.ib(type=List[ArchivedSyncResult])
     to_device = attr.ib(type=List[JsonDict])
     device_lists = attr.ib(type=DeviceLists)
@@ -236,6 +249,7 @@ class SyncResult:
             self.presence
             or self.joined
             or self.invited
+            or self.knocked
             or self.archived
             or self.account_data
             or self.to_device
@@ -1031,7 +1045,7 @@ class SyncHandler:
         res = await self._generate_sync_entry_for_rooms(
             sync_result_builder, account_data_by_room
         )
-        newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
+        newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res
         _, _, newly_left_rooms, newly_left_users = res
 
         block_all_presence_data = (
@@ -1040,7 +1054,9 @@ class SyncHandler:
         if self.hs_config.use_presence and not block_all_presence_data:
             logger.debug("Fetching presence data")
             await self._generate_sync_entry_for_presence(
-                sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
+                sync_result_builder,
+                newly_joined_rooms,
+                newly_joined_or_invited_or_knocked_users,
             )
 
         logger.debug("Fetching to-device data")
@@ -1049,7 +1065,7 @@ class SyncHandler:
         device_lists = await self._generate_sync_entry_for_device_list(
             sync_result_builder,
             newly_joined_rooms=newly_joined_rooms,
-            newly_joined_or_invited_users=newly_joined_or_invited_users,
+            newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
             newly_left_rooms=newly_left_rooms,
             newly_left_users=newly_left_users,
         )
@@ -1083,6 +1099,7 @@ class SyncHandler:
             account_data=sync_result_builder.account_data,
             joined=sync_result_builder.joined,
             invited=sync_result_builder.invited,
+            knocked=sync_result_builder.knocked,
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
             device_lists=device_lists,
@@ -1142,7 +1159,7 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         newly_joined_rooms: Set[str],
-        newly_joined_or_invited_users: Set[str],
+        newly_joined_or_invited_or_knocked_users: Set[str],
         newly_left_rooms: Set[str],
         newly_left_users: Set[str],
     ) -> DeviceLists:
@@ -1151,8 +1168,9 @@ class SyncHandler:
         Args:
             sync_result_builder
             newly_joined_rooms: Set of rooms user has joined since previous sync
-            newly_joined_or_invited_users: Set of users that have joined or
-                been invited to a room since previous sync.
+            newly_joined_or_invited_or_knocked_users: Set of users that have joined,
+                been invited to a room or are knocking on a room since
+                previous sync.
             newly_left_rooms: Set of rooms user has left since previous sync
             newly_left_users: Set of users that have left a room we're in since
                 previous sync
@@ -1163,7 +1181,9 @@ class SyncHandler:
 
         # We're going to mutate these fields, so lets copy them rather than
         # assume they won't get used later.
-        newly_joined_or_invited_users = set(newly_joined_or_invited_users)
+        newly_joined_or_invited_or_knocked_users = set(
+            newly_joined_or_invited_or_knocked_users
+        )
         newly_left_users = set(newly_left_users)
 
         if since_token and since_token.device_list_key:
@@ -1202,11 +1222,11 @@ class SyncHandler:
             # Step 1b, check for newly joined rooms
             for room_id in newly_joined_rooms:
                 joined_users = await self.store.get_users_in_room(room_id)
-                newly_joined_or_invited_users.update(joined_users)
+                newly_joined_or_invited_or_knocked_users.update(joined_users)
 
             # TODO: Check that these users are actually new, i.e. either they
             # weren't in the previous sync *or* they left and rejoined.
-            users_that_have_changed.update(newly_joined_or_invited_users)
+            users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
 
             user_signatures_changed = (
                 await self.store.get_users_whose_signatures_changed(
@@ -1452,6 +1472,7 @@ class SyncHandler:
 
         room_entries = room_changes.room_entries
         invited = room_changes.invited
+        knocked = room_changes.knocked
         newly_joined_rooms = room_changes.newly_joined_rooms
         newly_left_rooms = room_changes.newly_left_rooms
 
@@ -1472,9 +1493,10 @@ class SyncHandler:
         await concurrently_execute(handle_room_entries, room_entries, 10)
 
         sync_result_builder.invited.extend(invited)
+        sync_result_builder.knocked.extend(knocked)
 
-        # Now we want to get any newly joined or invited users
-        newly_joined_or_invited_users = set()
+        # Now we want to get any newly joined, invited or knocking users
+        newly_joined_or_invited_or_knocked_users = set()
         newly_left_users = set()
         if since_token:
             for joined_sync in sync_result_builder.joined:
@@ -1486,19 +1508,22 @@ class SyncHandler:
                         if (
                             event.membership == Membership.JOIN
                             or event.membership == Membership.INVITE
+                            or event.membership == Membership.KNOCK
                         ):
-                            newly_joined_or_invited_users.add(event.state_key)
+                            newly_joined_or_invited_or_knocked_users.add(
+                                event.state_key
+                            )
                         else:
                             prev_content = event.unsigned.get("prev_content", {})
                             prev_membership = prev_content.get("membership", None)
                             if prev_membership == Membership.JOIN:
                                 newly_left_users.add(event.state_key)
 
-        newly_left_users -= newly_joined_or_invited_users
+        newly_left_users -= newly_joined_or_invited_or_knocked_users
 
         return (
             set(newly_joined_rooms),
-            newly_joined_or_invited_users,
+            newly_joined_or_invited_or_knocked_users,
             set(newly_left_rooms),
             newly_left_users,
         )
@@ -1553,6 +1578,7 @@ class SyncHandler:
         newly_left_rooms = []
         room_entries = []
         invited = []
+        knocked = []
         for room_id, events in mem_change_events_by_room_id.items():
             logger.debug(
                 "Membership changes in %s: [%s]",
@@ -1632,9 +1658,17 @@ class SyncHandler:
             should_invite = non_joins[-1].membership == Membership.INVITE
             if should_invite:
                 if event.sender not in ignored_users:
-                    room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
-                    if room_sync:
-                        invited.append(room_sync)
+                    invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+                    if invite_room_sync:
+                        invited.append(invite_room_sync)
+
+            # Only bother if our latest membership in the room is knock (and we haven't
+            # been accepted/rejected in the meantime).
+            should_knock = non_joins[-1].membership == Membership.KNOCK
+            if should_knock:
+                knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
+                if knock_room_sync:
+                    knocked.append(knock_room_sync)
 
             # Always include leave/ban events. Just take the last one.
             # TODO: How do we handle ban -> leave in same batch?
@@ -1738,7 +1772,13 @@ class SyncHandler:
                 )
             room_entries.append(entry)
 
-        return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
+        return _RoomChanges(
+            room_entries,
+            invited,
+            knocked,
+            newly_joined_rooms,
+            newly_left_rooms,
+        )
 
     async def _get_all_rooms(
         self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
@@ -1758,6 +1798,7 @@ class SyncHandler:
 
         membership_list = (
             Membership.INVITE,
+            Membership.KNOCK,
             Membership.JOIN,
             Membership.LEAVE,
             Membership.BAN,
@@ -1769,6 +1810,7 @@ class SyncHandler:
 
         room_entries = []
         invited = []
+        knocked = []
 
         for event in room_list:
             if event.membership == Membership.JOIN:
@@ -1788,8 +1830,11 @@ class SyncHandler:
                     continue
                 invite = await self.store.get_event(event.event_id)
                 invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
+            elif event.membership == Membership.KNOCK:
+                knock = await self.store.get_event(event.event_id)
+                knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock))
             elif event.membership in (Membership.LEAVE, Membership.BAN):
-                # Always send down rooms we were banned or kicked from.
+                # Always send down rooms we were banned from or kicked from.
                 if not sync_config.filter_collection.include_leave:
                     if event.membership == Membership.LEAVE:
                         if user_id == event.sender:
@@ -1810,7 +1855,7 @@ class SyncHandler:
                     )
                 )
 
-        return _RoomChanges(room_entries, invited, [], [])
+        return _RoomChanges(room_entries, invited, knocked, [], [])
 
     async def _generate_room_entry(
         self,
@@ -2101,6 +2146,7 @@ class SyncResultBuilder:
         account_data (list)
         joined (list[JoinedSyncResult])
         invited (list[InvitedSyncResult])
+        knocked (list[KnockedSyncResult])
         archived (list[ArchivedSyncResult])
         groups (GroupsSyncResult|None)
         to_device (list)
@@ -2116,6 +2162,7 @@ class SyncResultBuilder:
     account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
     joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
     invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
+    knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
     archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
     groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
     to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))