summary refs log tree commit diff
path: root/synapse/handlers/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/_base.py')
-rw-r--r--synapse/handlers/_base.py31
1 files changed, 13 insertions, 18 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index fa83d3e464..d3f722b22e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,25 +53,10 @@ class BaseHandler(object):
         self.event_builder_factory = hs.get_event_builder_factory()
 
     @defer.inlineCallbacks
-    def _filter_events_for_clients(self, user_tuples, events):
+    def _filter_events_for_clients(self, user_tuples, events, event_id_to_state):
         """ Returns dict of user_id -> list of events that user is allowed to
         see.
         """
-        # If there is only one user, just get the state for that one user,
-        # otherwise just get all the state.
-        if len(user_tuples) == 1:
-            types = (
-                (EventTypes.RoomHistoryVisibility, ""),
-                (EventTypes.Member, user_tuples[0][0]),
-            )
-        else:
-            types = None
-
-        event_id_to_state = yield self.store.get_state_for_events(
-            frozenset(e.event_id for e in events),
-            types=types
-        )
-
         forgotten = yield defer.gatherResults([
             self.store.who_forgot_in_room(
                 room_id,
@@ -135,7 +120,17 @@ class BaseHandler(object):
     @defer.inlineCallbacks
     def _filter_events_for_client(self, user_id, events, is_peeking=False):
         # Assumes that user has at some point joined the room if not is_guest.
-        res = yield self._filter_events_for_clients([(user_id, is_peeking)], events)
+        types = (
+            (EventTypes.RoomHistoryVisibility, ""),
+            (EventTypes.Member, user_id),
+        )
+        event_id_to_state = yield self.store.get_state_for_events(
+            frozenset(e.event_id for e in events),
+            types=types
+        )
+        res = yield self._filter_events_for_clients(
+            [(user_id, is_peeking)], events, event_id_to_state
+        )
         defer.returnValue(res.get(user_id, []))
 
     def ratelimit(self, user_id):
@@ -275,7 +270,7 @@ class BaseHandler(object):
 
         action_generator = ActionGenerator(self.hs)
         yield action_generator.handle_push_actions_for_event(
-            event, self
+            event, self, context.current_state
         )
 
         destinations = set()