summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-01-18 10:45:09 +0000
committerErik Johnston <erik@matrix.org>2016-01-18 14:43:50 +0000
commitcc66a9a5e3fc954b0da48ba891e9f77be31aa832 (patch)
tree230d80668cd13e86931ffa61b711746a5b432de7
parentMerge pull request #501 from matrix-org/daniel/unban (diff)
downloadsynapse-cc66a9a5e3fc954b0da48ba891e9f77be31aa832.tar.xz
Allow filtering events for multiple users at once
-rw-r--r--synapse/handlers/_base.py93
-rw-r--r--synapse/storage/roommember.py13
2 files changed, 67 insertions, 39 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb2c6733d5..2d1167296a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,16 +53,54 @@ class BaseHandler(object):
         self.event_builder_factory = hs.get_event_builder_factory()
 
     @defer.inlineCallbacks
-    def _filter_events_for_client(self, user_id, events, is_guest=False):
-        # Assumes that user has at some point joined the room if not is_guest.
+    def _filter_events_for_clients(self, users, events):
+        """ Returns dict of user_id -> list of events that user is allowed to
+        see.
+        """
+        event_id_to_state = yield self.store.get_state_for_events(
+            frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, None),
+            )
+        )
+
+        forgotten = yield defer.gatherResults([
+            self.store.who_forgot_in_room(
+                room_id,
+            )
+            for room_id in frozenset(e.room_id for e in events)
+        ], consumeErrors=True)
+
+        # Set of membership event_ids that have been forgotten
+        event_id_forgotten = frozenset(
+            row["event_id"] for rows in forgotten for row in rows
+        )
+
+        def allowed(event, user_id, is_guest):
+            state = event_id_to_state[event.event_id]
+
+            visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+            if visibility_event:
+                visibility = visibility_event.content.get("history_visibility", "shared")
+            else:
+                visibility = "shared"
 
-        def allowed(event, membership, visibility):
             if visibility == "world_readable":
                 return True
 
             if is_guest:
                 return False
 
+            membership_event = state.get((EventTypes.Member, user_id), None)
+            if membership_event:
+                if membership_event.event_id in event_id_forgotten:
+                    membership = None
+                else:
+                    membership = membership_event.membership
+            else:
+                membership = None
+
             if membership == Membership.JOIN:
                 return True
 
@@ -78,43 +116,20 @@ class BaseHandler(object):
 
             return True
 
-        event_id_to_state = yield self.store.get_state_for_events(
-            frozenset(e.event_id for e in events),
-            types=(
-                (EventTypes.RoomHistoryVisibility, ""),
-                (EventTypes.Member, user_id),
-            )
-        )
-
-        events_to_return = []
-        for event in events:
-            state = event_id_to_state[event.event_id]
+        defer.returnValue({
+            user_id: [
+                event
+                for event in events
+                if allowed(event, user_id, is_guest)
+            ]
+            for user_id, is_guest in users
+        })
 
-            membership_event = state.get((EventTypes.Member, user_id), None)
-            if membership_event:
-                was_forgotten_at_event = yield self.store.was_forgotten_at(
-                    membership_event.state_key,
-                    membership_event.room_id,
-                    membership_event.event_id
-                )
-                if was_forgotten_at_event:
-                    membership = None
-                else:
-                    membership = membership_event.membership
-            else:
-                membership = None
-
-            visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
-            if visibility_event:
-                visibility = visibility_event.content.get("history_visibility", "shared")
-            else:
-                visibility = "shared"
-
-            should_include = allowed(event, membership, visibility)
-            if should_include:
-                events_to_return.append(event)
-
-        defer.returnValue(events_to_return)
+    @defer.inlineCallbacks
+    def _filter_events_for_client(self, user_id, events, is_guest=False):
+        # Assumes that user has at some point joined the room if not is_guest.
+        res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
+        defer.returnValue(res.get(user_id, []))
 
     def ratelimit(self, user_id):
         time_now = self.clock.time()
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 7d3ce4579d..68ac88905f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
             txn.execute(sql, (user_id, room_id))
         yield self.runInteraction("forget_membership", f)
         self.was_forgotten_at.invalidate_all()
+        self.who_forgot_in_room.invalidate_all()
         self.did_forget.invalidate((user_id, room_id))
 
     @cachedInlineCallbacks(num_args=2)
@@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
             return rows[0][0]
         forgot = yield self.runInteraction("did_forget_membership_at", f)
         defer.returnValue(forgot == 1)
+
+    @cached()
+    def who_forgot_in_room(self, room_id):
+        return self._simple_select_list(
+            table="room_memberships",
+            retcols=("user_id", "event_id"),
+            keyvalues={
+                "room_id": room_id,
+                "forgotten": 1,
+            },
+            desc="who_forgot"
+        )