summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py288
1 files changed, 177 insertions, 111 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 49068c06d9..3dd107a285 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -30,7 +30,12 @@ from unpaddedbase64 import decode_base64
 
 from twisted.internet import defer
 
-from synapse.api.constants import EventTypes, Membership, RejectedReason
+from synapse.api.constants import (
+    KNOWN_ROOM_VERSIONS,
+    EventTypes,
+    Membership,
+    RejectedReason,
+)
 from synapse.api.errors import (
     AuthError,
     CodeMessageException,
@@ -44,10 +49,15 @@ from synapse.crypto.event_signing import (
     compute_event_signature,
 )
 from synapse.events.validator import EventValidator
+from synapse.replication.http.federation import (
+    ReplicationCleanRoomRestServlet,
+    ReplicationFederationSendEventsRestServlet,
+)
+from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import resolve_events_with_factory
 from synapse.types import UserID, get_domain_from_id
 from synapse.util import logcontext, unwrapFirstError
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room
 from synapse.util.frozenutils import unfreeze
 from synapse.util.logutils import log_function
@@ -76,7 +86,7 @@ class FederationHandler(BaseHandler):
         self.hs = hs
 
         self.store = hs.get_datastore()
-        self.replication_layer = hs.get_federation_client()
+        self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
@@ -86,6 +96,18 @@ class FederationHandler(BaseHandler):
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
+        self.config = hs.config
+        self.http_client = hs.get_simple_http_client()
+
+        self._send_events_to_master = (
+            ReplicationFederationSendEventsRestServlet.make_client(hs)
+        )
+        self._notify_user_membership_change = (
+            ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
+        )
+        self._clean_room_for_join_client = (
+            ReplicationCleanRoomRestServlet.make_client(hs)
+        )
 
         # When joining a room we need to queue any events for that room up
         self.room_queues = {}
@@ -255,7 +277,7 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         state, got_auth_chain = (
-                            yield self.replication_layer.get_state_for_room(
+                            yield self.federation_client.get_state_for_room(
                                 origin, pdu.room_id, p
                             )
                         )
@@ -338,7 +360,7 @@ class FederationHandler(BaseHandler):
         #
         # see https://github.com/matrix-org/synapse/pull/1744
 
-        missing_events = yield self.replication_layer.get_missing_events(
+        missing_events = yield self.federation_client.get_missing_events(
             origin,
             pdu.room_id,
             earliest_events_ids=list(latest),
@@ -400,7 +422,7 @@ class FederationHandler(BaseHandler):
             )
 
             try:
-                event_stream_id, max_stream_id = yield self._persist_auth_tree(
+                yield self._persist_auth_tree(
                     origin, auth_chain, state, event
                 )
             except AuthError as e:
@@ -444,7 +466,7 @@ class FederationHandler(BaseHandler):
                 yield self._handle_new_events(origin, event_infos)
 
             try:
-                context, event_stream_id, max_stream_id = yield self._handle_new_event(
+                context = yield self._handle_new_event(
                     origin,
                     event,
                     state=state,
@@ -469,17 +491,6 @@ class FederationHandler(BaseHandler):
             except StoreError:
                 logger.exception("Failed to store room.")
 
-        extra_users = []
-        if event.type == EventTypes.Member:
-            target_user_id = event.state_key
-            target_user = UserID.from_string(target_user_id)
-            extra_users.append(target_user)
-
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id,
-            extra_users=extra_users
-        )
-
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 # Only fire user_joined_room if the user has acutally
@@ -501,7 +512,7 @@ class FederationHandler(BaseHandler):
 
                 if newly_joined:
                     user = UserID.from_string(event.state_key)
-                    yield user_joined_room(self.distributor, user, event.room_id)
+                    yield self.user_joined_room(user, event.room_id)
 
     @log_function
     @defer.inlineCallbacks
@@ -522,7 +533,7 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        events = yield self.replication_layer.backfill(
+        events = yield self.federation_client.backfill(
             dest,
             room_id,
             limit=limit,
@@ -570,7 +581,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.replication_layer.get_state_for_room(
+            state, auth = yield self.federation_client.get_state_for_room(
                 destination=dest,
                 room_id=room_id,
                 event_id=e_id
@@ -612,7 +623,7 @@ class FederationHandler(BaseHandler):
                 results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
                     [
                         logcontext.run_in_background(
-                            self.replication_layer.get_pdu,
+                            self.federation_client.get_pdu,
                             [dest],
                             event_id,
                             outlier=True,
@@ -893,7 +904,7 @@ class FederationHandler(BaseHandler):
 
         Invites must be signed by the invitee's server before distribution.
         """
-        pdu = yield self.replication_layer.send_invite(
+        pdu = yield self.federation_client.send_invite(
             destination=target_host,
             room_id=event.room_id,
             event_id=event.event_id,
@@ -933,6 +944,9 @@ class FederationHandler(BaseHandler):
             joinee,
             "join",
             content,
+            params={
+                "ver": KNOWN_ROOM_VERSIONS,
+            },
         )
 
         # This shouldn't happen, because the RoomMemberHandler has a
@@ -942,7 +956,7 @@ class FederationHandler(BaseHandler):
 
         self.room_queues[room_id] = []
 
-        yield self.store.clean_room_for_join(room_id)
+        yield self._clean_room_for_join(room_id)
 
         handled_events = set()
 
@@ -955,7 +969,7 @@ class FederationHandler(BaseHandler):
                 target_hosts.insert(0, origin)
             except ValueError:
                 pass
-            ret = yield self.replication_layer.send_join(target_hosts, event)
+            ret = yield self.federation_client.send_join(target_hosts, event)
 
             origin = ret["origin"]
             state = ret["state"]
@@ -981,15 +995,10 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            event_stream_id, max_stream_id = yield self._persist_auth_tree(
+            yield self._persist_auth_tree(
                 origin, auth_chain, state, event
             )
 
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id,
-                extra_users=[joinee]
-            )
-
             logger.debug("Finished joining %s to %s", joinee, room_id)
         finally:
             room_queue = self.room_queues[room_id]
@@ -1084,7 +1093,7 @@ class FederationHandler(BaseHandler):
         # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
-        context, event_stream_id, max_stream_id = yield self._handle_new_event(
+        context = yield self._handle_new_event(
             origin, event
         )
 
@@ -1094,20 +1103,10 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        extra_users = []
-        if event.type == EventTypes.Member:
-            target_user_id = event.state_key
-            target_user = UserID.from_string(target_user_id)
-            extra_users.append(target_user)
-
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id, extra_users=extra_users
-        )
-
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.JOIN:
                 user = UserID.from_string(event.state_key)
-                yield user_joined_room(self.distributor, user, event.room_id)
+                yield self.user_joined_room(user, event.room_id)
 
         prev_state_ids = yield context.get_prev_state_ids(self.store)
 
@@ -1176,17 +1175,7 @@ class FederationHandler(BaseHandler):
         )
 
         context = yield self.state_handler.compute_event_context(event)
-
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event,
-            context=context,
-        )
-
-        target_user = UserID.from_string(event.state_key)
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id,
-            extra_users=[target_user],
-        )
+        yield self.persist_events_and_notify([(event, context)])
 
         defer.returnValue(event)
 
@@ -1211,35 +1200,26 @@ class FederationHandler(BaseHandler):
         except ValueError:
             pass
 
-        yield self.replication_layer.send_leave(
+        yield self.federation_client.send_leave(
             target_hosts,
             event
         )
 
         context = yield self.state_handler.compute_event_context(event)
-
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event,
-            context=context,
-        )
-
-        target_user = UserID.from_string(event.state_key)
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id,
-            extra_users=[target_user],
-        )
+        yield self.persist_events_and_notify([(event, context)])
 
         defer.returnValue(event)
 
     @defer.inlineCallbacks
     def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
-                               content={},):
-        origin, pdu = yield self.replication_layer.make_membership_event(
+                               content={}, params=None):
+        origin, pdu = yield self.federation_client.make_membership_event(
             target_hosts,
             room_id,
             user_id,
             membership,
             content,
+            params=params,
         )
 
         logger.debug("Got response to make_%s: %s", membership, pdu)
@@ -1318,7 +1298,7 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context, event_stream_id, max_stream_id = yield self._handle_new_event(
+        yield self._handle_new_event(
             origin, event
         )
 
@@ -1328,22 +1308,17 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        extra_users = []
-        if event.type == EventTypes.Member:
-            target_user_id = event.state_key
-            target_user = UserID.from_string(target_user_id)
-            extra_users.append(target_user)
-
-        self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id, extra_users=extra_users
-        )
-
         defer.returnValue(None)
 
     @defer.inlineCallbacks
     def get_state_for_pdu(self, room_id, event_id):
         """Returns the state at the event. i.e. not including said event.
         """
+
+        event = yield self.store.get_event(
+            event_id, allow_none=False, check_room_id=room_id,
+        )
+
         state_groups = yield self.store.get_state_groups(
             room_id, [event_id]
         )
@@ -1354,8 +1329,7 @@ class FederationHandler(BaseHandler):
                 (e.type, e.state_key): e for e in state
             }
 
-            event = yield self.store.get_event(event_id)
-            if event and event.is_state():
+            if event.is_state():
                 # Get previous state
                 if "replaces_state" in event.unsigned:
                     prev_id = event.unsigned["replaces_state"]
@@ -1374,6 +1348,10 @@ class FederationHandler(BaseHandler):
     def get_state_ids_for_pdu(self, room_id, event_id):
         """Returns the state at the event. i.e. not including said event.
         """
+        event = yield self.store.get_event(
+            event_id, allow_none=False, check_room_id=room_id,
+        )
+
         state_groups = yield self.store.get_state_groups_ids(
             room_id, [event_id]
         )
@@ -1382,8 +1360,7 @@ class FederationHandler(BaseHandler):
             _, state = state_groups.items().pop()
             results = state
 
-            event = yield self.store.get_event(event_id)
-            if event and event.is_state():
+            if event.is_state():
                 # Get previous state
                 if "replaces_state" in event.unsigned:
                     prev_id = event.unsigned["replaces_state"]
@@ -1472,9 +1449,8 @@ class FederationHandler(BaseHandler):
                     event, context
                 )
 
-            event_stream_id, max_stream_id = yield self.store.persist_event(
-                event,
-                context=context,
+            yield self.persist_events_and_notify(
+                [(event, context)],
                 backfilled=backfilled,
             )
         except:  # noqa: E722, as we reraise the exception this is fine.
@@ -1487,15 +1463,7 @@ class FederationHandler(BaseHandler):
 
             six.reraise(tp, value, tb)
 
-        if not backfilled:
-            # this intentionally does not yield: we don't care about the result
-            # and don't need to wait for it.
-            logcontext.run_in_background(
-                self.pusher_pool.on_new_notifications,
-                event_stream_id, max_stream_id,
-            )
-
-        defer.returnValue((context, event_stream_id, max_stream_id))
+        defer.returnValue(context)
 
     @defer.inlineCallbacks
     def _handle_new_events(self, origin, event_infos, backfilled=False):
@@ -1503,6 +1471,8 @@ class FederationHandler(BaseHandler):
         should not depend on one another, e.g. this should be used to persist
         a bunch of outliers, but not a chunk of individual events that depend
         on each other for state calculations.
+
+        Notifies about the events where appropriate.
         """
         contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
@@ -1517,7 +1487,7 @@ class FederationHandler(BaseHandler):
             ], consumeErrors=True,
         ))
 
-        yield self.store.persist_events(
+        yield self.persist_events_and_notify(
             [
                 (ev_info["event"], context)
                 for ev_info, context in zip(event_infos, contexts)
@@ -1529,7 +1499,8 @@ class FederationHandler(BaseHandler):
     def _persist_auth_tree(self, origin, auth_events, state, event):
         """Checks the auth chain is valid (and passes auth checks) for the
         state and event. Then persists the auth chain and state atomically.
-        Persists the event seperately.
+        Persists the event separately. Notifies about the persisted events
+        where appropriate.
 
         Will attempt to fetch missing auth events.
 
@@ -1540,8 +1511,7 @@ class FederationHandler(BaseHandler):
             event (Event)
 
         Returns:
-            2-tuple of (event_stream_id, max_stream_id) from the persist_event
-            call for `event`
+            Deferred
         """
         events_to_context = {}
         for e in itertools.chain(auth_events, state):
@@ -1567,7 +1537,7 @@ class FederationHandler(BaseHandler):
                     missing_auth_events.add(e_id)
 
         for e_id in missing_auth_events:
-            m_ev = yield self.replication_layer.get_pdu(
+            m_ev = yield self.federation_client.get_pdu(
                 [origin],
                 e_id,
                 outlier=True,
@@ -1605,7 +1575,7 @@ class FederationHandler(BaseHandler):
                     raise
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
-        yield self.store.persist_events(
+        yield self.persist_events_and_notify(
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
@@ -1616,12 +1586,10 @@ class FederationHandler(BaseHandler):
             event, old_state=state
         )
 
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event, new_event_context,
+        yield self.persist_events_and_notify(
+            [(event, new_event_context)],
         )
 
-        defer.returnValue((event_stream_id, max_stream_id))
-
     @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, auth_events=None):
         """
@@ -1678,8 +1646,19 @@ class FederationHandler(BaseHandler):
         defer.returnValue(context)
 
     @defer.inlineCallbacks
-    def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+    def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
                       missing):
+        in_room = yield self.auth.check_host_in_room(
+            room_id,
+            origin
+        )
+        if not in_room:
+            raise AuthError(403, "Host not in room.")
+
+        event = yield self.store.get_event(
+            event_id, allow_none=False, check_room_id=room_id
+        )
+
         # Just go through and process each event in `remote_auth_chain`. We
         # don't want to fall into the trap of `missing` being wrong.
         for e in remote_auth_chain:
@@ -1689,7 +1668,6 @@ class FederationHandler(BaseHandler):
                 pass
 
         # Now get the current auth_chain for the event.
-        event = yield self.store.get_event(event_id)
         local_auth_chain = yield self.store.get_auth_chain(
             [auth_id for auth_id, _ in event.auth_events],
             include_given=True
@@ -1777,7 +1755,7 @@ class FederationHandler(BaseHandler):
             logger.info("Missing auth: %s", missing_auth)
             # If we don't have all the auth events, we need to get them.
             try:
-                remote_auth_chain = yield self.replication_layer.get_event_auth(
+                remote_auth_chain = yield self.federation_client.get_event_auth(
                     origin, event.room_id, event.event_id
                 )
 
@@ -1893,7 +1871,7 @@ class FederationHandler(BaseHandler):
 
                 try:
                     # 2. Get remote difference.
-                    result = yield self.replication_layer.query_auth(
+                    result = yield self.federation_client.query_auth(
                         origin,
                         event.room_id,
                         event.event_id,
@@ -2192,7 +2170,7 @@ class FederationHandler(BaseHandler):
             yield member_handler.send_membership_event(None, event, context)
         else:
             destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
-            yield self.replication_layer.forward_third_party_invite(
+            yield self.federation_client.forward_third_party_invite(
                 destinations,
                 room_id,
                 event_dict,
@@ -2336,7 +2314,7 @@ class FederationHandler(BaseHandler):
                 for revocation.
         """
         try:
-            response = yield self.hs.get_simple_http_client().get_json(
+            response = yield self.http_client.get_json(
                 url,
                 {"public_key": public_key}
             )
@@ -2347,3 +2325,91 @@ class FederationHandler(BaseHandler):
             )
         if "valid" not in response or not response["valid"]:
             raise AuthError(403, "Third party certificate was invalid")
+
+    @defer.inlineCallbacks
+    def persist_events_and_notify(self, event_and_contexts, backfilled=False):
+        """Persists events and tells the notifier/pushers about them, if
+        necessary.
+
+        Args:
+            event_and_contexts(list[tuple[FrozenEvent, EventContext]])
+            backfilled (bool): Whether these events are a result of
+                backfilling or not
+
+        Returns:
+            Deferred
+        """
+        if self.config.worker_app:
+            yield self._send_events_to_master(
+                store=self.store,
+                event_and_contexts=event_and_contexts,
+                backfilled=backfilled
+            )
+        else:
+            max_stream_id = yield self.store.persist_events(
+                event_and_contexts,
+                backfilled=backfilled,
+            )
+
+            if not backfilled:  # Never notify for backfilled events
+                for event, _ in event_and_contexts:
+                    self._notify_persisted_event(event, max_stream_id)
+
+    def _notify_persisted_event(self, event, max_stream_id):
+        """Checks to see if notifier/pushers should be notified about the
+        event or not.
+
+        Args:
+            event (FrozenEvent)
+            max_stream_id (int): The max_stream_id returned by persist_events
+        """
+
+        extra_users = []
+        if event.type == EventTypes.Member:
+            target_user_id = event.state_key
+
+            # We notify for memberships if its an invite for one of our
+            # users
+            if event.internal_metadata.is_outlier():
+                if event.membership != Membership.INVITE:
+                    if not self.is_mine_id(target_user_id):
+                        return
+
+            target_user = UserID.from_string(target_user_id)
+            extra_users.append(target_user)
+        elif event.internal_metadata.is_outlier():
+            return
+
+        event_stream_id = event.internal_metadata.stream_ordering
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id,
+            extra_users=extra_users
+        )
+
+        self.pusher_pool.on_new_notifications(
+            event_stream_id, max_stream_id,
+        )
+
+    def _clean_room_for_join(self, room_id):
+        """Called to clean up any data in DB for a given room, ready for the
+        server to join the room.
+
+        Args:
+            room_id (str)
+        """
+        if self.config.worker_app:
+            return self._clean_room_for_join_client(room_id)
+        else:
+            return self.store.clean_room_for_join(room_id)
+
+    def user_joined_room(self, user, room_id):
+        """Called when a new user has joined the room
+        """
+        if self.config.worker_app:
+            return self._notify_user_membership_change(
+                room_id=room_id,
+                user_id=user.to_string(),
+                change="joined",
+            )
+        else:
+            return user_joined_room(self.distributor, user, room_id)