summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-07-10 13:47:04 +0100
committerErik Johnston <erik@matrix.org>2015-07-10 13:47:04 +0100
commit532fcc997ac79e1e9c0d888b9239b1fb06de25e2 (patch)
tree8d74fd9e2c1c525a345d1aef42afa66829ad8e1e
parentMerge pull request #194 from matrix-org/erikj/bulk_verify_sigs (diff)
parentAdd comment (diff)
downloadsynapse-532fcc997ac79e1e9c0d888b9239b1fb06de25e2.tar.xz
Merge pull request #196 from matrix-org/erikj/room_history
Add ability to restrict room history.
Diffstat (limited to '')
-rw-r--r--synapse/api/auth.py3
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/handlers/federation.py54
-rw-r--r--synapse/handlers/message.py66
-rw-r--r--synapse/handlers/room.py1
-rw-r--r--synapse/handlers/sync.py48
-rw-r--r--synapse/storage/state.py63
8 files changed, 235 insertions, 4 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 4da62e5d8d..1a25bf1086 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 AuthEventTypes = (
     EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
-    EventTypes.JoinRules,
+    EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
 )
 
 
@@ -575,6 +575,7 @@ class Auth(object):
         levels_to_check = [
             ("users_default", []),
             ("events_default", []),
+            ("state_default", []),
             ("ban", []),
             ("redact", []),
             ("kick", []),
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index d8a18ee87b..3e15e8a9d7 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -75,6 +75,8 @@ class EventTypes(object):
     Redaction = "m.room.redaction"
     Feedback = "m.room.message.feedback"
 
+    RoomHistoryVisibility = "m.room.history_visibility"
+
     # These are used for validation
     Message = "m.room.message"
     Topic = "m.room.topic"
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 1aa952150e..7bd78343f0 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -74,6 +74,8 @@ def prune_event(event):
         )
     elif event_type == EventTypes.Aliases:
         add_fields("aliases")
+    elif event_type == EventTypes.RoomHistoryVisibility:
+        add_fields("history_visibility")
 
     allowed_fields = {
         k: v
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b5d882fd65..d7f197f247 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
 )
 from synapse.types import UserID
 
+from synapse.events.utils import prune_event
+
 from synapse.util.retryutils import NotRetryingDestination
 
 from twisted.internet import defer
@@ -222,6 +224,56 @@ class FederationHandler(BaseHandler):
                     "user_joined_room", user=user, room_id=event.room_id
                 )
 
+    @defer.inlineCallbacks
+    def _filter_events_for_server(self, server_name, room_id, events):
+        states = yield self.store.get_state_for_events(
+            room_id, [e.event_id for e in events],
+        )
+
+        events_and_states = zip(events, states)
+
+        def redact_disallowed(event_and_state):
+            event, state = event_and_state
+
+            if not state:
+                return event
+
+            history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+            if history:
+                visibility = history.content.get("history_visibility", "shared")
+                if visibility in ["invited", "joined"]:
+                    # We now loop through all state events looking for
+                    # membership states for the requesting server to determine
+                    # if the server is either in the room or has been invited
+                    # into the room.
+                    for ev in state.values():
+                        if ev.type != EventTypes.Member:
+                            continue
+                        try:
+                            domain = UserID.from_string(ev.state_key).domain
+                        except:
+                            continue
+
+                        if domain != server_name:
+                            continue
+
+                        memtype = ev.membership
+                        if memtype == Membership.JOIN:
+                            return event
+                        elif memtype == Membership.INVITE:
+                            if visibility == "invited":
+                                return event
+                    else:
+                        return prune_event(event)
+
+            return event
+
+        res = map(redact_disallowed, events_and_states)
+
+        logger.info("_filter_events_for_server %r", res)
+
+        defer.returnValue(res)
+
     @log_function
     @defer.inlineCallbacks
     def backfill(self, dest, room_id, limit, extremities=[]):
@@ -882,6 +934,8 @@ class FederationHandler(BaseHandler):
             limit
         )
 
+        events = yield self._filter_events_for_server(origin, room_id, events)
+
         defer.returnValue(events)
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e324662f18..d8b117612d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -113,11 +113,21 @@ class MessageHandler(BaseHandler):
             "room_key", next_key
         )
 
+        if not events:
+            defer.returnValue({
+                "chunk": [],
+                "start": pagin_config.from_token.to_string(),
+                "end": next_token.to_string(),
+            })
+
+        events = yield self._filter_events_for_client(user_id, room_id, events)
+
         time_now = self.clock.time_msec()
 
         chunk = {
             "chunk": [
-                serialize_event(e, time_now, as_client_event) for e in events
+                serialize_event(e, time_now, as_client_event)
+                for e in events
             ],
             "start": pagin_config.from_token.to_string(),
             "end": next_token.to_string(),
@@ -126,6 +136,52 @@ class MessageHandler(BaseHandler):
         defer.returnValue(chunk)
 
     @defer.inlineCallbacks
+    def _filter_events_for_client(self, user_id, room_id, events):
+        states = yield self.store.get_state_for_events(
+            room_id, [e.event_id for e in events],
+        )
+
+        events_and_states = zip(events, states)
+
+        def allowed(event_and_state):
+            event, state = event_and_state
+
+            if event.type == EventTypes.RoomHistoryVisibility:
+                return True
+
+            membership_ev = state.get((EventTypes.Member, user_id), None)
+            if membership_ev:
+                membership = membership_ev.membership
+            else:
+                membership = Membership.LEAVE
+
+            if membership == Membership.JOIN:
+                return True
+
+            history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+            if history:
+                visibility = history.content.get("history_visibility", "shared")
+            else:
+                visibility = "shared"
+
+            if visibility == "public":
+                return True
+            elif visibility == "shared":
+                return True
+            elif visibility == "joined":
+                return membership == Membership.JOIN
+            elif visibility == "invited":
+                return membership == Membership.INVITE
+
+            return True
+
+        events_and_states = filter(allowed, events_and_states)
+        defer.returnValue([
+            ev
+            for ev, _ in events_and_states
+        ])
+
+    @defer.inlineCallbacks
     def create_and_send_event(self, event_dict, ratelimit=True,
                               client=None, txn_id=None):
         """ Given a dict from a client, create and handle a new event.
@@ -316,6 +372,10 @@ class MessageHandler(BaseHandler):
                     ]
                 ).addErrback(unwrapFirstError)
 
+                messages = yield self._filter_events_for_client(
+                    user_id, event.room_id, messages
+                )
+
                 start_token = now_token.copy_and_replace("room_key", token[0])
                 end_token = now_token.copy_and_replace("room_key", token[1])
                 time_now = self.clock.time_msec()
@@ -417,6 +477,10 @@ class MessageHandler(BaseHandler):
             consumeErrors=True,
         ).addErrback(unwrapFirstError)
 
+        messages = yield self._filter_events_for_client(
+            user_id, room_id, messages
+        )
+
         start_token = now_token.copy_and_replace("room_key", token[0])
         end_token = now_token.copy_and_replace("room_key", token[1])
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 4bd027d9bb..891707df44 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -213,6 +213,7 @@ class RoomCreationHandler(BaseHandler):
                 "events": {
                     EventTypes.Name: 100,
                     EventTypes.PowerLevels: 100,
+                    EventTypes.RoomHistoryVisibility: 100,
                 },
                 "events_default": 0,
                 "state_default": 50,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bd8c603681..6cff6230c1 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -293,6 +293,51 @@ class SyncHandler(BaseHandler):
         ))
 
     @defer.inlineCallbacks
+    def _filter_events_for_client(self, user_id, room_id, events):
+        states = yield self.store.get_state_for_events(
+            room_id, [e.event_id for e in events],
+        )
+
+        events_and_states = zip(events, states)
+
+        def allowed(event_and_state):
+            event, state = event_and_state
+
+            if event.type == EventTypes.RoomHistoryVisibility:
+                return True
+
+            membership_ev = state.get((EventTypes.Member, user_id), None)
+            if membership_ev:
+                membership = membership_ev.membership
+            else:
+                membership = Membership.LEAVE
+
+            if membership == Membership.JOIN:
+                return True
+
+            history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+            if history:
+                visibility = history.content.get("history_visibility", "shared")
+            else:
+                visibility = "shared"
+
+            if visibility == "public":
+                return True
+            elif visibility == "shared":
+                return True
+            elif visibility == "joined":
+                return membership == Membership.JOIN
+            elif visibility == "invited":
+                return membership == Membership.INVITE
+
+            return True
+        events_and_states = filter(allowed, events_and_states)
+        defer.returnValue([
+            ev
+            for ev, _ in events_and_states
+        ])
+
+    @defer.inlineCallbacks
     def load_filtered_recents(self, room_id, sync_config, now_token,
                               since_token=None):
         limited = True
@@ -313,6 +358,9 @@ class SyncHandler(BaseHandler):
             (room_key, _) = keys
             end_key = "s" + room_key.split('-')[-1]
             loaded_recents = sync_config.filter.filter_room_events(events)
+            loaded_recents = yield self._filter_events_for_client(
+                sync_config.user.to_string(), room_id, loaded_recents,
+            )
             loaded_recents.extend(recents)
             recents = loaded_recents
             if len(events) <= load_limit:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f2b17f29ea..d7844edee3 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -92,11 +92,11 @@ class StateStore(SQLBaseStore):
         defer.returnValue(dict(state_list))
 
     @cached(num_args=1)
-    def _fetch_events_for_group(self, state_group, events):
+    def _fetch_events_for_group(self, key, events):
         return self._get_events(
             events, get_prev_content=False
         ).addCallback(
-            lambda evs: (state_group, evs)
+            lambda evs: (key, evs)
         )
 
     def _store_state_groups_txn(self, txn, event, context):
@@ -194,6 +194,65 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
+    @defer.inlineCallbacks
+    def get_state_for_events(self, room_id, event_ids):
+        def f(txn):
+            groups = set()
+            event_to_group = {}
+            for event_id in event_ids:
+                # TODO: Remove this loop.
+                group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="event_to_state_groups",
+                    keyvalues={"event_id": event_id},
+                    retcol="state_group",
+                    allow_none=True,
+                )
+                if group:
+                    event_to_group[event_id] = group
+                    groups.add(group)
+
+            group_to_state_ids = {}
+            for group in groups:
+                state_ids = self._simple_select_onecol_txn(
+                    txn,
+                    table="state_groups_state",
+                    keyvalues={"state_group": group},
+                    retcol="event_id",
+                )
+
+                group_to_state_ids[group] = state_ids
+
+            return event_to_group, group_to_state_ids
+
+        res = yield self.runInteraction(
+            "annotate_events_with_state_groups",
+            f,
+        )
+
+        event_to_group, group_to_state_ids = res
+
+        state_list = yield defer.gatherResults(
+            [
+                self._fetch_events_for_group(group, vals)
+                for group, vals in group_to_state_ids.items()
+            ],
+            consumeErrors=True,
+        )
+
+        state_dict = {
+            group: {
+                (ev.type, ev.state_key): ev
+                for ev in state
+            }
+            for group, state in state_list
+        }
+
+        defer.returnValue([
+            state_dict.get(event_to_group.get(event, None), None)
+            for event in event_ids
+        ])
+
 
 def _make_group_id(clock):
     return str(int(clock.time_msec())) + random_string(5)