summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-07-18 14:14:38 +0100
committerErik Johnston <erik@matrix.org>2024-07-18 14:14:38 +0100
commit605b358d4d719b81e52800c1bc53d9547905c180 (patch)
treea37fcb2fa670b88d798b46e4de22d15aed5f5de0
parentAdd context to conn_id (diff)
downloadsynapse-605b358d4d719b81e52800c1bc53d9547905c180.tar.xz
Refactor to avoid SyncConfig.connection_id()
-rw-r--r--synapse/handlers/sliding_sync.py65
-rw-r--r--synapse/types/handlers/__init__.py30
2 files changed, 48 insertions, 47 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index cb45af79ea..75603b6f75 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -453,8 +453,7 @@ class SlidingSyncHandler:
             raise NotImplementedError()
 
         await self.connection_store.mark_token_seen(
-            user_id,
-            conn_id=sync_config.connection_id(),
+            sync_config=sync_config,
             from_token=from_token,
         )
 
@@ -623,8 +622,7 @@ class SlidingSyncHandler:
 
         if has_lists or has_room_subscriptions:
             connection_token = await self.connection_store.record_rooms(
-                user_id,
-                conn_id=sync_config.connection_id(),
+                sync_config=sync_config,
                 from_token=from_token,
                 sent_room_ids=relevant_room_map.keys(),
                 # TODO: We need to calculate which rooms have had updates since the `from_token` but were not included in the `sent_room_ids`
@@ -1406,8 +1404,7 @@ class SlidingSyncHandler:
         initial = True
         if from_token and not room_membership_for_user_at_to_token.newly_joined:
             room_status = await self.connection_store.have_sent_room(
-                user_id=user.to_string(),
-                conn_id=sync_config.connection_id(),
+                sync_config=sync_config,
                 connection_token=from_token.connection_token,
                 room_id=room_id,
             )
@@ -1986,13 +1983,14 @@ class SlidingSyncConnectionStore:
     )
 
     async def have_sent_room(
-        self, user_id: str, conn_id: str, connection_token: int, room_id: str
+        self, sync_config: SlidingSyncConfig, connection_token: int, room_id: str
     ) -> HaveSentRoom:
         """For the given user_id/conn_id/token, return whether we have
         previously sent the room down
         """
 
-        sync_statuses = self._connections.setdefault((user_id, conn_id), {})
+        conn_key = self._get_connection_key(sync_config)
+        sync_statuses = self._connections.setdefault(conn_key, {})
         room_status = sync_statuses.get(connection_token, {}).get(
             room_id, HAVE_SENT_ROOM_NEVER
         )
@@ -2001,8 +1999,7 @@ class SlidingSyncConnectionStore:
 
     async def record_rooms(
         self,
-        user_id: str,
-        conn_id: str,
+        sync_config: SlidingSyncConfig,
         from_token: Optional[SlidingSyncStreamToken],
         *,
         sent_room_ids: StrCollection,
@@ -2011,8 +2008,7 @@ class SlidingSyncConnectionStore:
         """Record which rooms we have/haven't sent down in a new response
 
         Attributes:
-            user_id
-            conn_id
+            sync_config
             from_token: The since token from the request, if any
             sent_room_ids: The set of room IDs that we have sent down as
                 part of this request (only needs to be ones we didn't
@@ -2029,7 +2025,8 @@ class SlidingSyncConnectionStore:
         if not sent_room_ids and not unsent_room_ids:
             return prev_connection_token
 
-        sync_statuses = self._connections.setdefault((user_id, conn_id), {})
+        conn_key = self._get_connection_key(sync_config)
+        sync_statuses = self._connections.setdefault(conn_key, {})
 
         # Generate a new token, removing any existing entries in that token
         # (which can happen if requests get resent).
@@ -2077,8 +2074,7 @@ class SlidingSyncConnectionStore:
 
     async def mark_token_seen(
         self,
-        user_id: str,
-        conn_id: str,
+        sync_config: SlidingSyncConfig,
         from_token: Optional[SlidingSyncStreamToken],
     ) -> None:
         """We have received a request with the given token, so we can clear out
@@ -2090,7 +2086,8 @@ class SlidingSyncConnectionStore:
         # Clear out any tokens for the connection that doesn't match the one
         # from the request.
 
-        sync_statuses = self._connections.pop((user_id, conn_id), {})
+        conn_key = self._get_connection_key(sync_config)
+        sync_statuses = self._connections.pop(conn_key, {})
         if from_token is None:
             return
 
@@ -2100,4 +2097,38 @@ class SlidingSyncConnectionStore:
             if connection_token == from_token.connection_token
         }
         if sync_statuses:
-            self._connections[(user_id, conn_id)] = sync_statuses
+            self._connections[conn_key] = sync_statuses
+
+    @staticmethod
+    def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
+        """Return a unique identifier for this connection.
+
+        The first part is simply the user ID.
+
+        The second part is generally a combination of device ID and conn_id.
+        However, both these two are optional (e.g. puppet access tokens don't
+        have device IDs), so this handles those edge cases.
+
+        We use this over the raw `conn_id` to avoid clashes between different
+        clients that use the same `conn_id`. Imagine a user uses a web client
+        that uses `conn_id: main_sync_loop` and an Android client that also has
+        a `conn_id: main_sync_loop`.
+        """
+
+        user_id = sync_config.user.to_string()
+
+        # If this is missing, only one sliding sync connection is allowed per
+        # given conn_id.
+        conn_id = sync_config.conn_id or ""
+
+        if sync_config.requester.device_id:
+            return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")
+
+        if sync_config.requester.access_token_id:
+            # If we don't have a device, then the access token ID should be a
+            # stable ID.
+            return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")
+
+        # If we have neither then its likely an AS or some weird token. Either
+        # way we can just fail here.
+        raise Exception("Cannot use sliding sync with access token type")
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 6396be839f..1b9de129ba 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -120,36 +120,6 @@ class SlidingSyncConfig(SlidingSyncBody):
         # Allow custom types like `UserID` to be used in the model
         arbitrary_types_allowed = True
 
-    def connection_id(self) -> str:
-        """Return a string identifier for this connection. May clash with
-        connection IDs from different users.
-
-        This is generally a combination of device ID and conn_id. However, both
-        these two are optional (e.g. puppet access tokens don't have device
-        IDs), so this handles those edge cases.
-
-        We use this over the raw `conn_id` to avoid clashes between different
-        clients that use the same `conn_id`. Imagine a user uses a web client
-        that uses `conn_id: main_sync_loop` and an Android client that also has
-        a `conn_id: main_sync_loop`.
-        """
-
-        # If this is missing, only one sliding sync connection is allowed per
-        # given conn_id.
-        conn_id = self.conn_id or ""
-
-        if self.requester.device_id:
-            return f"D/{self.requester.device_id}/{conn_id}"
-
-        if self.requester.access_token_id:
-            # If we don't have a device, then the access token ID should be a
-            # stable ID.
-            return f"A/{self.requester.access_token_id}/{conn_id}"
-
-        # If we have neither then its likely an AS or some weird token. Either
-        # way we can just fail here.
-        raise Exception("Cannot use sliding sync with access token type")
-
 
 class OperationType(Enum):
     """