summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/handlers/message.py33
-rw-r--r--synapse/storage/state.py63
3 files changed, 95 insertions, 3 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index d8a18ee87b..3e15e8a9d7 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -75,6 +75,8 @@ class EventTypes(object):
     Redaction = "m.room.redaction"
     Feedback = "m.room.message.feedback"
 
+    RoomHistoryVisibility = "m.room.history_visibility"
+
     # These are used for validation
     Message = "m.room.message"
     Topic = "m.room.topic"
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e324662f18..17c75f33c9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -113,11 +113,42 @@ class MessageHandler(BaseHandler):
             "room_key", next_key
         )
 
+        if not events:
+            defer.returnValue({
+                "chunk": [],
+                "start": pagin_config.from_token.to_string(),
+                "end": next_token.to_string(),
+            })
+
+        states = yield self.store.get_state_for_events(
+            room_id, [e.event_id for e in events],
+        )
+
+        events_and_states = zip(events, states)
+
+        def allowed(event_and_state):
+            _, state = event_and_state
+
+            membership = state.get((EventTypes.Member, user_id), None)
+            if membership and membership.membership == Membership.JOIN:
+                return True
+
+            history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+            if history and history.content["visibility"] == "after_join":
+                return False
+
+        events_and_states = filter(allowed, events_and_states)
+        events = [
+            ev
+            for ev, _ in events_and_states
+        ]
+
         time_now = self.clock.time_msec()
 
         chunk = {
             "chunk": [
-                serialize_event(e, time_now, as_client_event) for e in events
+                serialize_event(e, time_now, as_client_event)
+                for e in events
             ],
             "start": pagin_config.from_token.to_string(),
             "end": next_token.to_string(),
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f2b17f29ea..d7844edee3 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -92,11 +92,11 @@ class StateStore(SQLBaseStore):
         defer.returnValue(dict(state_list))
 
     @cached(num_args=1)
-    def _fetch_events_for_group(self, state_group, events):
+    def _fetch_events_for_group(self, key, events):
         return self._get_events(
             events, get_prev_content=False
         ).addCallback(
-            lambda evs: (state_group, evs)
+            lambda evs: (key, evs)
         )
 
     def _store_state_groups_txn(self, txn, event, context):
@@ -194,6 +194,65 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
+    @defer.inlineCallbacks
+    def get_state_for_events(self, room_id, event_ids):
+        def f(txn):
+            groups = set()
+            event_to_group = {}
+            for event_id in event_ids:
+                # TODO: Remove this loop.
+                group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="event_to_state_groups",
+                    keyvalues={"event_id": event_id},
+                    retcol="state_group",
+                    allow_none=True,
+                )
+                if group:
+                    event_to_group[event_id] = group
+                    groups.add(group)
+
+            group_to_state_ids = {}
+            for group in groups:
+                state_ids = self._simple_select_onecol_txn(
+                    txn,
+                    table="state_groups_state",
+                    keyvalues={"state_group": group},
+                    retcol="event_id",
+                )
+
+                group_to_state_ids[group] = state_ids
+
+            return event_to_group, group_to_state_ids
+
+        res = yield self.runInteraction(
+            "annotate_events_with_state_groups",
+            f,
+        )
+
+        event_to_group, group_to_state_ids = res
+
+        state_list = yield defer.gatherResults(
+            [
+                self._fetch_events_for_group(group, vals)
+                for group, vals in group_to_state_ids.items()
+            ],
+            consumeErrors=True,
+        )
+
+        state_dict = {
+            group: {
+                (ev.type, ev.state_key): ev
+                for ev in state
+            }
+            for group, state in state_list
+        }
+
+        defer.returnValue([
+            state_dict.get(event_to_group.get(event, None), None)
+            for event in event_ids
+        ])
+
 
 def _make_group_id(clock):
     return str(int(clock.time_msec())) + random_string(5)