summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py324
1 files changed, 152 insertions, 172 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 60391d07c4..533b82c783 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,8 +21,8 @@ import logging
 import sys
 
 import six
-from six import iteritems
-from six.moves import http_client
+from six import iteritems, itervalues
+from six.moves import http_client, zip
 
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
@@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
         self.hs = hs
 
         self.store = hs.get_datastore()
-        self.replication_layer = hs.get_federation_client()
+        self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
@@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         state, got_auth_chain = (
-                            yield self.replication_layer.get_state_for_room(
+                            yield self.federation_client.get_state_for_room(
                                 origin, pdu.room_id, p
                             )
                         )
@@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
         #
         # see https://github.com/matrix-org/synapse/pull/1744
 
-        missing_events = yield self.replication_layer.get_missing_events(
+        missing_events = yield self.federation_client.get_missing_events(
             origin,
             pdu.room_id,
             earliest_events_ids=list(latest),
@@ -400,7 +400,7 @@ class FederationHandler(BaseHandler):
             )
 
             try:
-                event_stream_id, max_stream_id = yield self._persist_auth_tree(
+                yield self._persist_auth_tree(
                     origin, auth_chain, state, event
                 )
             except AuthError as e:
@@ -444,7 +444,7 @@ class FederationHandler(BaseHandler):
                 yield self._handle_new_events(origin, event_infos)
 
             try:
-                context, event_stream_id, max_stream_id = yield self._handle_new_event(
+                context = yield self._handle_new_event(
                     origin,
                     event,
                     state=state,
@@ -469,24 +469,16 @@ class FederationHandler(BaseHandler):
             except StoreError:
                 logger.exception("Failed to store room.")
 
-        extra_users = []
-        if event.type == EventTypes.Member:
-            target_user_id = event.state_key
-            target_user = UserID.from_string(target_user_id)
-            extra_users.append(target_user)
-
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id,
-            extra_users=extra_users
-        )
-
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 # Only fire user_joined_room if the user has acutally
                 # joined the room. Don't bother if the user is just
                 # changing their profile info.
                 newly_joined = True
-                prev_state_id = context.prev_state_ids.get(
+
+                prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+                prev_state_id = prev_state_ids.get(
                     (event.type, event.state_key)
                 )
                 if prev_state_id:
@@ -498,7 +490,7 @@ class FederationHandler(BaseHandler):
 
                 if newly_joined:
                     user = UserID.from_string(event.state_key)
-                    yield user_joined_room(self.distributor, user, event.room_id)
+                    yield self.user_joined_room(user, event.room_id)
 
     @log_function
     @defer.inlineCallbacks
@@ -519,7 +511,7 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        events = yield self.replication_layer.backfill(
+        events = yield self.federation_client.backfill(
             dest,
             room_id,
             limit=limit,
@@ -567,7 +559,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.replication_layer.get_state_for_room(
+            state, auth = yield self.federation_client.get_state_for_room(
                 destination=dest,
                 room_id=room_id,
                 event_id=e_id
@@ -609,7 +601,7 @@ class FederationHandler(BaseHandler):
                 results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
                     [
                         logcontext.run_in_background(
-                            self.replication_layer.get_pdu,
+                            self.federation_client.get_pdu,
                             [dest],
                             event_id,
                             outlier=True,
@@ -731,7 +723,7 @@ class FederationHandler(BaseHandler):
             """
             joined_users = [
                 (state_key, int(event.depth))
-                for (e_type, state_key), event in state.iteritems()
+                for (e_type, state_key), event in iteritems(state)
                 if e_type == EventTypes.Member
                 and event.membership == Membership.JOIN
             ]
@@ -748,7 +740,7 @@ class FederationHandler(BaseHandler):
                 except Exception:
                     pass
 
-            return sorted(joined_domains.iteritems(), key=lambda d: d[1])
+            return sorted(joined_domains.items(), key=lambda d: d[1])
 
         curr_domains = get_domains_from_state(curr_state)
 
@@ -811,7 +803,7 @@ class FederationHandler(BaseHandler):
         tried_domains = set(likely_domains)
         tried_domains.add(self.server_name)
 
-        event_ids = list(extremities.iterkeys())
+        event_ids = list(extremities.keys())
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
         resolve = logcontext.preserve_fn(
@@ -827,15 +819,15 @@ class FederationHandler(BaseHandler):
         states = dict(zip(event_ids, [s.state for s in states]))
 
         state_map = yield self.store.get_events(
-            [e_id for ids in states.itervalues() for e_id in ids.itervalues()],
+            [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
             get_prev_content=False
         )
         states = {
             key: {
                 k: state_map[e_id]
-                for k, e_id in state_dict.iteritems()
+                for k, e_id in iteritems(state_dict)
                 if e_id in state_map
-            } for key, state_dict in states.iteritems()
+            } for key, state_dict in iteritems(states)
         }
 
         for e_id, _ in sorted_extremeties_tuple:
@@ -890,7 +882,7 @@ class FederationHandler(BaseHandler):
 
         Invites must be signed by the invitee's server before distribution.
         """
-        pdu = yield self.replication_layer.send_invite(
+        pdu = yield self.federation_client.send_invite(
             destination=target_host,
             room_id=event.room_id,
             event_id=event.event_id,
@@ -906,16 +898,6 @@ class FederationHandler(BaseHandler):
             [auth_id for auth_id, _ in event.auth_events],
             include_given=True
         )
-
-        for event in auth:
-            event.signatures.update(
-                compute_event_signature(
-                    event,
-                    self.hs.hostname,
-                    self.hs.config.signing_key[0]
-                )
-            )
-
         defer.returnValue([e for e in auth])
 
     @log_function
@@ -949,7 +931,7 @@ class FederationHandler(BaseHandler):
 
         self.room_queues[room_id] = []
 
-        yield self.store.clean_room_for_join(room_id)
+        yield self._clean_room_for_join(room_id)
 
         handled_events = set()
 
@@ -962,7 +944,7 @@ class FederationHandler(BaseHandler):
                 target_hosts.insert(0, origin)
             except ValueError:
                 pass
-            ret = yield self.replication_layer.send_join(target_hosts, event)
+            ret = yield self.federation_client.send_join(target_hosts, event)
 
             origin = ret["origin"]
             state = ret["state"]
@@ -988,15 +970,10 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            event_stream_id, max_stream_id = yield self._persist_auth_tree(
+            yield self._persist_auth_tree(
                 origin, auth_chain, state, event
             )
 
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id,
-                extra_users=[joinee]
-            )
-
             logger.debug("Finished joining %s to %s", joinee, room_id)
         finally:
             room_queue = self.room_queues[room_id]
@@ -1091,7 +1068,7 @@ class FederationHandler(BaseHandler):
         # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
-        context, event_stream_id, max_stream_id = yield self._handle_new_event(
+        context = yield self._handle_new_event(
             origin, event
         )
 
@@ -1101,25 +1078,17 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        extra_users = []
-        if event.type == EventTypes.Member:
-            target_user_id = event.state_key
-            target_user = UserID.from_string(target_user_id)
-            extra_users.append(target_user)
-
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id, extra_users=extra_users
-        )
-
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.JOIN:
                 user = UserID.from_string(event.state_key)
-                yield user_joined_room(self.distributor, user, event.room_id)
+                yield self.user_joined_room(user, event.room_id)
+
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
 
-        state_ids = list(context.prev_state_ids.values())
+        state_ids = list(prev_state_ids.values())
         auth_chain = yield self.store.get_auth_chain(state_ids)
 
-        state = yield self.store.get_events(list(context.prev_state_ids.values()))
+        state = yield self.store.get_events(list(prev_state_ids.values()))
 
         defer.returnValue({
             "state": list(state.values()),
@@ -1181,17 +1150,7 @@ class FederationHandler(BaseHandler):
         )
 
         context = yield self.state_handler.compute_event_context(event)
-
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event,
-            context=context,
-        )
-
-        target_user = UserID.from_string(event.state_key)
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id,
-            extra_users=[target_user],
-        )
+        yield self._persist_events([(event, context)])
 
         defer.returnValue(event)
 
@@ -1216,30 +1175,20 @@ class FederationHandler(BaseHandler):
         except ValueError:
             pass
 
-        yield self.replication_layer.send_leave(
+        yield self.federation_client.send_leave(
             target_hosts,
             event
         )
 
         context = yield self.state_handler.compute_event_context(event)
-
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event,
-            context=context,
-        )
-
-        target_user = UserID.from_string(event.state_key)
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id,
-            extra_users=[target_user],
-        )
+        yield self._persist_events([(event, context)])
 
         defer.returnValue(event)
 
     @defer.inlineCallbacks
     def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
                                content={},):
-        origin, pdu = yield self.replication_layer.make_membership_event(
+        origin, pdu = yield self.federation_client.make_membership_event(
             target_hosts,
             room_id,
             user_id,
@@ -1284,7 +1233,7 @@ class FederationHandler(BaseHandler):
     @log_function
     def on_make_leave_request(self, room_id, user_id):
         """ We've received a /make_leave/ request, so we create a partial
-        join event for the room and return that. We do *not* persist or
+        leave event for the room and return that. We do *not* persist or
         process it until the other server has signed it and sent it back.
         """
         builder = self.event_builder_factory.new({
@@ -1323,7 +1272,7 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context, event_stream_id, max_stream_id = yield self._handle_new_event(
+        yield self._handle_new_event(
             origin, event
         )
 
@@ -1333,16 +1282,6 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        extra_users = []
-        if event.type == EventTypes.Member:
-            target_user_id = event.state_key
-            target_user = UserID.from_string(target_user_id)
-            extra_users.append(target_user)
-
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id, extra_users=extra_users
-        )
-
         defer.returnValue(None)
 
     @defer.inlineCallbacks
@@ -1375,18 +1314,6 @@ class FederationHandler(BaseHandler):
                     del results[(event.type, event.state_key)]
 
             res = list(results.values())
-            for event in res:
-                # We sign these again because there was a bug where we
-                # incorrectly signed things the first time round
-                if self.is_mine_id(event.event_id):
-                    event.signatures.update(
-                        compute_event_signature(
-                            event,
-                            self.hs.hostname,
-                            self.hs.config.signing_key[0]
-                        )
-                    )
-
             defer.returnValue(res)
         else:
             defer.returnValue([])
@@ -1461,18 +1388,6 @@ class FederationHandler(BaseHandler):
         )
 
         if event:
-            if self.is_mine_id(event.event_id):
-                # FIXME: This is a temporary work around where we occasionally
-                # return events slightly differently than when they were
-                # originally signed
-                event.signatures.update(
-                    compute_event_signature(
-                        event,
-                        self.hs.hostname,
-                        self.hs.config.signing_key[0]
-                    )
-                )
-
             in_room = yield self.auth.check_host_in_room(
                 event.room_id,
                 origin
@@ -1508,9 +1423,8 @@ class FederationHandler(BaseHandler):
                     event, context
                 )
 
-            event_stream_id, max_stream_id = yield self.store.persist_event(
-                event,
-                context=context,
+            yield self._persist_events(
+                [(event, context)],
                 backfilled=backfilled,
             )
         except:  # noqa: E722, as we reraise the exception this is fine.
@@ -1523,15 +1437,7 @@ class FederationHandler(BaseHandler):
 
             six.reraise(tp, value, tb)
 
-        if not backfilled:
-            # this intentionally does not yield: we don't care about the result
-            # and don't need to wait for it.
-            logcontext.run_in_background(
-                self.pusher_pool.on_new_notifications,
-                event_stream_id, max_stream_id,
-            )
-
-        defer.returnValue((context, event_stream_id, max_stream_id))
+        defer.returnValue(context)
 
     @defer.inlineCallbacks
     def _handle_new_events(self, origin, event_infos, backfilled=False):
@@ -1539,6 +1445,8 @@ class FederationHandler(BaseHandler):
         should not depend on one another, e.g. this should be used to persist
         a bunch of outliers, but not a chunk of individual events that depend
         on each other for state calculations.
+
+        Notifies about the events where appropriate.
         """
         contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
@@ -1553,10 +1461,10 @@ class FederationHandler(BaseHandler):
             ], consumeErrors=True,
         ))
 
-        yield self.store.persist_events(
+        yield self._persist_events(
             [
                 (ev_info["event"], context)
-                for ev_info, context in itertools.izip(event_infos, contexts)
+                for ev_info, context in zip(event_infos, contexts)
             ],
             backfilled=backfilled,
         )
@@ -1565,7 +1473,8 @@ class FederationHandler(BaseHandler):
     def _persist_auth_tree(self, origin, auth_events, state, event):
         """Checks the auth chain is valid (and passes auth checks) for the
         state and event. Then persists the auth chain and state atomically.
-        Persists the event seperately.
+        Persists the event separately. Notifies about the persisted events
+        where appropriate.
 
         Will attempt to fetch missing auth events.
 
@@ -1576,8 +1485,7 @@ class FederationHandler(BaseHandler):
             event (Event)
 
         Returns:
-            2-tuple of (event_stream_id, max_stream_id) from the persist_event
-            call for `event`
+            Deferred
         """
         events_to_context = {}
         for e in itertools.chain(auth_events, state):
@@ -1603,7 +1511,7 @@ class FederationHandler(BaseHandler):
                     missing_auth_events.add(e_id)
 
         for e_id in missing_auth_events:
-            m_ev = yield self.replication_layer.get_pdu(
+            m_ev = yield self.federation_client.get_pdu(
                 [origin],
                 e_id,
                 outlier=True,
@@ -1641,7 +1549,7 @@ class FederationHandler(BaseHandler):
                     raise
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
-        yield self.store.persist_events(
+        yield self._persist_events(
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
@@ -1652,12 +1560,10 @@ class FederationHandler(BaseHandler):
             event, old_state=state
         )
 
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event, new_event_context,
+        yield self._persist_events(
+            [(event, new_event_context)],
         )
 
-        defer.returnValue((event_stream_id, max_stream_id))
-
     @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, auth_events=None):
         """
@@ -1676,8 +1582,9 @@ class FederationHandler(BaseHandler):
         )
 
         if not auth_events:
+            prev_state_ids = yield context.get_prev_state_ids(self.store)
             auth_events_ids = yield self.auth.compute_auth_events(
-                event, context.prev_state_ids, for_verification=True,
+                event, prev_state_ids, for_verification=True,
             )
             auth_events = yield self.store.get_events(auth_events_ids)
             auth_events = {
@@ -1747,15 +1654,6 @@ class FederationHandler(BaseHandler):
             local_auth_chain, remote_auth_chain
         )
 
-        for event in ret["auth_chain"]:
-            event.signatures.update(
-                compute_event_signature(
-                    event,
-                    self.hs.hostname,
-                    self.hs.config.signing_key[0]
-                )
-            )
-
         logger.debug("on_query_auth returning: %s", ret)
 
         defer.returnValue(ret)
@@ -1831,7 +1729,7 @@ class FederationHandler(BaseHandler):
             logger.info("Missing auth: %s", missing_auth)
             # If we don't have all the auth events, we need to get them.
             try:
-                remote_auth_chain = yield self.replication_layer.get_event_auth(
+                remote_auth_chain = yield self.federation_client.get_event_auth(
                     origin, event.room_id, event.event_id
                 )
 
@@ -1936,9 +1834,10 @@ class FederationHandler(BaseHandler):
                         break
 
             if do_resolution:
+                prev_state_ids = yield context.get_prev_state_ids(self.store)
                 # 1. Get what we think is the auth chain.
                 auth_ids = yield self.auth.compute_auth_events(
-                    event, context.prev_state_ids
+                    event, prev_state_ids
                 )
                 local_auth_chain = yield self.store.get_auth_chain(
                     auth_ids, include_given=True
@@ -1946,7 +1845,7 @@ class FederationHandler(BaseHandler):
 
                 try:
                     # 2. Get remote difference.
-                    result = yield self.replication_layer.query_auth(
+                    result = yield self.federation_client.query_auth(
                         origin,
                         event.room_id,
                         event.event_id,
@@ -2028,21 +1927,34 @@ class FederationHandler(BaseHandler):
             k: a.event_id for k, a in iteritems(auth_events)
             if k != event_key
         }
-        context.current_state_ids = dict(context.current_state_ids)
-        context.current_state_ids.update(state_updates)
-        if context.delta_ids is not None:
-            context.delta_ids = dict(context.delta_ids)
-            context.delta_ids.update(state_updates)
-        context.prev_state_ids = dict(context.prev_state_ids)
-        context.prev_state_ids.update({
+        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = dict(current_state_ids)
+
+        current_state_ids.update(state_updates)
+
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = dict(prev_state_ids)
+
+        prev_state_ids.update({
             k: a.event_id for k, a in iteritems(auth_events)
         })
-        context.state_group = yield self.store.store_state_group(
+
+        # create a new state group as a delta from the existing one.
+        prev_group = context.state_group
+        state_group = yield self.store.store_state_group(
             event.event_id,
             event.room_id,
-            prev_group=context.prev_group,
-            delta_ids=context.delta_ids,
-            current_state_ids=context.current_state_ids,
+            prev_group=prev_group,
+            delta_ids=state_updates,
+            current_state_ids=current_state_ids,
+        )
+
+        yield context.update_state(
+            state_group=state_group,
+            current_state_ids=current_state_ids,
+            prev_state_ids=prev_state_ids,
+            prev_group=prev_group,
+            delta_ids=state_updates,
         )
 
     @defer.inlineCallbacks
@@ -2232,7 +2144,7 @@ class FederationHandler(BaseHandler):
             yield member_handler.send_membership_event(None, event, context)
         else:
             destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
-            yield self.replication_layer.forward_third_party_invite(
+            yield self.federation_client.forward_third_party_invite(
                 destinations,
                 room_id,
                 event_dict,
@@ -2282,7 +2194,8 @@ class FederationHandler(BaseHandler):
             event.content["third_party_invite"]["signed"]["token"]
         )
         original_invite = None
-        original_invite_id = context.prev_state_ids.get(key)
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        original_invite_id = prev_state_ids.get(key)
         if original_invite_id:
             original_invite = yield self.store.get_event(
                 original_invite_id, allow_none=True
@@ -2324,7 +2237,8 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        invite_event_id = context.prev_state_ids.get(
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        invite_event_id = prev_state_ids.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
 
@@ -2385,3 +2299,69 @@ class FederationHandler(BaseHandler):
             )
         if "valid" not in response or not response["valid"]:
             raise AuthError(403, "Third party certificate was invalid")
+
+    @defer.inlineCallbacks
+    def _persist_events(self, event_and_contexts, backfilled=False):
+        """Persists events and tells the notifier/pushers about them, if
+        necessary.
+
+        Args:
+            event_and_contexts(list[tuple[FrozenEvent, EventContext]])
+            backfilled (bool): Whether these events are a result of
+                backfilling or not
+
+        Returns:
+            Deferred
+        """
+        max_stream_id = yield self.store.persist_events(
+            event_and_contexts,
+            backfilled=backfilled,
+        )
+
+        if not backfilled:  # Never notify for backfilled events
+            for event, _ in event_and_contexts:
+                self._notify_persisted_event(event, max_stream_id)
+
+    def _notify_persisted_event(self, event, max_stream_id):
+        """Checks to see if notifier/pushers should be notified about the
+        event or not.
+
+        Args:
+            event (FrozenEvent)
+            max_stream_id (int): The max_stream_id returned by persist_events
+        """
+
+        extra_users = []
+        if event.type == EventTypes.Member:
+            target_user_id = event.state_key
+
+            # We notify for memberships if its an invite for one of our
+            # users
+            if event.internal_metadata.is_outlier():
+                if event.membership != Membership.INVITE:
+                    if not self.is_mine_id(target_user_id):
+                        return
+
+            target_user = UserID.from_string(target_user_id)
+            extra_users.append(target_user)
+        elif event.internal_metadata.is_outlier():
+            return
+
+        event_stream_id = event.internal_metadata.stream_ordering
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id,
+            extra_users=extra_users
+        )
+
+        logcontext.run_in_background(
+            self.pusher_pool.on_new_notifications,
+            event_stream_id, max_stream_id,
+        )
+
+    def _clean_room_for_join(self, room_id):
+        return self.store.clean_room_for_join(room_id)
+
+    def user_joined_room(self, user, room_id):
+        """Called when a new user has joined the room
+        """
+        return user_joined_room(self.distributor, user, room_id)