summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r--synapse/storage/databases/main/roommember.py46
1 files changed, 26 insertions, 20 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a8d224602a..8ada3cdac3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -662,31 +662,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not user_ids:
             return set()
 
-        def _get_users_server_still_shares_room_with_txn(
-            txn: LoggingTransaction,
-        ) -> Set[str]:
-            sql = """
-                SELECT state_key FROM current_state_events
-                WHERE
-                    type = 'm.room.member'
-                    AND membership = 'join'
-                    AND %s
-                GROUP BY state_key
-            """
-
-            clause, args = make_in_list_sql_clause(
-                self.database_engine, "state_key", user_ids
-            )
+        return await self.db_pool.runInteraction(
+            "get_users_server_still_shares_room_with",
+            self.get_users_server_still_shares_room_with_txn,
+            user_ids,
+        )
 
-            txn.execute(sql % (clause,), args)
+    def get_users_server_still_shares_room_with_txn(
+        self,
+        txn: LoggingTransaction,
+        user_ids: Collection[str],
+    ) -> Set[str]:
+        if not user_ids:
+            return set()
 
-            return {row[0] for row in txn}
+        sql = """
+            SELECT state_key FROM current_state_events
+            WHERE
+                type = 'm.room.member'
+                AND membership = 'join'
+                AND %s
+            GROUP BY state_key
+        """
 
-        return await self.db_pool.runInteraction(
-            "get_users_server_still_shares_room_with",
-            _get_users_server_still_shares_room_with_txn,
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "state_key", user_ids
         )
 
+        txn.execute(sql % (clause,), args)
+
+        return {row[0] for row in txn}
+
     @cancellable
     async def get_rooms_for_user(
         self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None