summary refs log tree commit diff
path: root/synapse/handlers/_base.py
diff options
context:
space:
mode:
authorDavid Baker <dave@matrix.org>2016-01-19 18:17:23 +0000
committerDavid Baker <dave@matrix.org>2016-01-19 18:17:23 +0000
commitafb7b377f23b275bf0274d6cbbfae462362cfc8c (patch)
tree212e275af6f3d52dba6e4367553ea649d9965d33 /synapse/handlers/_base.py
parentUse the unread notification count to send accurate badge counts in push notif... (diff)
parentMerge pull request #505 from matrix-org/erikj/push_fast (diff)
downloadsynapse-afb7b377f23b275bf0274d6cbbfae462362cfc8c.tar.xz
Merge branch 'develop' into push_badge_counts
Diffstat (limited to 'synapse/handlers/_base.py')
-rw-r--r--synapse/handlers/_base.py103
1 files changed, 58 insertions, 45 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 66e35de6e4..5c7617de44 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,16 +53,54 @@ class BaseHandler(object):
         self.event_builder_factory = hs.get_event_builder_factory()
 
     @defer.inlineCallbacks
-    def _filter_events_for_client(self, user_id, events, is_guest=False):
-        # Assumes that user has at some point joined the room if not is_guest.
+    def _filter_events_for_clients(self, users, events):
+        """ Returns dict of user_id -> list of events that user is allowed to
+        see.
+        """
+        event_id_to_state = yield self.store.get_state_for_events(
+            frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, None),
+            )
+        )
+
+        forgotten = yield defer.gatherResults([
+            self.store.who_forgot_in_room(
+                room_id,
+            )
+            for room_id in frozenset(e.room_id for e in events)
+        ], consumeErrors=True)
+
+        # Set of membership event_ids that have been forgotten
+        event_id_forgotten = frozenset(
+            row["event_id"] for rows in forgotten for row in rows
+        )
+
+        def allowed(event, user_id, is_guest):
+            state = event_id_to_state[event.event_id]
+
+            visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+            if visibility_event:
+                visibility = visibility_event.content.get("history_visibility", "shared")
+            else:
+                visibility = "shared"
 
-        def allowed(event, membership, visibility):
             if visibility == "world_readable":
                 return True
 
             if is_guest:
                 return False
 
+            membership_event = state.get((EventTypes.Member, user_id), None)
+            if membership_event:
+                if membership_event.event_id in event_id_forgotten:
+                    membership = None
+                else:
+                    membership = membership_event.membership
+            else:
+                membership = None
+
             if membership == Membership.JOIN:
                 return True
 
@@ -78,43 +116,20 @@ class BaseHandler(object):
 
             return True
 
-        event_id_to_state = yield self.store.get_state_for_events(
-            frozenset(e.event_id for e in events),
-            types=(
-                (EventTypes.RoomHistoryVisibility, ""),
-                (EventTypes.Member, user_id),
-            )
-        )
-
-        events_to_return = []
-        for event in events:
-            state = event_id_to_state[event.event_id]
+        defer.returnValue({
+            user_id: [
+                event
+                for event in events
+                if allowed(event, user_id, is_guest)
+            ]
+            for user_id, is_guest in users
+        })
 
-            membership_event = state.get((EventTypes.Member, user_id), None)
-            if membership_event:
-                was_forgotten_at_event = yield self.store.was_forgotten_at(
-                    membership_event.state_key,
-                    membership_event.room_id,
-                    membership_event.event_id
-                )
-                if was_forgotten_at_event:
-                    membership = None
-                else:
-                    membership = membership_event.membership
-            else:
-                membership = None
-
-            visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
-            if visibility_event:
-                visibility = visibility_event.content.get("history_visibility", "shared")
-            else:
-                visibility = "shared"
-
-            should_include = allowed(event, membership, visibility)
-            if should_include:
-                events_to_return.append(event)
-
-        defer.returnValue(events_to_return)
+    @defer.inlineCallbacks
+    def _filter_events_for_client(self, user_id, events, is_guest=False):
+        # Assumes that user has at some point joined the room if not is_guest.
+        res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
+        defer.returnValue(res.get(user_id, []))
 
     def ratelimit(self, user_id):
         time_now = self.clock.time()
@@ -171,12 +186,10 @@ class BaseHandler(object):
         )
 
     @defer.inlineCallbacks
-    def handle_new_client_event(self, event, context, extra_destinations=[],
-                                extra_users=[], suppress_auth=False):
+    def handle_new_client_event(self, event, context, extra_users=[]):
         # We now need to go and hit out to wherever we need to hit out to.
 
-        if not suppress_auth:
-            self.auth.check(event, auth_events=context.current_state)
+        self.auth.check(event, auth_events=context.current_state)
 
         yield self.maybe_kick_guest_users(event, context.current_state.values())
 
@@ -253,12 +266,12 @@ class BaseHandler(object):
             event, context=context
         )
 
-        action_generator = ActionGenerator(self.store)
+        action_generator = ActionGenerator(self.hs)
         yield action_generator.handle_push_actions_for_event(
             event, self
         )
 
-        destinations = set(extra_destinations)
+        destinations = set()
         for k, s in context.current_state.items():
             try:
                 if k[0] == EventTypes.Member: