diff --git a/changelog.d/5727.feature b/changelog.d/5727.feature
new file mode 100644
index 0000000000..819bebf2d7
--- /dev/null
+++ b/changelog.d/5727.feature
@@ -0,0 +1 @@
+Add federation support for cross-signing.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index cc75c39476..d5d4a60c88 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -366,10 +366,10 @@ class PerDestinationQueue(object):
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.device_list_update",
+ edu_type=edu_type,
content=content,
)
- for content in results
+ for (edu_type, content) in results
]
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5f23ee4488..fd8d14b680 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -458,7 +458,18 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+ master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+ self_signing_key = yield self.store.get_e2e_cross_signing_key(
+ user_id, "self_signing"
+ )
+
+ return {
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ "master_key": master_key,
+ "self_signing_key": self_signing_key,
+ }
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5ea54f60be..4ab75a351e 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,6 +36,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key,
)
from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
+ self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+ federation_registry = hs.get_federation_registry()
+
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ federation_registry.register_edu_handler(
+ "org.matrix.signing_key_update",
+ self._edu_updater.incoming_signing_key_update,
+ )
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- hs.get_federation_registry().register_query_handler(
+ federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -207,13 +218,15 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
- for user_id, key in remote_result["master_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
except Exception as e:
failure = _exception_to_failure(e)
@@ -251,7 +264,7 @@ class E2eKeysHandler(object):
Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from
- (master|self_signing|user_signing) -> user_id -> key
+ (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@@ -343,7 +356,16 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
- return {"device_keys": res}
+ ret = {"device_keys": res}
+
+ # add in the cross-signing keys
+ cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ device_keys_query, None
+ )
+
+ ret.update(cross_signing_keys)
+
+ return ret
@trace
@defer.inlineCallbacks
@@ -1047,3 +1069,100 @@ class SignatureListItem:
target_user_id = attr.ib()
target_device_id = attr.ib()
signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+ "Handles incoming signing key updates from federation and updates the DB"
+
+ def __init__(self, hs, e2e_keys_handler):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_federation_client()
+ self.clock = hs.get_clock()
+ self.e2e_keys_handler = e2e_keys_handler
+
+ self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+ # user_id -> list of updates waiting to be handled.
+ self._pending_updates = {}
+
+ # Recently seen stream ids. We don't bother keeping these in the DB,
+ # but they're useful to have them about to reduce the number of spurious
+ # resyncs.
+ self._seen_updates = ExpiringCache(
+ cache_name="signing_key_update_edu",
+ clock=self.clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ iterable=True,
+ )
+
+ @defer.inlineCallbacks
+ def incoming_signing_key_update(self, origin, edu_content):
+ """Called on incoming signing key update from federation. Responsible for
+ parsing the EDU and adding to pending updates list.
+
+ Args:
+ origin (string): the server that sent the EDU
+ edu_content (dict): the contents of the EDU
+ """
+
+ user_id = edu_content.pop("user_id")
+ master_key = edu_content.pop("master_key", None)
+ self_signing_key = edu_content.pop("self_signing_key", None)
+
+ if get_domain_from_id(user_id) != origin:
+ # TODO: Raise?
+ logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+ return
+
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ # We don't share any rooms with this user. Ignore update, as we
+ # probably won't get any further updates.
+ return
+
+ self._pending_updates.setdefault(user_id, []).append(
+ (master_key, self_signing_key, edu_content)
+ )
+
+ yield self._handle_signing_key_updates(user_id)
+
+ @defer.inlineCallbacks
+ def _handle_signing_key_updates(self, user_id):
+ """Actually handle pending updates.
+
+ Args:
+ user_id (string): the user whose updates we are processing
+ """
+
+ device_handler = self.e2e_keys_handler.device_handler
+
+ with (yield self._remote_edu_linearizer.queue(user_id)):
+ pending_updates = self._pending_updates.pop(user_id, [])
+ if not pending_updates:
+ # This can happen since we batch updates
+ return
+
+ device_ids = []
+
+ logger.info("pending updates: %r", pending_updates)
+
+ for master_key, self_signing_key, edu_content in pending_updates:
+ if master_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "master", master_key
+ )
+ device_id = get_verify_key_from_cross_signing_key(master_key)[
+ 1
+ ].version
+ device_ids.append(device_id)
+ if self_signing_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ device_id = get_verify_key_from_cross_signing_key(self_signing_key)[
+ 1
+ ].version
+ device_ids.append(device_id)
+
+ yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a3542348..6ac165068e 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -94,9 +95,10 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get stream of updates to send to remote servers
Returns:
- Deferred[tuple[int, list[dict]]]:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
- response), and the list of updates
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -129,6 +131,33 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
+ # get the cross-signing keys of the users the list
+ users = set(r[0] for r in updates)
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "pubkey": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "pubkey": verify_key.version,
+ }
+
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
@@ -158,6 +187,16 @@ class DeviceWorkerStore(SQLBaseStore):
# Stop processing updates
break
+ # skip over cross-signing keys
+ if (
+ update[0] in master_key_by_user
+ and update[1] == master_key_by_user[update[0]]["pubkey"]
+ ) or (
+ update[0] in master_key_by_user
+ and update[1] == self_signing_key_by_user[update[0]]["pubkey"]
+ ):
+ continue
+
key = (update[0], update[1])
update_context = update[3]
@@ -172,16 +211,40 @@ class DeviceWorkerStore(SQLBaseStore):
# means that there are more than limit updates all of which have the same
# steam_id.
+ # figure out which cross-signing keys were changed by intersecting the
+ # update list with the master/self-signing key by user maps
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, stream, _opentracing_context in updates:
+ if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif device_id == self_signing_key_by_user.get(user_id, {}).get(
+ "pubkey", None
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+
+ cross_signing_results = []
+
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ cross_signing_results.append(("org.matrix.signing_key_update", result))
+
# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map:
+ if not query_map and not cross_signing_results:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
+ results.extend(cross_signing_results)
return now_stream_id, results
@@ -200,6 +263,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
+ # get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +289,16 @@ class DeviceWorkerStore(SQLBaseStore):
List[Dict]: List of objects representing an device update EDU
"""
- devices = yield self.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
- query_map.keys(),
- include_all_devices=True,
- include_deleted_devices=True,
+ devices = (
+ yield self.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
)
results = []
@@ -262,7 +330,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
- results.append(result)
+ results.append(("m.device_list_update", result))
return results
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index f5c3ed9dc2..a0bc6f2d18 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -248,6 +248,73 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
+ def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
+ """Returns a user's cross-signing key.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_id (str): the user whose key is being requested
+ key_type (str): the type of key that is being set: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
+ from_user_id (str): if specified, signatures made by this user on
+ the key will be included in the result
+
+ Returns:
+ dict of the key data or None if not found
+ """
+ sql = (
+ "SELECT keydata "
+ " FROM e2e_cross_signing_keys "
+ " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
+ )
+ txn.execute(sql, (user_id, key_type))
+ row = txn.fetchone()
+ if not row:
+ return None
+ key = json.loads(row[0])
+
+ device_id = None
+ for k in key["keys"].values():
+ device_id = k
+
+ if from_user_id is not None:
+ sql = (
+ "SELECT key_id, signature "
+ " FROM e2e_cross_signing_signatures "
+ " WHERE user_id = ? "
+ " AND target_user_id = ? "
+ " AND target_device_id = ? "
+ )
+ txn.execute(sql, (from_user_id, user_id, device_id))
+ row = txn.fetchone()
+ if row:
+ key.setdefault("signatures", {}).setdefault(from_user_id, {})[
+ row[0]
+ ] = row[1]
+
+ return key
+
+ def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+ """Returns a user's cross-signing key.
+
+ Args:
+ user_id (str): the user whose self-signing key is being requested
+ key_type (str): the type of cross-signing key to get
+ from_user_id (str): if specified, signatures made by this user on
+ the self-signing key will be included in the result
+
+ Returns:
+ dict of the key data or None if not found
+ """
+ return self.runInteraction(
+ "get_e2e_cross_signing_key",
+ self._get_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ from_user_id,
+ )
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
@@ -426,73 +493,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key,
)
- def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
- """Returns a user's cross-signing key.
-
- Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being set: either 'master'
- for a master key, 'self_signing' for a self-signing key, or
- 'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
- the key will be included in the result
-
- Returns:
- dict of the key data or None if not found
- """
- sql = (
- "SELECT keydata "
- " FROM e2e_cross_signing_keys "
- " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
- )
- txn.execute(sql, (user_id, key_type))
- row = txn.fetchone()
- if not row:
- return None
- key = json.loads(row[0])
-
- device_id = None
- for k in key["keys"].values():
- device_id = k
-
- if from_user_id is not None:
- sql = (
- "SELECT key_id, signature "
- " FROM e2e_cross_signing_signatures "
- " WHERE user_id = ? "
- " AND target_user_id = ? "
- " AND target_device_id = ? "
- )
- txn.execute(sql, (from_user_id, user_id, device_id))
- row = txn.fetchone()
- if row:
- key.setdefault("signatures", {}).setdefault(from_user_id, {})[
- row[0]
- ] = row[1]
-
- return key
-
- def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
- """Returns a user's cross-signing key.
-
- Args:
- user_id (str): the user whose self-signing key is being requested
- key_type (str): the type of cross-signing key to get
- from_user_id (str): if specified, signatures made by this user on
- the self-signing key will be included in the result
-
- Returns:
- dict of the key data or None if not found
- """
- return self.runInteraction(
- "get_e2e_cross_signing_key",
- self._get_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- from_user_id,
- )
-
def store_e2e_cross_signing_signatures(self, user_id, signatures):
"""Stores cross-signing signatures.
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1c..039cc79357 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
- received_device_ids = {update["device_id"] for update in device_updates}
+ received_device_ids = {
+ update["device_id"] for edu_type, update in device_updates
+ }
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
|