summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/roommember.py21
-rw-r--r--synapse/storage/databases/main/stream.py30
2 files changed, 36 insertions, 15 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3248da5356..98d09b3736 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return None
 
     async def get_rooms_for_local_user_where_membership_is(
-        self, user_id: str, membership_list: Collection[str]
+        self,
+        user_id: str,
+        membership_list: Collection[str],
+        excluded_rooms: Optional[List[str]] = None,
     ) -> List[RoomsForUser]:
         """Get all the rooms for this *local* user where the membership for this user
         matches one in the membership list.
@@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             user_id: The user ID.
             membership_list: A list of synapse.api.constants.Membership
                 values which the user must be in.
+            excluded_rooms: A list of rooms to ignore.
 
         Returns:
             The RoomsForUser that the user matches the membership types.
@@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             membership_list,
         )
 
-        # Now we filter out forgotten rooms
-        forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
-        return [room for room in rooms if room.room_id not in forgotten_rooms]
+        # Now we filter out forgotten and excluded rooms
+        rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+
+        if excluded_rooms is not None:
+            rooms_to_exclude.update(set(excluded_rooms))
+
+        return [room for room in rooms if room.room_id not in rooms_to_exclude]
 
     def _get_rooms_for_local_user_where_membership_is_txn(
-        self, txn, user_id: str, membership_list: List[str]
+        self,
+        txn,
+        user_id: str,
+        membership_list: List[str],
     ) -> List[RoomsForUser]:
         # Paranoia check.
         if not self.hs.is_mine_id(user_id):
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 39e1efe373..8e764790db 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -36,7 +36,7 @@ what sort order was used:
 """
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple
 
 import attr
 from frozendict import frozendict
@@ -585,7 +585,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return ret, key
 
     async def get_membership_changes_for_user(
-        self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
+        self,
+        user_id: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
+        excluded_rooms: Optional[List[str]] = None,
     ) -> List[EventBase]:
         """Fetch membership events for a given user.
 
@@ -610,23 +614,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             min_from_id = from_key.stream
             max_to_id = to_key.get_max_stream_pos()
 
+            args: List[Any] = [user_id, min_from_id, max_to_id]
+
+            ignore_room_clause = ""
+            if excluded_rooms is not None and len(excluded_rooms) > 0:
+                ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
+                    "?" for _ in excluded_rooms
+                )
+                args = args + excluded_rooms
+
             sql = """
                 SELECT m.event_id, instance_name, topological_ordering, stream_ordering
                 FROM events AS e, room_memberships AS m
                 WHERE e.event_id = m.event_id
                     AND m.user_id = ?
                     AND e.stream_ordering > ? AND e.stream_ordering <= ?
+                    %s
                 ORDER BY e.stream_ordering ASC
-            """
-            txn.execute(
-                sql,
-                (
-                    user_id,
-                    min_from_id,
-                    max_to_id,
-                ),
+            """ % (
+                ignore_room_clause,
             )
 
+            txn.execute(sql, args)
+
             rows = [
                 _EventDictReturn(event_id, None, stream_ordering)
                 for event_id, instance_name, topological_ordering, stream_ordering in txn