summary refs log tree commit diff
path: root/synapse/handlers/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/_base.py')
-rw-r--r--synapse/handlers/_base.py39
1 files changed, 17 insertions, 22 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index fa83d3e464..064e8723c8 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,25 +53,10 @@ class BaseHandler(object):
         self.event_builder_factory = hs.get_event_builder_factory()
 
     @defer.inlineCallbacks
-    def _filter_events_for_clients(self, user_tuples, events):
+    def _filter_events_for_clients(self, user_tuples, events, event_id_to_state):
         """ Returns dict of user_id -> list of events that user is allowed to
         see.
         """
-        # If there is only one user, just get the state for that one user,
-        # otherwise just get all the state.
-        if len(user_tuples) == 1:
-            types = (
-                (EventTypes.RoomHistoryVisibility, ""),
-                (EventTypes.Member, user_tuples[0][0]),
-            )
-        else:
-            types = None
-
-        event_id_to_state = yield self.store.get_state_for_events(
-            frozenset(e.event_id for e in events),
-            types=types
-        )
-
         forgotten = yield defer.gatherResults([
             self.store.who_forgot_in_room(
                 room_id,
@@ -135,7 +120,17 @@ class BaseHandler(object):
     @defer.inlineCallbacks
     def _filter_events_for_client(self, user_id, events, is_peeking=False):
         # Assumes that user has at some point joined the room if not is_guest.
-        res = yield self._filter_events_for_clients([(user_id, is_peeking)], events)
+        types = (
+            (EventTypes.RoomHistoryVisibility, ""),
+            (EventTypes.Member, user_id),
+        )
+        event_id_to_state = yield self.store.get_state_for_events(
+            frozenset(e.event_id for e in events),
+            types=types
+        )
+        res = yield self._filter_events_for_clients(
+            [(user_id, is_peeking)], events, event_id_to_state
+        )
         defer.returnValue(res.get(user_id, []))
 
     def ratelimit(self, user_id):
@@ -269,13 +264,13 @@ class BaseHandler(object):
                         "You don't have permission to redact events"
                     )
 
-        (event_stream_id, max_stream_id) = yield self.store.persist_event(
-            event, context=context
-        )
-
         action_generator = ActionGenerator(self.hs)
         yield action_generator.handle_push_actions_for_event(
-            event, self
+            event, context, self
+        )
+
+        (event_stream_id, max_stream_id) = yield self.store.persist_event(
+            event, context=context
         )
 
         destinations = set()