summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-03-13 13:17:08 +0000
committerErik Johnston <erik@matrix.org>2018-03-13 13:17:08 +0000
commitbf8e97bd3c11a4bdae9dcf255922b1ac2d03f5b3 (patch)
tree42423f71093a796415971f66ec4a24712bab42c2
parentAdd docstring (diff)
parentMerge pull request #2980 from matrix-org/erikj/rm_priv (diff)
downloadsynapse-bf8e97bd3c11a4bdae9dcf255922b1ac2d03f5b3.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/factor_remote_leave
-rw-r--r--synapse/federation/federation_base.py6
-rw-r--r--synapse/federation/federation_client.py1
-rw-r--r--synapse/federation/federation_server.py123
-rw-r--r--synapse/federation/replication.py22
-rw-r--r--synapse/handlers/device.py6
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/directory.py2
-rw-r--r--synapse/handlers/e2e_keys.py2
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/presence.py10
-rw-r--r--synapse/handlers/profile.py2
-rw-r--r--synapse/handlers/receipts.py2
-rw-r--r--synapse/handlers/register.py26
-rw-r--r--synapse/handlers/room_member.py24
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/replication/http/send_event.py8
-rw-r--r--synapse/rest/client/v1/room.py2
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/events.py9
-rw-r--r--tests/handlers/test_directory.py9
-rw-r--r--tests/handlers/test_profile.py9
21 files changed, 154 insertions, 120 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 7918d3e442..79eaa31031 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -27,7 +27,13 @@ logger = logging.getLogger(__name__)
 
 class FederationBase(object):
     def __init__(self, hs):
+        self.hs = hs
+
+        self.server_name = hs.hostname
+        self.keyring = hs.get_keyring()
         self.spam_checker = hs.get_spam_checker()
+        self.store = hs.get_datastore()
+        self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
     def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 813907f7f2..38440da5b5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -58,6 +58,7 @@ class FederationClient(FederationBase):
             self._clear_tried_cache, 60 * 1000,
         )
         self.state = hs.get_state_handler()
+        self.transport_layer = hs.get_federation_transport_client()
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9849953c9b..f6fd2e86e9 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -17,12 +17,14 @@ import logging
 import simplejson as json
 from twisted.internet import defer
 
-from synapse.api.errors import AuthError, FederationError, SynapseError
+from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.federation.federation_base import (
     FederationBase,
     event_from_pdu_json,
 )
+
+from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
 import synapse.metrics
 from synapse.types import get_domain_from_id
@@ -56,6 +58,12 @@ class FederationServer(FederationBase):
         self._server_linearizer = async.Linearizer("fed_server")
         self._transaction_linearizer = async.Linearizer("fed_txn_handler")
 
+        self.transaction_actions = TransactionActions(self.store)
+
+        self.handler = None
+
+        self.registry = hs.get_federation_registry()
+
         # We cache responses to state queries, as they take a while and often
         # come in waves.
         self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
@@ -67,35 +75,6 @@ class FederationServer(FederationBase):
         """
         self.handler = handler
 
-    def register_edu_handler(self, edu_type, handler):
-        if edu_type in self.edu_handlers:
-            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
-        self.edu_handlers[edu_type] = handler
-
-    def register_query_handler(self, query_type, handler):
-        """Sets the handler callable that will be used to handle an incoming
-        federation Query of the given type.
-
-        Args:
-            query_type (str): Category name of the query, which should match
-                the string used by make_query.
-            handler (callable): Invoked to handle incoming queries of this type
-
-        handler is invoked as:
-            result = handler(args)
-
-        where 'args' is a dict mapping strings to strings of the query
-          arguments. It should return a Deferred that will eventually yield an
-          object to encode as JSON.
-        """
-        if query_type in self.query_handlers:
-            raise KeyError(
-                "Already have a Query handler for %s" % (query_type,)
-            )
-
-        self.query_handlers[query_type] = handler
-
     @defer.inlineCallbacks
     @log_function
     def on_backfill_request(self, origin, room_id, versions, limit):
@@ -229,16 +208,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def received_edu(self, origin, edu_type, content):
         received_edus_counter.inc()
-
-        if edu_type in self.edu_handlers:
-            try:
-                yield self.edu_handlers[edu_type](origin, content)
-            except SynapseError as e:
-                logger.info("Failed to handle edu %r: %r", edu_type, e)
-            except Exception as e:
-                logger.exception("Failed to handle edu %r", edu_type)
-        else:
-            logger.warn("Received EDU of type %s with no handler", edu_type)
+        yield self.registry.on_edu(edu_type, origin, content)
 
     @defer.inlineCallbacks
     @log_function
@@ -328,14 +298,8 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_query_request(self, query_type, args):
         received_queries_counter.inc(query_type)
-
-        if query_type in self.query_handlers:
-            response = yield self.query_handlers[query_type](args)
-            defer.returnValue((200, response))
-        else:
-            defer.returnValue(
-                (404, "No handler for Query type '%s'" % (query_type,))
-            )
+        resp = yield self.registry.on_query(query_type, args)
+        defer.returnValue((200, resp))
 
     @defer.inlineCallbacks
     def on_make_join_request(self, room_id, user_id):
@@ -607,3 +571,66 @@ class FederationServer(FederationBase):
             origin, room_id, event_dict
         )
         defer.returnValue(ret)
+
+
+class FederationHandlerRegistry(object):
+    """Allows classes to register themselves as handlers for a given EDU or
+    query type for incoming federation traffic.
+    """
+    def __init__(self):
+        self.edu_handlers = {}
+        self.query_handlers = {}
+
+    def register_edu_handler(self, edu_type, handler):
+        """Sets the handler callable that will be used to handle an incoming
+        federation EDU of the given type.
+
+        Args:
+            edu_type (str): The type of the incoming EDU to register handler for
+            handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+                of the given type. The arguments are the origin server name and
+                the EDU contents.
+        """
+        if edu_type in self.edu_handlers:
+            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+        self.edu_handlers[edu_type] = handler
+
+    def register_query_handler(self, query_type, handler):
+        """Sets the handler callable that will be used to handle an incoming
+        federation query of the given type.
+
+        Args:
+            query_type (str): Category name of the query, which should match
+                the string used by make_query.
+            handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+                incoming queries of this type. The return will be yielded
+                on and the result used as the response to the query request.
+        """
+        if query_type in self.query_handlers:
+            raise KeyError(
+                "Already have a Query handler for %s" % (query_type,)
+            )
+
+        self.query_handlers[query_type] = handler
+
+    @defer.inlineCallbacks
+    def on_edu(self, edu_type, origin, content):
+        handler = self.edu_handlers.get(edu_type)
+        if not handler:
+            logger.warn("No handler registered for EDU type %s", edu_type)
+
+        try:
+            yield handler(origin, content)
+        except SynapseError as e:
+            logger.info("Failed to handle edu %r: %r", edu_type, e)
+        except Exception as e:
+            logger.exception("Failed to handle edu %r", edu_type)
+
+    def on_query(self, query_type, args):
+        handler = self.query_handlers.get(query_type)
+        if not handler:
+            logger.warn("No handler registered for query type %s", query_type)
+            raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+        return handler(args)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 62d865ec4b..b8b3a3f933 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -20,8 +20,6 @@ a given transport.
 from .federation_client import FederationClient
 from .federation_server import FederationServer
 
-from .persistence import TransactionActions
-
 import logging
 
 
@@ -47,26 +45,6 @@ class ReplicationLayer(FederationClient, FederationServer):
     """
 
     def __init__(self, hs, transport_layer):
-        self.server_name = hs.hostname
-
-        self.keyring = hs.get_keyring()
-
-        self.transport_layer = transport_layer
-
-        self.federation_client = self
-
-        self.store = hs.get_datastore()
-
-        self.handler = None
-        self.edu_handlers = {}
-        self.query_handlers = {}
-
-        self._clock = hs.get_clock()
-
-        self.transaction_actions = TransactionActions(self.store)
-
-        self.hs = hs
-
         super(ReplicationLayer, self).__init__(hs)
 
     def __str__(self):
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 0e83453851..9e58dbe64e 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -41,10 +41,12 @@ class DeviceHandler(BaseHandler):
 
         self._edu_updater = DeviceListEduUpdater(hs, self)
 
-        self.federation.register_edu_handler(
+        federation_registry = hs.get_federation_registry()
+
+        federation_registry.register_edu_handler(
             "m.device_list_update", self._edu_updater.incoming_device_list_update,
         )
-        self.federation.register_query_handler(
+        federation_registry.register_query_handler(
             "user_devices", self.on_federation_query_user_devices,
         )
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index d996aa90bb..f147a20b73 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -37,7 +37,7 @@ class DeviceMessageHandler(object):
         self.is_mine = hs.is_mine
         self.federation = hs.get_federation_sender()
 
-        hs.get_replication_layer().register_edu_handler(
+        hs.get_federation_registry().register_edu_handler(
             "m.direct_to_device", self.on_direct_to_device_edu
         )
 
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 8580ada60a..e955cb1f3c 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -37,7 +37,7 @@ class DirectoryHandler(BaseHandler):
         self.event_creation_handler = hs.get_event_creation_handler()
 
         self.federation = hs.get_replication_layer()
-        self.federation.register_query_handler(
+        hs.get_federation_registry().register_query_handler(
             "directory", self.on_directory_query
         )
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9aa95f89e6..57f50a4e27 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -40,7 +40,7 @@ class E2eKeysHandler(object):
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
         # "query handler" interface.
-        self.federation.register_query_handler(
+        hs.get_federation_registry().register_query_handler(
             "client_keys", self.on_federation_query_client_keys
         )
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 42aab91c50..4f97c8db79 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -667,7 +667,7 @@ class EventCreationHandler(object):
             event (FrozenEvent)
             context (EventContext)
             ratelimit (bool)
-            extra_users (list(str)): Any extra users to notify about event
+            extra_users (list(UserID)): Any extra users to notify about event
         """
 
         try:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index cb158ba962..b11ae78350 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -98,24 +98,26 @@ class PresenceHandler(object):
 
         self.state = hs.get_state_handler()
 
-        self.replication.register_edu_handler(
+        federation_registry = hs.get_federation_registry()
+
+        federation_registry.register_edu_handler(
             "m.presence", self.incoming_presence
         )
-        self.replication.register_edu_handler(
+        federation_registry.register_edu_handler(
             "m.presence_invite",
             lambda origin, content: self.invite_presence(
                 observed_user=UserID.from_string(content["observed_user"]),
                 observer_user=UserID.from_string(content["observer_user"]),
             )
         )
-        self.replication.register_edu_handler(
+        federation_registry.register_edu_handler(
             "m.presence_accept",
             lambda origin, content: self.accept_presence(
                 observed_user=UserID.from_string(content["observed_user"]),
                 observer_user=UserID.from_string(content["observer_user"]),
             )
         )
-        self.replication.register_edu_handler(
+        federation_registry.register_edu_handler(
             "m.presence_deny",
             lambda origin, content: self.deny_presence(
                 observed_user=UserID.from_string(content["observed_user"]),
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c9c2879038..c386c79bbd 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -32,7 +32,7 @@ class ProfileHandler(BaseHandler):
         super(ProfileHandler, self).__init__(hs)
 
         self.federation = hs.get_replication_layer()
-        self.federation.register_query_handler(
+        hs.get_federation_registry().register_query_handler(
             "profile", self.on_profile_query
         )
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 0525765272..3f215c2b4e 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -35,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
         self.store = hs.get_datastore()
         self.hs = hs
         self.federation = hs.get_federation_sender()
-        hs.get_replication_layer().register_edu_handler(
+        hs.get_federation_registry().register_edu_handler(
             "m.receipt", self._received_remote_receipt
         )
         self.clock = self.hs.get_clock()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 9021d4d57f..ed5939880a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -446,16 +446,34 @@ class RegistrationHandler(BaseHandler):
         return self.hs.get_auth_handler()
 
     @defer.inlineCallbacks
-    def guest_access_token_for(self, medium, address, inviter_user_id):
+    def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
+        """Get a guest access token for a 3PID, creating a guest account if
+        one doesn't already exist.
+
+        Args:
+            medium (str)
+            address (str)
+            inviter_user_id (str): The user ID who is trying to invite the
+                3PID
+
+        Returns:
+            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+            3PID guest account.
+        """
         access_token = yield self.store.get_3pid_guest_access_token(medium, address)
         if access_token:
-            defer.returnValue(access_token)
+            user_info = yield self.auth.get_user_by_access_token(
+                access_token
+            )
 
-        _, access_token = yield self.register(
+            defer.returnValue((user_info["user"].to_string(), access_token))
+
+        user_id, access_token = yield self.register(
             generate_token=True,
             make_guest=True
         )
         access_token = yield self.store.save_or_get_3pid_guest_access_token(
             medium, address, access_token, inviter_user_id
         )
-        defer.returnValue(access_token)
+
+        defer.returnValue((user_id, access_token))
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index da35e604d0..6ee8420d1f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -138,7 +138,7 @@ class RoomMemberHandler(object):
         defer.returnValue(event)
 
     @defer.inlineCallbacks
-    def remote_join(self, remote_room_hosts, room_id, user, content):
+    def _remote_join(self, remote_room_hosts, room_id, user, content):
         """Try and join a room that this server is not in
 
         Args:
@@ -342,7 +342,7 @@ class RoomMemberHandler(object):
                     raise AuthError(403, "Guest access not allowed")
 
             if not is_host_in_room:
-                inviter = yield self.get_inviter(target.to_string(), room_id)
+                inviter = yield self._get_inviter(target.to_string(), room_id)
                 if inviter and not self.hs.is_mine(inviter):
                     remote_room_hosts.append(inviter.domain)
 
@@ -356,7 +356,7 @@ class RoomMemberHandler(object):
                 if requester.is_guest:
                     content["kind"] = "guest"
 
-                ret = yield self.remote_join(
+                ret = yield self._remote_join(
                     remote_room_hosts, room_id, target, content
                 )
                 defer.returnValue(ret)
@@ -364,7 +364,7 @@ class RoomMemberHandler(object):
         elif effective_membership_state == Membership.LEAVE:
             if not is_host_in_room:
                 # perhaps we've been invited
-                inviter = yield self.get_inviter(target.to_string(), room_id)
+                inviter = yield self._get_inviter(target.to_string(), room_id)
                 if not inviter:
                     raise SynapseError(404, "Not a known room")
 
@@ -528,7 +528,7 @@ class RoomMemberHandler(object):
         defer.returnValue((RoomID.from_string(room_id), servers))
 
     @defer.inlineCallbacks
-    def get_inviter(self, user_id, room_id):
+    def _get_inviter(self, user_id, room_id):
         invite = yield self.store.get_invite_for_user_in_room(
             user_id=user_id,
             room_id=room_id,
@@ -605,7 +605,7 @@ class RoomMemberHandler(object):
             if "mxid" in data:
                 if "signatures" not in data:
                     raise AuthError(401, "No signatures on 3pid binding")
-                yield self.verify_any_signature(data, id_server)
+                yield self._verify_any_signature(data, id_server)
                 defer.returnValue(data["mxid"])
 
         except IOError as e:
@@ -613,7 +613,7 @@ class RoomMemberHandler(object):
             defer.returnValue(None)
 
     @defer.inlineCallbacks
-    def verify_any_signature(self, data, server_hostname):
+    def _verify_any_signature(self, data, server_hostname):
         if server_hostname not in data["signatures"]:
             raise AuthError(401, "No signature from server %s" % (server_hostname,))
         for key_name, signature in data["signatures"][server_hostname].items():
@@ -767,20 +767,16 @@ class RoomMemberHandler(object):
         }
 
         if self.config.invite_3pid_guest:
-            registration_handler = self.registration_handler
-            guest_access_token = yield registration_handler.guest_access_token_for(
+            rh = self.registration_handler
+            guest_user_id, guest_access_token = yield rh.get_or_register_3pid_guest(
                 medium=medium,
                 address=address,
                 inviter_user_id=inviter_user_id,
             )
 
-            guest_user_info = yield self.auth.get_user_by_access_token(
-                guest_access_token
-            )
-
             invite_config.update({
                 "guest_access_token": guest_access_token,
-                "guest_user_id": guest_user_info["user"].to_string(),
+                "guest_user_id": guest_user_id,
             })
 
         data = yield self.simple_http_client.post_urlencoded_get_json(
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 82dedbbc99..77c0cf146f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -56,7 +56,7 @@ class TypingHandler(object):
 
         self.federation = hs.get_federation_sender()
 
-        hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
+        hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 70f2fe456a..bbe2f967b7 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -25,7 +25,7 @@ from synapse.util.async import sleep
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 from synapse.util.metrics import Measure
-from synapse.types import Requester
+from synapse.types import Requester, UserID
 
 import logging
 import re
@@ -46,7 +46,7 @@ def send_event_to_master(client, host, port, requester, event, context,
         event (FrozenEvent)
         context (EventContext)
         ratelimit (bool)
-        extra_users (list(str)): Any extra users to notify about event
+        extra_users (list(UserID)): Any extra users to notify about event
     """
     uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
         host, port, event.event_id,
@@ -59,7 +59,7 @@ def send_event_to_master(client, host, port, requester, event, context,
         "context": context.serialize(event),
         "requester": requester.serialize(),
         "ratelimit": ratelimit,
-        "extra_users": extra_users,
+        "extra_users": [u.to_string() for u in extra_users],
     }
 
     try:
@@ -143,7 +143,7 @@ class ReplicationSendEventRestServlet(RestServlet):
             context = yield EventContext.deserialize(self.store, content["context"])
 
             ratelimit = content["ratelimit"]
-            extra_users = content["extra_users"]
+            extra_users = [UserID.from_string(u) for u in content["extra_users"]]
 
         if requester.user:
             request.authenticated_entity = requester.user.to_string()
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 9d745174c7..f8999d64d7 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -599,7 +599,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
     def register(self, http_server):
         # /rooms/$roomid/[invite|join|leave]
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
-                    "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
+                    "(?P<membership_action>join|invite|leave|ban|unban|kick)")
         register_txn_path(self, PATTERNS, http_server)
 
     @defer.inlineCallbacks
diff --git a/synapse/server.py b/synapse/server.py
index 5b6effbe31..1bc8d6f702 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -34,6 +34,7 @@ from synapse.events.builder import EventBuilderFactory
 from synapse.events.spamcheck import SpamChecker
 from synapse.federation import initialize_http_replication
 from synapse.federation.send_queue import FederationRemoteSendQueue
+from synapse.federation.federation_server import FederationHandlerRegistry
 from synapse.federation.transport.client import TransportLayerClient
 from synapse.federation.transaction_queue import TransactionQueue
 from synapse.handlers import Handlers
@@ -147,6 +148,7 @@ class HomeServer(object):
         'groups_attestation_renewer',
         'spam_checker',
         'room_member_handler',
+        'federation_registry',
     ]
 
     def __init__(self, hostname, **kwargs):
@@ -387,6 +389,9 @@ class HomeServer(object):
     def build_room_member_handler(self):
         return RoomMemberHandler(self)
 
+    def build_federation_registry(self):
+        return FederationHandlerRegistry()
+
     def remove_pusher(self, app_id, push_key, user_id):
         return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 826fad307e..3890878170 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -283,10 +283,11 @@ class EventsStore(EventsWorkerStore):
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
         def persisting_queue(item):
-            yield self._persist_events(
-                item.events_and_contexts,
-                backfilled=item.backfilled,
-            )
+            with Measure(self._clock, "persist_events"):
+                yield self._persist_events(
+                    item.events_and_contexts,
+                    backfilled=item.backfilled,
+                )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 5712773909..b103921498 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -35,21 +35,20 @@ class DirectoryTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation = Mock(spec=[
-            "make_query",
-            "register_edu_handler",
-        ])
+        self.mock_federation = Mock()
+        self.mock_registry = Mock()
 
         self.query_handlers = {}
 
         def register_query_handler(query_type, handler):
             self.query_handlers[query_type] = handler
-        self.mock_federation.register_query_handler = register_query_handler
+        self.mock_registry.register_query_handler = register_query_handler
 
         hs = yield setup_test_homeserver(
             http_client=None,
             resource_for_federation=Mock(),
             replication_layer=self.mock_federation,
+            federation_registry=self.mock_registry,
         )
         hs.handlers = DirectoryHandlers(hs)
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a5f47181d7..73223ffbd3 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -37,23 +37,22 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation = Mock(spec=[
-            "make_query",
-            "register_edu_handler",
-        ])
+        self.mock_federation = Mock()
+        self.mock_registry = Mock()
 
         self.query_handlers = {}
 
         def register_query_handler(query_type, handler):
             self.query_handlers[query_type] = handler
 
-        self.mock_federation.register_query_handler = register_query_handler
+        self.mock_registry.register_query_handler = register_query_handler
 
         hs = yield setup_test_homeserver(
             http_client=None,
             handlers=None,
             resource_for_federation=Mock(),
             replication_layer=self.mock_federation,
+            federation_registry=self.mock_registry,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ])