summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-01-26 16:06:54 +0000
committerErik Johnston <erik@matrix.org>2017-01-26 16:07:24 +0000
commitc974116f197d211ba9b42159fe61cfd5957411b5 (patch)
treeb4d6e71850ba0d371e0463a5cc99439ce8072c40
parentFix up sending of m.device_list_update edus (diff)
downloadsynapse-c974116f197d211ba9b42159fe61cfd5957411b5.tar.xz
Implement device key caching over federation
-rw-r--r--synapse/federation/federation_client.py10
-rw-r--r--synapse/federation/federation_server.py3
-rw-r--r--synapse/federation/transport/client.py26
-rw-r--r--synapse/federation/transport/server.py8
-rw-r--r--synapse/handlers/device.py85
-rw-r--r--synapse/handlers/e2e_keys.py40
-rw-r--r--synapse/storage/devices.py201
-rw-r--r--synapse/storage/end_to_end_keys.py4
-rw-r--r--synapse/storage/schema/delta/40/device_list_streams.sql20
-rw-r--r--tests/handlers/test_device.py18
-rw-r--r--tests/handlers/test_directory.py1
-rw-r--r--tests/handlers/test_profile.py1
-rw-r--r--tests/storage/test_appservice.py21
13 files changed, 381 insertions, 57 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c9175bb33d..b5bcfd705a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -127,6 +127,16 @@ class FederationClient(FederationBase):
         )
 
     @log_function
+    def query_user_devices(self, destination, user_id, timeout=30000):
+        """Query the device keys for a list of user ids hosted on a remote
+        server.
+        """
+        sent_queries_counter.inc("user_devices")
+        return self.transport_layer.query_user_devices(
+            destination, user_id, timeout
+        )
+
+    @log_function
     def claim_client_keys(self, destination, content, timeout):
         """Claims one-time keys for a device hosted on a remote server.
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 862ccbef5d..e922b7ff4a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -416,6 +416,9 @@ class FederationServer(FederationBase):
     def on_query_client_keys(self, origin, content):
         return self.on_query_request("client_keys", content)
 
+    def on_query_user_devices(self, origin, user_id):
+        return self.on_query_request("user_devices", user_id)
+
     @defer.inlineCallbacks
     @log_function
     def on_claim_client_keys(self, origin, content):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 915af34409..f49e8a2cc4 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -348,6 +348,32 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def query_user_devices(self, destination, user_id, timeout):
+        """Query the devices for a user id hosted on a remote server.
+
+        Response:
+            {
+              "stream_id": "...",
+              "devices": [ { ... } ]
+            }
+
+        Args:
+            destination(str): The server to query.
+            query_content(dict): The user ids to query.
+        Returns:
+            A dict containg the device keys.
+        """
+        path = PREFIX + "/user/devices/" + user_id
+
+        content = yield self.client.get_json(
+            destination=destination,
+            path=path,
+            timeout=timeout,
+        )
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
     def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 159dbd1747..c840da834c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
         return self.handler.on_query_client_keys(origin, content)
 
 
+class FederationUserDevicesQueryServlet(BaseFederationServlet):
+    PATH = "/user/devices/(?P<user_id>[^/]*)"
+
+    def on_GET(self, origin, content, query, user_id):
+        return self.handler.on_query_user_devices(origin, user_id)
+
+
 class FederationClientKeysClaimServlet(BaseFederationServlet):
     PATH = "/user/keys/claim"
 
@@ -613,6 +620,7 @@ SERVLET_CLASSES = (
     FederationGetMissingEventsServlet,
     FederationEventAuthServlet,
     FederationClientKeysQueryServlet,
+    FederationUserDevicesQueryServlet,
     FederationClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ba4c48d590..2d66b3721a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,7 @@
 
 from synapse.api import errors
 from synapse.util import stringutils
+from synapse.util.async import Linearizer
 from synapse.types import get_domain_from_id
 from twisted.internet import defer
 from ._base import BaseHandler
@@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler):
     def __init__(self, hs):
         super(DeviceHandler, self).__init__(hs)
 
+        self.hs = hs
         self.state = hs.get_state_handler()
-        self.federation = hs.get_federation_sender()
+        self.federation_sender = hs.get_federation_sender()
+        self.federation = hs.get_replication_layer()
+        self._remote_edue_linearizer = Linearizer(name="remote_device_list")
+
+        self.federation.register_edu_handler(
+            "m.device_list_update", self._incoming_device_list_update,
+        )
+        self.federation.register_query_handler(
+            "user_devices", self.on_federation_query_user_devices,
+        )
 
     @defer.inlineCallbacks
     def check_device_registered(self, user_id, device_id,
@@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler):
                 initial_device_display_name=initial_device_display_name,
             )
             if new_device:
-                yield self.notify_device_update(user_id, device_id)
+                yield self.notify_device_update(user_id, [device_id])
             defer.returnValue(device_id)
 
         # if the device id is not specified, we'll autogen one, but loop a few
@@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler):
                 initial_device_display_name=initial_device_display_name,
             )
             if new_device:
-                yield self.notify_device_update(user_id, device_id)
+                yield self.notify_device_update(user_id, [device_id])
                 defer.returnValue(device_id)
             attempts += 1
 
@@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler):
             user_id=user_id, device_id=device_id
         )
 
-        yield self.notify_device_update(user_id, device_id)
+        yield self.notify_device_update(user_id, [device_id])
 
     @defer.inlineCallbacks
     def update_device(self, user_id, device_id, content):
@@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler):
                 device_id,
                 new_display_name=content.get("display_name")
             )
-            yield self.notify_device_update(user_id, device_id)
+            yield self.notify_device_update(user_id, [device_id])
         except errors.StoreError, e:
             if e.code == 404:
                 raise errors.NotFoundError()
@@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler):
                 raise
 
     @defer.inlineCallbacks
-    def notify_device_update(self, user_id, device_id):
+    def notify_device_update(self, user_id, device_ids):
         rooms = yield self.store.get_rooms_for_user(user_id)
         room_ids = [r.room_id for r in rooms]
 
         hosts = set()
-        for room_id in room_ids:
-            users = yield self.state.get_current_user_in_room(room_id)
-            hosts.update(get_domain_from_id(u) for u in users)
-        hosts.discard(self.server_name)
+        if self.hs.is_mine_id(user_id):
+            for room_id in room_ids:
+                users = yield self.state.get_current_user_in_room(room_id)
+                hosts.update(get_domain_from_id(u) for u in users)
+            hosts.discard(self.server_name)
 
         position = yield self.store.add_device_change_to_streams(
-            user_id, device_id, list(hosts)
+            user_id, device_ids, list(hosts)
         )
 
         yield self.notifier.on_new_event(
             "device_list_key", position, rooms=room_ids,
         )
 
+        logger.info("Sending device list update notif to: %r", hosts)
         for host in hosts:
-            self.federation.send_device_messages(host)
+            self.federation_sender.send_device_messages(host)
 
     @defer.inlineCallbacks
     def get_device_list_changes(self, user_id, room_ids, from_key):
@@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler):
 
         defer.returnValue(user_ids_changed)
 
+    @defer.inlineCallbacks
+    def _incoming_device_list_update(self, origin, edu_content):
+        user_id = edu_content["user_id"]
+        device_id = edu_content["device_id"]
+        stream_id = edu_content["stream_id"]
+        prev_ids = edu_content.get("prev_id", [])
+
+        if get_domain_from_id(user_id) != origin:
+            # TODO: Raise?
+            return
+
+        logger.info("Got edu: %r", edu_content)
+
+        with (yield self._remote_edue_linearizer.queue(user_id)):
+            resync = True
+            if len(prev_ids) == 1:
+                extremity = yield self.store.get_device_list_remote_extremity(user_id)
+                logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
+                if str(extremity) == str(prev_ids[0]):
+                    resync = False
+
+            if resync:
+                result = yield self.federation.query_user_devices(origin, user_id)
+                stream_id = result["stream_id"]
+                devices = result["devices"]
+                yield self.store.update_remote_device_list_cache(
+                    user_id, devices, stream_id,
+                )
+                device_ids = [device["device_id"] for device in devices]
+                yield self.notify_device_update(user_id, device_ids)
+            else:
+                content = dict(edu_content)
+                for key in ("user_id", "device_id", "stream_id", "prev_ids"):
+                    content.pop(key, None)
+                yield self.store.update_remote_device_list_cache_entry(
+                    user_id, device_id, content, stream_id,
+                )
+                yield self.notify_device_update(user_id, [device_id])
+
+    @defer.inlineCallbacks
+    def on_federation_query_user_devices(self, user_id):
+        stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+        defer.returnValue({
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+        })
+
 
 def _update_device_from_client_ips(device, client_ips):
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 38c2a2d39e..832998a6d3 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -73,8 +73,7 @@ class E2eKeysHandler(object):
             if self.is_mine_id(user_id):
                 local_query[user_id] = device_ids
             else:
-                domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_ids
+                remote_queries[user_id] = device_ids
 
         # do the queries
         failures = {}
@@ -85,9 +84,40 @@ class E2eKeysHandler(object):
                 if user_id in local_query:
                     results[user_id] = keys
 
+        remote_queries_not_in_cache = {}
+        if remote_queries:
+            query_list = []
+            for user_id, device_ids in remote_queries.iteritems():
+                if device_ids:
+                    query_list.extend((user_id, device_id) for device_id in device_ids)
+                else:
+                    query_list.append((user_id, None))
+
+            user_ids_not_in_cache, remote_results = (
+                yield self.store.get_user_devices_from_cache(
+                    query_list
+                )
+            )
+            for user_id, devices in remote_results.iteritems():
+                user_devices = results.setdefault(user_id, {})
+                for device_id, device in devices.iteritems():
+                    keys = device.get("keys", None)
+                    device_display_name = device.get("device_display_name", None)
+                    if keys:
+                        result = dict(keys)
+                        unsigned = result.setdefault("unsigned", {})
+                        if device_display_name:
+                            unsigned["device_display_name"] = device_display_name
+                        user_devices[device_id] = result
+
+            for user_id in user_ids_not_in_cache:
+                domain = get_domain_from_id(user_id)
+                r = remote_queries_not_in_cache.setdefault(domain, {})
+                r[user_id] = remote_queries[user_id]
+
         @defer.inlineCallbacks
         def do_remote_query(destination):
-            destination_query = remote_queries[destination]
+            destination_query = remote_queries_not_in_cache[destination]
             try:
                 limiter = yield get_retry_limiter(
                     destination, self.clock, self.store
@@ -119,7 +149,7 @@ class E2eKeysHandler(object):
 
         yield preserve_context_over_deferred(defer.gatherResults([
             preserve_fn(do_remote_query)(destination)
-            for destination in remote_queries
+            for destination in remote_queries_not_in_cache
         ]))
 
         defer.returnValue({
@@ -259,7 +289,7 @@ class E2eKeysHandler(object):
                 user_id, device_id, time_now,
                 encode_canonical_json(device_keys)
             )
-            yield self.device_handler.notify_device_update(user_id, device_id)
+            yield self.device_handler.notify_device_update(user_id, [device_id])
 
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys:
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 9628e2ff75..8ee3119db2 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore):
 
         defer.returnValue({d["device_id"]: d for d in devices})
 
+    def get_device_list_remote_extremity(self, user_id):
+        return self._simple_select_one_onecol(
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+            retcol="stream_id",
+            desc="get_device_list_remote_extremity",
+            allow_none=True,
+        )
+
+    def update_remote_device_list_cache_entry(self, user_id, device_id, content,
+                                              stream_id):
+        return self.runInteraction(
+            "update_remote_device_list_cache_entry",
+            self._update_remote_device_list_cache_entry_txn,
+            user_id, device_id, content, stream_id,
+        )
+
+    def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
+                                                   content, stream_id):
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            values={
+                "content": json.dumps(content),
+            }
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def update_remote_device_list_cache(self, user_id, devices, stream_id):
+        return self.runInteraction(
+            "update_remote_device_list_cache",
+            self._update_remote_device_list_cache_txn,
+            user_id, devices, stream_id,
+        )
+
+    def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
+                                             stream_id):
+        self._simple_delete_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_remote_cache",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": content["device_id"],
+                    "content": json.dumps(content),
+                }
+                for content in devices
+            ]
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
     def get_devices_by_remote(self, destination, from_stream_id):
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore):
             txn.execute(prev_sent_id_sql, (destination, user_id, True))
             rows = txn.fetchall()
             prev_id = rows[0][0]
-            for device_id, result in user_devices.iteritems():
+            for device_id, device in user_devices.iteritems():
                 stream_id = query_map[(user_id, device_id)]
                 result = {
                     "user_id": user_id,
@@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore):
 
                 prev_id = stream_id
 
-                key_json = result.get("key_json", None)
+                key_json = device.get("key_json", None)
                 if key_json:
                     result["keys"] = json.loads(key_json)
-                device_display_name = result.get("device_display_name", None)
+                device_display_name = device.get("device_display_name", None)
                 if device_display_name:
                     result["device_display_name"] = device_display_name
 
@@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore):
 
         return (now_stream_id, results)
 
+    def get_user_devices_from_cache(self, query_list):
+        return self.runInteraction(
+            "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
+            query_list,
+        )
+
+    def _get_user_devices_from_cache_txn(self, txn, query_list):
+        user_ids = {user_id for user_id, _ in query_list}
+
+        user_ids_in_cache = set()
+        for user_id in user_ids:
+            stream_ids = self._simple_select_onecol_txn(
+                txn,
+                table="device_lists_remote_extremeties",
+                keyvalues={
+                    "user_id": user_id,
+                },
+                retcol="stream_id",
+            )
+            if stream_ids:
+                user_ids_in_cache.add(user_id)
+
+        user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+        results = {}
+        for user_id, device_id in query_list:
+            if user_id not in user_ids_in_cache:
+                continue
+
+            if device_id:
+                content = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                    },
+                    retcol="content",
+                )
+                results.setdefault(user_id, {})[device_id] = json.loads(content)
+            else:
+                devices = self._simple_select_list_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                    },
+                    retcols=("device_id", "content"),
+                )
+                results[user_id] = {
+                    device["device_id"]: json.loads(device["content"])
+                    for device in devices
+                }
+                user_ids_in_cache.discard(user_id)
+
+        return user_ids_not_in_cache, results
+
+    def get_devices_with_keys_by_user(self, user_id):
+        return self.runInteraction(
+            "get_devices_with_keys_by_user",
+            self._get_devices_with_keys_by_user_txn, user_id,
+        )
+
+    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, [(user_id, None)], include_all_devices=True
+        )
+
+        for user_id, user_devices in devices.iteritems():
+            results = []
+            for device_id, device in user_devices.iteritems():
+                result = {
+                    "device_id": device_id,
+                }
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
     def mark_as_sent_devices_by_remote(self, destination, stream_id):
         return self.runInteraction(
             "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
@@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore):
         defer.returnValue(set(row["user_id"] for row in rows))
 
     @defer.inlineCallbacks
-    def add_device_change_to_streams(self, user_id, device_id, hosts):
+    def add_device_change_to_streams(self, user_id, device_ids, hosts):
         # device_lists_stream
         # device_lists_outbound_pokes
         with self._device_list_id_gen.get_next() as stream_id:
             yield self.runInteraction(
                 "add_device_change_to_streams", self._add_device_change_txn,
-                user_id, device_id, hosts, stream_id,
+                user_id, device_ids, hosts, stream_id,
             )
         defer.returnValue(stream_id)
 
-    def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id):
+    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
         txn.call_after(
             self._device_list_stream_cache.entity_has_changed,
             user_id, stream_id,
@@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore):
                 host, stream_id,
             )
 
-        self._simple_insert_txn(
+        self._simple_insert_many_txn(
             txn,
             table="device_lists_stream",
-            values={
-                "stream_id": stream_id,
-                "user_id": user_id,
-                "device_id": device_id,
-            }
+            values=[
+                {
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+                for device_id in device_ids
+            ]
         )
 
         self._simple_insert_many_txn(
@@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore):
                     "sent": False,
                 }
                 for destination in hosts
+                for device_id in device_ids
             ]
         )
 
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index f82943a7a8..a915c790ff 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -52,11 +52,11 @@ class EndToEndKeyStore(SQLBaseStore):
         query_params = []
 
         for (user_id, device_id) in query_list:
-            query_clause = "k.user_id = ?"
+            query_clause = "user_id = ?"
             query_params.append(user_id)
 
             if device_id:
-                query_clause += " AND k.device_id = ?"
+                query_clause += " AND device_id = ?"
                 query_params.append(device_id)
 
             query_clauses.append(query_clause)
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql
index 61cac63bbb..d1051c6ddf 100644
--- a/synapse/storage/schema/delta/40/device_list_streams.sql
+++ b/synapse/storage/schema/delta/40/device_list_streams.sql
@@ -13,18 +13,6 @@
  * limitations under the License.
  */
 
-CREATE TABLE device_list_streams_remote (
-    list_id TEXT NOT NULL,
-    origin TEXT NOT NULL,
-    user_id TEXT NOT NULL,
-    is_full BOOLEAN NOT NULL,
-    ts BIGINT NOT NULL
-);
-
-CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote(
-    origin, list_id, user_id
-);
-
 
 CREATE TABLE device_lists_remote_cache (
     user_id TEXT NOT NULL,
@@ -35,6 +23,14 @@ CREATE TABLE device_lists_remote_cache (
 CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
 
 
+CREATE TABLE device_lists_remote_extremeties (
+    user_id TEXT NOT NULL,
+    stream_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
+
+
 CREATE TABLE device_lists_stream (
     stream_id BIGINT NOT NULL,
     user_id TEXT NOT NULL,
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 85a970a6c9..2eaaa8253c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield utils.setup_test_homeserver(handlers=None)
-        self.handler = synapse.handlers.device.DeviceHandler(hs)
+        hs = yield utils.setup_test_homeserver()
+        self.handler = hs.get_device_handler()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
 
     @defer.inlineCallbacks
     def test_device_is_created_if_doesnt_exist(self):
         res = yield self.handler.check_device_registered(
-            user_id="boris",
+            user_id="@boris:foo",
             device_id="fco",
             initial_device_display_name="display name"
         )
         self.assertEqual(res, "fco")
 
-        dev = yield self.handler.store.get_device("boris", "fco")
+        dev = yield self.handler.store.get_device("@boris:foo", "fco")
         self.assertEqual(dev["display_name"], "display name")
 
     @defer.inlineCallbacks
     def test_device_is_preserved_if_exists(self):
         res1 = yield self.handler.check_device_registered(
-            user_id="boris",
+            user_id="@boris:foo",
             device_id="fco",
             initial_device_display_name="display name"
         )
         self.assertEqual(res1, "fco")
 
         res2 = yield self.handler.check_device_registered(
-            user_id="boris",
+            user_id="@boris:foo",
             device_id="fco",
             initial_device_display_name="new display name"
         )
         self.assertEqual(res2, "fco")
 
-        dev = yield self.handler.store.get_device("boris", "fco")
+        dev = yield self.handler.store.get_device("@boris:foo", "fco")
         self.assertEqual(dev["display_name"], "display name")
 
     @defer.inlineCallbacks
     def test_device_id_is_made_up_if_unspecified(self):
         device_id = yield self.handler.check_device_registered(
-            user_id="theresa",
+            user_id="@theresa:foo",
             device_id=None,
             initial_device_display_name="display"
         )
 
-        dev = yield self.handler.store.get_device("theresa", device_id)
+        dev = yield self.handler.store.get_device("@theresa:foo", device_id)
         self.assertEqual(dev["display_name"], "display")
 
     @defer.inlineCallbacks
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 5d602c1531..ceb9aa5765 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
     def setUp(self):
         self.mock_federation = Mock(spec=[
             "make_query",
+            "register_edu_handler",
         ])
 
         self.query_handlers = {}
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f1f664275f..979cebf600 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
     def setUp(self):
         self.mock_federation = Mock(spec=[
             "make_query",
+            "register_edu_handler",
         ])
 
         self.query_handlers = {}
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 9ff1abcd80..9e98d0e330 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
             event_cache_size=1,
             password_providers=[],
         )
-        hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
+        hs = yield setup_test_homeserver(
+            config=config,
+            federation_sender=Mock(),
+            replication_layer=Mock(),
+        )
 
         self.as_token = "token1"
         self.as_url = "some_url"
@@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
             event_cache_size=1,
             password_providers=[],
         )
-        hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
+        hs = yield setup_test_homeserver(
+            config=config,
+            federation_sender=Mock(),
+            replication_layer=Mock(),
+        )
         self.db_pool = hs.get_db_pool()
 
         self.as_list = [
@@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             datastore=Mock(),
-            federation_sender=Mock()
+            federation_sender=Mock(),
+            replication_layer=Mock(),
         )
 
         ApplicationServiceStore(hs)
@@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             datastore=Mock(),
-            federation_sender=Mock()
+            federation_sender=Mock(),
+            replication_layer=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
@@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             datastore=Mock(),
-            federation_sender=Mock()
+            federation_sender=Mock(),
+            replication_layer=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm: