diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c9175bb33d..b5bcfd705a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -127,6 +127,16 @@ class FederationClient(FederationBase):
)
@log_function
+ def query_user_devices(self, destination, user_id, timeout=30000):
+ """Query the device keys for a list of user ids hosted on a remote
+ server.
+ """
+ sent_queries_counter.inc("user_devices")
+ return self.transport_layer.query_user_devices(
+ destination, user_id, timeout
+ )
+
+ @log_function
def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 862ccbef5d..e922b7ff4a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -416,6 +416,9 @@ class FederationServer(FederationBase):
def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content)
+ def on_query_user_devices(self, origin, user_id):
+ return self.on_query_request("user_devices", user_id)
+
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 915af34409..f49e8a2cc4 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -348,6 +348,32 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
+ def query_user_devices(self, destination, user_id, timeout):
+ """Query the devices for a user id hosted on a remote server.
+
+ Response:
+ {
+ "stream_id": "...",
+ "devices": [ { ... } ]
+ }
+
+ Args:
+ destination(str): The server to query.
+ query_content(dict): The user ids to query.
+ Returns:
+ A dict containg the device keys.
+ """
+ path = PREFIX + "/user/devices/" + user_id
+
+ content = yield self.client.get_json(
+ destination=destination,
+ path=path,
+ timeout=timeout,
+ )
+ defer.returnValue(content)
+
+ @defer.inlineCallbacks
+ @log_function
def claim_client_keys(self, destination, query_content, timeout):
"""Claim one-time keys for a list of devices hosted on a remote server.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 159dbd1747..c840da834c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
return self.handler.on_query_client_keys(origin, content)
+class FederationUserDevicesQueryServlet(BaseFederationServlet):
+ PATH = "/user/devices/(?P<user_id>[^/]*)"
+
+ def on_GET(self, origin, content, query, user_id):
+ return self.handler.on_query_user_devices(origin, user_id)
+
+
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
@@ -613,6 +620,7 @@ SERVLET_CLASSES = (
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
+ FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ba4c48d590..2d66b3721a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,7 @@
from synapse.api import errors
from synapse.util import stringutils
+from synapse.util.async import Linearizer
from synapse.types import get_domain_from_id
from twisted.internet import defer
from ._base import BaseHandler
@@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
+ self.hs = hs
self.state = hs.get_state_handler()
- self.federation = hs.get_federation_sender()
+ self.federation_sender = hs.get_federation_sender()
+ self.federation = hs.get_replication_layer()
+ self._remote_edue_linearizer = Linearizer(name="remote_device_list")
+
+ self.federation.register_edu_handler(
+ "m.device_list_update", self._incoming_device_list_update,
+ )
+ self.federation.register_query_handler(
+ "user_devices", self.on_federation_query_user_devices,
+ )
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
@@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler):
initial_device_display_name=initial_device_display_name,
)
if new_device:
- yield self.notify_device_update(user_id, device_id)
+ yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
@@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler):
initial_device_display_name=initial_device_display_name,
)
if new_device:
- yield self.notify_device_update(user_id, device_id)
+ yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id)
attempts += 1
@@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id
)
- yield self.notify_device_update(user_id, device_id)
+ yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
@@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler):
device_id,
new_display_name=content.get("display_name")
)
- yield self.notify_device_update(user_id, device_id)
+ yield self.notify_device_update(user_id, [device_id])
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
@@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler):
raise
@defer.inlineCallbacks
- def notify_device_update(self, user_id, device_id):
+ def notify_device_update(self, user_id, device_ids):
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
hosts = set()
- for room_id in room_ids:
- users = yield self.state.get_current_user_in_room(room_id)
- hosts.update(get_domain_from_id(u) for u in users)
- hosts.discard(self.server_name)
+ if self.hs.is_mine_id(user_id):
+ for room_id in room_ids:
+ users = yield self.state.get_current_user_in_room(room_id)
+ hosts.update(get_domain_from_id(u) for u in users)
+ hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams(
- user_id, device_id, list(hosts)
+ user_id, device_ids, list(hosts)
)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
+ logger.info("Sending device list update notif to: %r", hosts)
for host in hosts:
- self.federation.send_device_messages(host)
+ self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def get_device_list_changes(self, user_id, room_ids, from_key):
@@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler):
defer.returnValue(user_ids_changed)
+ @defer.inlineCallbacks
+ def _incoming_device_list_update(self, origin, edu_content):
+ user_id = edu_content["user_id"]
+ device_id = edu_content["device_id"]
+ stream_id = edu_content["stream_id"]
+ prev_ids = edu_content.get("prev_id", [])
+
+ if get_domain_from_id(user_id) != origin:
+ # TODO: Raise?
+ return
+
+ logger.info("Got edu: %r", edu_content)
+
+ with (yield self._remote_edue_linearizer.queue(user_id)):
+ resync = True
+ if len(prev_ids) == 1:
+ extremity = yield self.store.get_device_list_remote_extremity(user_id)
+ logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
+ if str(extremity) == str(prev_ids[0]):
+ resync = False
+
+ if resync:
+ result = yield self.federation.query_user_devices(origin, user_id)
+ stream_id = result["stream_id"]
+ devices = result["devices"]
+ yield self.store.update_remote_device_list_cache(
+ user_id, devices, stream_id,
+ )
+ device_ids = [device["device_id"] for device in devices]
+ yield self.notify_device_update(user_id, device_ids)
+ else:
+ content = dict(edu_content)
+ for key in ("user_id", "device_id", "stream_id", "prev_ids"):
+ content.pop(key, None)
+ yield self.store.update_remote_device_list_cache_entry(
+ user_id, device_id, content, stream_id,
+ )
+ yield self.notify_device_update(user_id, [device_id])
+
+ @defer.inlineCallbacks
+ def on_federation_query_user_devices(self, user_id):
+ stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+ defer.returnValue({
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ })
+
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 38c2a2d39e..832998a6d3 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -73,8 +73,7 @@ class E2eKeysHandler(object):
if self.is_mine_id(user_id):
local_query[user_id] = device_ids
else:
- domain = get_domain_from_id(user_id)
- remote_queries.setdefault(domain, {})[user_id] = device_ids
+ remote_queries[user_id] = device_ids
# do the queries
failures = {}
@@ -85,9 +84,40 @@ class E2eKeysHandler(object):
if user_id in local_query:
results[user_id] = keys
+ remote_queries_not_in_cache = {}
+ if remote_queries:
+ query_list = []
+ for user_id, device_ids in remote_queries.iteritems():
+ if device_ids:
+ query_list.extend((user_id, device_id) for device_id in device_ids)
+ else:
+ query_list.append((user_id, None))
+
+ user_ids_not_in_cache, remote_results = (
+ yield self.store.get_user_devices_from_cache(
+ query_list
+ )
+ )
+ for user_id, devices in remote_results.iteritems():
+ user_devices = results.setdefault(user_id, {})
+ for device_id, device in devices.iteritems():
+ keys = device.get("keys", None)
+ device_display_name = device.get("device_display_name", None)
+ if keys:
+ result = dict(keys)
+ unsigned = result.setdefault("unsigned", {})
+ if device_display_name:
+ unsigned["device_display_name"] = device_display_name
+ user_devices[device_id] = result
+
+ for user_id in user_ids_not_in_cache:
+ domain = get_domain_from_id(user_id)
+ r = remote_queries_not_in_cache.setdefault(domain, {})
+ r[user_id] = remote_queries[user_id]
+
@defer.inlineCallbacks
def do_remote_query(destination):
- destination_query = remote_queries[destination]
+ destination_query = remote_queries_not_in_cache[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
@@ -119,7 +149,7 @@ class E2eKeysHandler(object):
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
- for destination in remote_queries
+ for destination in remote_queries_not_in_cache
]))
defer.returnValue({
@@ -259,7 +289,7 @@ class E2eKeysHandler(object):
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
- yield self.device_handler.notify_device_update(user_id, device_id)
+ yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 9628e2ff75..8ee3119db2 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices})
+ def get_device_list_remote_extremity(self, user_id):
+ return self._simple_select_one_onecol(
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ retcol="stream_id",
+ desc="get_device_list_remote_extremity",
+ allow_none=True,
+ )
+
+ def update_remote_device_list_cache_entry(self, user_id, device_id, content,
+ stream_id):
+ return self.runInteraction(
+ "update_remote_device_list_cache_entry",
+ self._update_remote_device_list_cache_entry_txn,
+ user_id, device_id, content, stream_id,
+ )
+
+ def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
+ content, stream_id):
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "content": json.dumps(content),
+ }
+ )
+
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={
+ "user_id": user_id,
+ },
+ values={
+ "stream_id": stream_id,
+ }
+ )
+
+ def update_remote_device_list_cache(self, user_id, devices, stream_id):
+ return self.runInteraction(
+ "update_remote_device_list_cache",
+ self._update_remote_device_list_cache_txn,
+ user_id, devices, stream_id,
+ )
+
+ def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
+ stream_id):
+ self._simple_delete_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="device_lists_remote_cache",
+ values=[
+ {
+ "user_id": user_id,
+ "device_id": content["device_id"],
+ "content": json.dumps(content),
+ }
+ for content in devices
+ ]
+ )
+
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={
+ "user_id": user_id,
+ },
+ values={
+ "stream_id": stream_id,
+ }
+ )
+
def get_devices_by_remote(self, destination, from_stream_id):
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore):
txn.execute(prev_sent_id_sql, (destination, user_id, True))
rows = txn.fetchall()
prev_id = rows[0][0]
- for device_id, result in user_devices.iteritems():
+ for device_id, device in user_devices.iteritems():
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id
- key_json = result.get("key_json", None)
+ key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
- device_display_name = result.get("device_display_name", None)
+ device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, results)
+ def get_user_devices_from_cache(self, query_list):
+ return self.runInteraction(
+ "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
+ query_list,
+ )
+
+ def _get_user_devices_from_cache_txn(self, txn, query_list):
+ user_ids = {user_id for user_id, _ in query_list}
+
+ user_ids_in_cache = set()
+ for user_id in user_ids:
+ stream_ids = self._simple_select_onecol_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcol="stream_id",
+ )
+ if stream_ids:
+ user_ids_in_cache.add(user_id)
+
+ user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+ results = {}
+ for user_id, device_id in query_list:
+ if user_id not in user_ids_in_cache:
+ continue
+
+ if device_id:
+ content = self._simple_select_one_onecol_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="content",
+ )
+ results.setdefault(user_id, {})[device_id] = json.loads(content)
+ else:
+ devices = self._simple_select_list_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("device_id", "content"),
+ )
+ results[user_id] = {
+ device["device_id"]: json.loads(device["content"])
+ for device in devices
+ }
+ user_ids_in_cache.discard(user_id)
+
+ return user_ids_not_in_cache, results
+
+ def get_devices_with_keys_by_user(self, user_id):
+ return self.runInteraction(
+ "get_devices_with_keys_by_user",
+ self._get_devices_with_keys_by_user_txn, user_id,
+ )
+
+ def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, [(user_id, None)], include_all_devices=True
+ )
+
+ for user_id, user_devices in devices.iteritems():
+ results = []
+ for device_id, device in user_devices.iteritems():
+ result = {
+ "device_id": device_id,
+ }
+
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = json.loads(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
def mark_as_sent_devices_by_remote(self, destination, stream_id):
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
@@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore):
defer.returnValue(set(row["user_id"] for row in rows))
@defer.inlineCallbacks
- def add_device_change_to_streams(self, user_id, device_id, hosts):
+ def add_device_change_to_streams(self, user_id, device_ids, hosts):
# device_lists_stream
# device_lists_outbound_pokes
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
- user_id, device_id, hosts, stream_id,
+ user_id, device_ids, hosts, stream_id,
)
defer.returnValue(stream_id)
- def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id):
+ def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
@@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore):
host, stream_id,
)
- self._simple_insert_txn(
+ self._simple_insert_many_txn(
txn,
table="device_lists_stream",
- values={
- "stream_id": stream_id,
- "user_id": user_id,
- "device_id": device_id,
- }
+ values=[
+ {
+ "stream_id": stream_id,
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ for device_id in device_ids
+ ]
)
self._simple_insert_many_txn(
@@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore):
"sent": False,
}
for destination in hosts
+ for device_id in device_ids
]
)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index f82943a7a8..a915c790ff 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -52,11 +52,11 @@ class EndToEndKeyStore(SQLBaseStore):
query_params = []
for (user_id, device_id) in query_list:
- query_clause = "k.user_id = ?"
+ query_clause = "user_id = ?"
query_params.append(user_id)
if device_id:
- query_clause += " AND k.device_id = ?"
+ query_clause += " AND device_id = ?"
query_params.append(device_id)
query_clauses.append(query_clause)
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql
index 61cac63bbb..d1051c6ddf 100644
--- a/synapse/storage/schema/delta/40/device_list_streams.sql
+++ b/synapse/storage/schema/delta/40/device_list_streams.sql
@@ -13,18 +13,6 @@
* limitations under the License.
*/
-CREATE TABLE device_list_streams_remote (
- list_id TEXT NOT NULL,
- origin TEXT NOT NULL,
- user_id TEXT NOT NULL,
- is_full BOOLEAN NOT NULL,
- ts BIGINT NOT NULL
-);
-
-CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote(
- origin, list_id, user_id
-);
-
CREATE TABLE device_lists_remote_cache (
user_id TEXT NOT NULL,
@@ -35,6 +23,14 @@ CREATE TABLE device_lists_remote_cache (
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
+CREATE TABLE device_lists_remote_extremeties (
+ user_id TEXT NOT NULL,
+ stream_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
+
+
CREATE TABLE device_lists_stream (
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 85a970a6c9..2eaaa8253c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield utils.setup_test_homeserver(handlers=None)
- self.handler = synapse.handlers.device.DeviceHandler(hs)
+ hs = yield utils.setup_test_homeserver()
+ self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered(
- user_id="boris",
+ user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res, "fco")
- dev = yield self.handler.store.get_device("boris", "fco")
+ dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered(
- user_id="boris",
+ user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered(
- user_id="boris",
+ user_id="@boris:foo",
device_id="fco",
initial_device_display_name="new display name"
)
self.assertEqual(res2, "fco")
- dev = yield self.handler.store.get_device("boris", "fco")
+ dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered(
- user_id="theresa",
+ user_id="@theresa:foo",
device_id=None,
initial_device_display_name="display"
)
- dev = yield self.handler.store.get_device("theresa", device_id)
+ dev = yield self.handler.store.get_device("@theresa:foo", device_id)
self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 5d602c1531..ceb9aa5765 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
+ "register_edu_handler",
])
self.query_handlers = {}
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f1f664275f..979cebf600 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
+ "register_edu_handler",
])
self.query_handlers = {}
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 9ff1abcd80..9e98d0e330 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
event_cache_size=1,
password_providers=[],
)
- hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
+ hs = yield setup_test_homeserver(
+ config=config,
+ federation_sender=Mock(),
+ replication_layer=Mock(),
+ )
self.as_token = "token1"
self.as_url = "some_url"
@@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
event_cache_size=1,
password_providers=[],
)
- hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
+ hs = yield setup_test_homeserver(
+ config=config,
+ federation_sender=Mock(),
+ replication_layer=Mock(),
+ )
self.db_pool = hs.get_db_pool()
self.as_list = [
@@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
- federation_sender=Mock()
+ federation_sender=Mock(),
+ replication_layer=Mock(),
)
ApplicationServiceStore(hs)
@@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
- federation_sender=Mock()
+ federation_sender=Mock(),
+ replication_layer=Mock(),
)
with self.assertRaises(ConfigError) as cm:
@@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
- federation_sender=Mock()
+ federation_sender=Mock(),
+ replication_layer=Mock(),
)
with self.assertRaises(ConfigError) as cm:
|