summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py292
1 files changed, 204 insertions, 88 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 46ce3699d7..f7155fd8d3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
 )
 from synapse.types import UserID
 
+from synapse.events.utils import prune_event
+
 from synapse.util.retryutils import NotRetryingDestination
 
 from twisted.internet import defer
@@ -138,26 +140,29 @@ class FederationHandler(BaseHandler):
         if state and auth_chain is not None:
             # If we have any state or auth_chain given to us by the replication
             # layer, then we should handle them (if we haven't before.)
+
+            event_infos = []
+
             for e in itertools.chain(auth_chain, state):
                 if e.event_id in seen_ids:
                     continue
-
                 e.internal_metadata.outlier = True
-                try:
-                    auth_ids = [e_id for e_id, _ in e.auth_events]
-                    auth = {
-                        (e.type, e.state_key): e for e in auth_chain
-                        if e.event_id in auth_ids
-                    }
-                    yield self._handle_new_event(
-                        origin, e, auth_events=auth
-                    )
-                    seen_ids.add(e.event_id)
-                except:
-                    logger.exception(
-                        "Failed to handle state event %s",
-                        e.event_id,
-                    )
+                auth_ids = [e_id for e_id, _ in e.auth_events]
+                auth = {
+                    (e.type, e.state_key): e for e in auth_chain
+                    if e.event_id in auth_ids
+                }
+                event_infos.append({
+                    "event": e,
+                    "auth_events": auth,
+                })
+                seen_ids.add(e.event_id)
+
+            yield self._handle_new_events(
+                origin,
+                event_infos,
+                outliers=True
+            )
 
         try:
             _, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -222,6 +227,56 @@ class FederationHandler(BaseHandler):
                     "user_joined_room", user=user, room_id=event.room_id
                 )
 
+    @defer.inlineCallbacks
+    def _filter_events_for_server(self, server_name, room_id, events):
+        states = yield self.store.get_state_for_events(
+            room_id, [e.event_id for e in events],
+        )
+
+        events_and_states = zip(events, states)
+
+        def redact_disallowed(event_and_state):
+            event, state = event_and_state
+
+            if not state:
+                return event
+
+            history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+            if history:
+                visibility = history.content.get("history_visibility", "shared")
+                if visibility in ["invited", "joined"]:
+                    # We now loop through all state events looking for
+                    # membership states for the requesting server to determine
+                    # if the server is either in the room or has been invited
+                    # into the room.
+                    for ev in state.values():
+                        if ev.type != EventTypes.Member:
+                            continue
+                        try:
+                            domain = UserID.from_string(ev.state_key).domain
+                        except:
+                            continue
+
+                        if domain != server_name:
+                            continue
+
+                        memtype = ev.membership
+                        if memtype == Membership.JOIN:
+                            return event
+                        elif memtype == Membership.INVITE:
+                            if visibility == "invited":
+                                return event
+                    else:
+                        return prune_event(event)
+
+            return event
+
+        res = map(redact_disallowed, events_and_states)
+
+        logger.info("_filter_events_for_server %r", res)
+
+        defer.returnValue(res)
+
     @log_function
     @defer.inlineCallbacks
     def backfill(self, dest, room_id, limit, extremities=[]):
@@ -247,9 +302,15 @@ class FederationHandler(BaseHandler):
             if set(e_id for e_id, _ in ev.prev_events) - event_ids
         ]
 
+        logger.info(
+            "backfill: Got %d events with %d edges",
+            len(events), len(edges),
+        )
+
         # For each edge get the current state.
 
         auth_events = {}
+        state_events = {}
         events_to_state = {}
         for e_id in edges:
             state, auth = yield self.replication_layer.get_state_for_room(
@@ -258,27 +319,57 @@ class FederationHandler(BaseHandler):
                 event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
+            auth_events.update({s.event_id: s for s in state})
+            state_events.update({s.event_id: s for s in state})
             events_to_state[e_id] = state
 
-        yield defer.gatherResults(
-            [
-                self._handle_new_event(dest, a)
-                for a in auth_events.values()
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        seen_events = yield self.store.have_events(
+            set(auth_events.keys()) | set(state_events.keys())
+        )
+
+        all_events = events + state_events.values() + auth_events.values()
+        required_auth = set(
+            a_id for event in all_events for a_id, _ in event.auth_events
+        )
 
-        yield defer.gatherResults(
+        missing_auth = required_auth - set(auth_events)
+        results = yield defer.gatherResults(
             [
-                self._handle_new_event(
-                    dest, event_map[e_id],
-                    state=events_to_state[e_id],
-                    backfilled=True,
+                self.replication_layer.get_pdu(
+                    [dest],
+                    event_id,
+                    outlier=True,
+                    timeout=10000,
                 )
-                for e_id in events_to_state
+                for event_id in missing_auth
             ],
             consumeErrors=True
         ).addErrback(unwrapFirstError)
+        auth_events.update({a.event_id: a for a in results})
+
+        ev_infos = []
+        for a in auth_events.values():
+            if a.event_id in seen_events:
+                continue
+            ev_infos.append({
+                "event": a,
+                "auth_events": {
+                    (auth_events[a_id].type, auth_events[a_id].state_key):
+                    auth_events[a_id]
+                    for a_id, _ in a.auth_events
+                }
+            })
+
+        for e_id in events_to_state:
+            ev_infos.append({
+                "event": event_map[e_id],
+                "state": events_to_state[e_id],
+                "auth_events": {
+                    (auth_events[a_id].type, auth_events[a_id].state_key):
+                    auth_events[a_id]
+                    for a_id, _ in event_map[e_id].auth_events
+                }
+            })
 
         events.sort(key=lambda e: e.depth)
 
@@ -286,10 +377,14 @@ class FederationHandler(BaseHandler):
             if event in events_to_state:
                 continue
 
-            yield self._handle_new_event(
-                dest, event,
-                backfilled=True,
-            )
+            ev_infos.append({
+                "event": event,
+            })
+
+        yield self._handle_new_events(
+            dest, ev_infos,
+            backfilled=True,
+        )
 
         defer.returnValue(events)
 
@@ -555,32 +650,22 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            yield self._handle_auth_events(
-                origin, [e for e in auth_chain if e.event_id != event.event_id]
-            )
-
-            @defer.inlineCallbacks
-            def handle_state(e):
+            ev_infos = []
+            for e in itertools.chain(state, auth_chain):
                 if e.event_id == event.event_id:
-                    return
+                    continue
 
                 e.internal_metadata.outlier = True
-                try:
-                    auth_ids = [e_id for e_id, _ in e.auth_events]
-                    auth = {
+                auth_ids = [e_id for e_id, _ in e.auth_events]
+                ev_infos.append({
+                    "event": e,
+                    "auth_events": {
                         (e.type, e.state_key): e for e in auth_chain
                         if e.event_id in auth_ids
                     }
-                    yield self._handle_new_event(
-                        origin, e, auth_events=auth
-                    )
-                except:
-                    logger.exception(
-                        "Failed to handle state event %s",
-                        e.event_id,
-                    )
+                })
 
-            yield defer.DeferredList([handle_state(e) for e in state])
+            yield self._handle_new_events(origin, ev_infos, outliers=True)
 
             auth_ids = [e_id for e_id, _ in event.auth_events]
             auth_events = {
@@ -837,6 +922,8 @@ class FederationHandler(BaseHandler):
             limit
         )
 
+        events = yield self._filter_events_for_server(origin, room_id, events)
+
         defer.returnValue(events)
 
     @defer.inlineCallbacks
@@ -895,24 +982,62 @@ class FederationHandler(BaseHandler):
     def _handle_new_event(self, origin, event, state=None, backfilled=False,
                           current_state=None, auth_events=None):
 
-        logger.debug(
-            "_handle_new_event: %s, sigs: %s",
-            event.event_id, event.signatures,
+        outlier = event.internal_metadata.is_outlier()
+
+        context = yield self._prep_event(
+            origin, event,
+            state=state,
+            backfilled=backfilled,
+            current_state=current_state,
+            auth_events=auth_events,
         )
 
-        context = yield self.state_handler.compute_event_context(
-            event, old_state=state
+        event_stream_id, max_stream_id = yield self.store.persist_event(
+            event,
+            context=context,
+            backfilled=backfilled,
+            is_new_state=(not outlier and not backfilled),
+            current_state=current_state,
         )
 
-        if not auth_events:
-            auth_events = context.current_state
+        defer.returnValue((context, event_stream_id, max_stream_id))
 
-        logger.debug(
-            "_handle_new_event: %s, auth_events: %s",
-            event.event_id, auth_events,
+    @defer.inlineCallbacks
+    def _handle_new_events(self, origin, event_infos, backfilled=False,
+                           outliers=False):
+        contexts = yield defer.gatherResults(
+            [
+                self._prep_event(
+                    origin,
+                    ev_info["event"],
+                    state=ev_info.get("state"),
+                    backfilled=backfilled,
+                    auth_events=ev_info.get("auth_events"),
+                )
+                for ev_info in event_infos
+            ]
         )
 
-        is_new_state = not event.internal_metadata.is_outlier()
+        yield self.store.persist_events(
+            [
+                (ev_info["event"], context)
+                for ev_info, context in itertools.izip(event_infos, contexts)
+            ],
+            backfilled=backfilled,
+            is_new_state=(not outliers and not backfilled),
+        )
+
+    @defer.inlineCallbacks
+    def _prep_event(self, origin, event, state=None, backfilled=False,
+                    current_state=None, auth_events=None):
+        outlier = event.internal_metadata.is_outlier()
+
+        context = yield self.state_handler.compute_event_context(
+            event, old_state=state, outlier=outlier,
+        )
+
+        if not auth_events:
+            auth_events = context.current_state
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
@@ -937,26 +1062,7 @@ class FederationHandler(BaseHandler):
 
             context.rejected = RejectedReason.AUTH_ERROR
 
-            # FIXME: Don't store as rejected with AUTH_ERROR if we haven't
-            # seen all the auth events.
-            yield self.store.persist_event(
-                event,
-                context=context,
-                backfilled=backfilled,
-                is_new_state=False,
-                current_state=current_state,
-            )
-            raise
-
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event,
-            context=context,
-            backfilled=backfilled,
-            is_new_state=(is_new_state and not backfilled),
-            current_state=current_state,
-        )
-
-        defer.returnValue((context, event_stream_id, max_stream_id))
+        defer.returnValue(context)
 
     @defer.inlineCallbacks
     def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
@@ -1019,14 +1125,24 @@ class FederationHandler(BaseHandler):
     @log_function
     def do_auth(self, origin, event, context, auth_events):
         # Check if we have all the auth events.
-        have_events = yield self.store.have_events(
-            [e_id for e_id, _ in event.auth_events]
-        )
-
+        current_state = set(e.event_id for e in auth_events.values())
         event_auth_events = set(e_id for e_id, _ in event.auth_events)
+
+        if event_auth_events - current_state:
+            have_events = yield self.store.have_events(
+                event_auth_events - current_state
+            )
+        else:
+            have_events = {}
+
+        have_events.update({
+            e.event_id: ""
+            for e in auth_events.values()
+        })
+
         seen_events = set(have_events.keys())
 
-        missing_auth = event_auth_events - seen_events
+        missing_auth = event_auth_events - seen_events - current_state
 
         if missing_auth:
             logger.info("Missing auth: %s", missing_auth)