From 1a605456260bfb46d8bb9cff2d40d19aec03daa4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Jul 2015 16:20:10 +0100 Subject: Add basic impl for room history ACL on GET /messages client API --- synapse/api/constants.py | 2 ++ synapse/handlers/message.py | 33 +++++++++++++++++++++++- synapse/storage/state.py | 63 +++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 95 insertions(+), 3 deletions(-) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index d8a18ee87b..3e15e8a9d7 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -75,6 +75,8 @@ class EventTypes(object): Redaction = "m.room.redaction" Feedback = "m.room.message.feedback" + RoomHistoryVisibility = "m.room.history_visibility" + # These are used for validation Message = "m.room.message" Topic = "m.room.topic" diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e324662f18..17c75f33c9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -113,11 +113,42 @@ class MessageHandler(BaseHandler): "room_key", next_key ) + if not events: + defer.returnValue({ + "chunk": [], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + }) + + states = yield self.store.get_state_for_events( + room_id, [e.event_id for e in events], + ) + + events_and_states = zip(events, states) + + def allowed(event_and_state): + _, state = event_and_state + + membership = state.get((EventTypes.Member, user_id), None) + if membership and membership.membership == Membership.JOIN: + return True + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history and history.content["visibility"] == "after_join": + return False + + events_and_states = filter(allowed, events_and_states) + events = [ + ev + for ev, _ in events_and_states + ] + time_now = self.clock.time_msec() chunk = { "chunk": [ - serialize_event(e, time_now, as_client_event) for e in events + serialize_event(e, time_now, as_client_event) + for e in events ], "start": pagin_config.from_token.to_string(), "end": next_token.to_string(), diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f2b17f29ea..d7844edee3 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -92,11 +92,11 @@ class StateStore(SQLBaseStore): defer.returnValue(dict(state_list)) @cached(num_args=1) - def _fetch_events_for_group(self, state_group, events): + def _fetch_events_for_group(self, key, events): return self._get_events( events, get_prev_content=False ).addCallback( - lambda evs: (state_group, evs) + lambda evs: (key, evs) ) def _store_state_groups_txn(self, txn, event, context): @@ -194,6 +194,65 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids): + def f(txn): + groups = set() + event_to_group = {} + for event_id in event_ids: + # TODO: Remove this loop. + group = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + event_to_group[event_id] = group + groups.add(group) + + group_to_state_ids = {} + for group in groups: + state_ids = self._simple_select_onecol_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + + group_to_state_ids[group] = state_ids + + return event_to_group, group_to_state_ids + + res = yield self.runInteraction( + "annotate_events_with_state_groups", + f, + ) + + event_to_group, group_to_state_ids = res + + state_list = yield defer.gatherResults( + [ + self._fetch_events_for_group(group, vals) + for group, vals in group_to_state_ids.items() + ], + consumeErrors=True, + ) + + state_dict = { + group: { + (ev.type, ev.state_key): ev + for ev in state + } + for group, state in state_list + } + + defer.returnValue([ + state_dict.get(event_to_group.get(event, None), None) + for event in event_ids + ]) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) -- cgit 1.4.1