summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py93
1 files changed, 54 insertions, 39 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb2c6733d5..2d1167296a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,16 +53,54 @@ class BaseHandler(object):
         self.event_builder_factory = hs.get_event_builder_factory()
 
     @defer.inlineCallbacks
-    def _filter_events_for_client(self, user_id, events, is_guest=False):
-        # Assumes that user has at some point joined the room if not is_guest.
+    def _filter_events_for_clients(self, users, events):
+        """ Returns dict of user_id -> list of events that user is allowed to
+        see.
+        """
+        event_id_to_state = yield self.store.get_state_for_events(
+            frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, None),
+            )
+        )
+
+        forgotten = yield defer.gatherResults([
+            self.store.who_forgot_in_room(
+                room_id,
+            )
+            for room_id in frozenset(e.room_id for e in events)
+        ], consumeErrors=True)
+
+        # Set of membership event_ids that have been forgotten
+        event_id_forgotten = frozenset(
+            row["event_id"] for rows in forgotten for row in rows
+        )
+
+        def allowed(event, user_id, is_guest):
+            state = event_id_to_state[event.event_id]
+
+            visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+            if visibility_event:
+                visibility = visibility_event.content.get("history_visibility", "shared")
+            else:
+                visibility = "shared"
 
-        def allowed(event, membership, visibility):
             if visibility == "world_readable":
                 return True
 
             if is_guest:
                 return False
 
+            membership_event = state.get((EventTypes.Member, user_id), None)
+            if membership_event:
+                if membership_event.event_id in event_id_forgotten:
+                    membership = None
+                else:
+                    membership = membership_event.membership
+            else:
+                membership = None
+
             if membership == Membership.JOIN:
                 return True
 
@@ -78,43 +116,20 @@ class BaseHandler(object):
 
             return True
 
-        event_id_to_state = yield self.store.get_state_for_events(
-            frozenset(e.event_id for e in events),
-            types=(
-                (EventTypes.RoomHistoryVisibility, ""),
-                (EventTypes.Member, user_id),
-            )
-        )
-
-        events_to_return = []
-        for event in events:
-            state = event_id_to_state[event.event_id]
+        defer.returnValue({
+            user_id: [
+                event
+                for event in events
+                if allowed(event, user_id, is_guest)
+            ]
+            for user_id, is_guest in users
+        })
 
-            membership_event = state.get((EventTypes.Member, user_id), None)
-            if membership_event:
-                was_forgotten_at_event = yield self.store.was_forgotten_at(
-                    membership_event.state_key,
-                    membership_event.room_id,
-                    membership_event.event_id
-                )
-                if was_forgotten_at_event:
-                    membership = None
-                else:
-                    membership = membership_event.membership
-            else:
-                membership = None
-
-            visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
-            if visibility_event:
-                visibility = visibility_event.content.get("history_visibility", "shared")
-            else:
-                visibility = "shared"
-
-            should_include = allowed(event, membership, visibility)
-            if should_include:
-                events_to_return.append(event)
-
-        defer.returnValue(events_to_return)
+    @defer.inlineCallbacks
+    def _filter_events_for_client(self, user_id, events, is_guest=False):
+        # Assumes that user has at some point joined the room if not is_guest.
+        res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
+        defer.returnValue(res.get(user_id, []))
 
     def ratelimit(self, user_id):
         time_now = self.clock.time()