From 1b1aed38e3df1daca709b9bb57b1443b9e6da2bd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 10 Jan 2022 13:40:46 +0000 Subject: Deal with mypy errors w/ type-hinted pynacl 1.5.0 (#11714) * Deal with mypy errors w/ type-hinted pynacl 1.5.0 Fixes #11644. I really don't like that we're monkey patching pynacl SignedKey instances with alg and version objects. But I'm too scared to make the changes necessary right now. (Ideally I would replace `signedjson.types.SingingKey` with a runtime class which wraps or inherits from `nacl.signing.SigningKey`.) C.f. https://github.com/matrix-org/python-signedjson/issues/16 --- tests/crypto/test_event_signing.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 1c920157f5..a72a0103d3 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -14,6 +14,7 @@ import nacl.signing +import signedjson.types from unpaddedbase64 import decode_base64 from synapse.api.room_versions import RoomVersions @@ -35,7 +36,12 @@ HOSTNAME = "domain" class EventSigningTestCase(unittest.TestCase): def setUp(self): - self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED) + # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been + # monkeypatched to include new `alg` and `version` attributes. This is captured + # by the `signedjson.types.SigningKey` protocol. + self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( + SIGNING_KEY_SEED + ) self.signing_key.alg = KEY_ALG self.signing_key.version = KEY_VER -- cgit 1.4.1 From 22abfca8d9e8c486d9cf4624e8422a84cc361c83 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 12 Jan 2022 15:21:13 +0000 Subject: Fix a bug introduced in Synapse v1.0.0 whereby device list updates would not be sent to remote homeservers if there were too many to send at once. (#11729) Co-authored-by: Brendan Abolivier --- changelog.d/11729.bugfix | 1 + synapse/storage/databases/main/devices.py | 8 ++++- tests/storage/test_devices.py | 50 ++++++++++++++++++++++++++++++- 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11729.bugfix (limited to 'tests') diff --git a/changelog.d/11729.bugfix b/changelog.d/11729.bugfix new file mode 100644 index 0000000000..8438ce5686 --- /dev/null +++ b/changelog.d/11729.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.0.0 whereby some device list updates would not be sent to remote homeservers if there were too many to send at once. \ No newline at end of file diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 273adb61fd..324bd5f879 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -270,6 +270,10 @@ class DeviceWorkerStore(SQLBaseStore): # The most recent request's opentracing_context is used as the # context which created the Edu. + # This is the stream ID that we will return for the consumer to resume + # following this stream later. + last_processed_stream_id = from_stream_id + query_map = {} cross_signing_keys_by_user = {} for user_id, device_id, update_stream_id, update_context in updates: @@ -295,6 +299,8 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) + last_processed_stream_id = update_stream_id + results = await self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) @@ -307,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore): # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) - return now_stream_id, results + return last_processed_stream_id, results def _get_device_updates_by_remote_txn( self, diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6790aa5242..5cd7f6bb7a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -94,7 +94,7 @@ class DeviceStoreTestCase(HomeserverTestCase): def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] - # Add two device updates with a single stream_id + # Add two device updates with sequential `stream_id`s self.get_success( self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) @@ -107,6 +107,54 @@ class DeviceStoreTestCase(HomeserverTestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) + def test_get_device_updates_by_remote_can_limit_properly(self): + """ + Tests that `get_device_updates_by_remote` returns an appropriate + stream_id to resume fetching from (without skipping any results). + """ + + # Add some device updates with sequential `stream_id`s + device_ids = [ + "device_id1", + "device_id2", + "device_id3", + "device_id4", + "device_id5", + ] + self.get_success( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + ) + + # Get all device updates ever meant for this remote + next_stream_id, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", -1, limit=3) + ) + + # Check the first three original device_ids are contained within these updates + self._check_devices_in_updates(device_ids[:3], device_updates) + + # Get the next batch of device updates + next_stream_id, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3) + ) + + # Check the last two original device_ids are contained within these updates + self._check_devices_in_updates(device_ids[3:], device_updates) + + # Add some more device updates to ensure it still resumes properly + device_ids = ["device_id6", "device_id7"] + self.get_success( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + ) + + # Get the next batch of device updates + next_stream_id, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3) + ) + + # Check the newly-added device_ids are contained within these updates + self._check_devices_in_updates(device_ids, device_updates) + def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) -- cgit 1.4.1 From b602ba194bf9d8066be5fabd403f6d60f5b9883a Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 13 Jan 2022 18:12:18 +0000 Subject: Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates. (#11730) Co-authored-by: David Robertson --- changelog.d/11730.bugfix | 1 + synapse/storage/databases/main/devices.py | 94 ++++++++++++++++++++----- tests/storage/test_devices.py | 112 +++++++++++++++++++++++++++++- 3 files changed, 190 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11730.bugfix (limited to 'tests') diff --git a/changelog.d/11730.bugfix b/changelog.d/11730.bugfix new file mode 100644 index 0000000000..a0bd7dd1a3 --- /dev/null +++ b/changelog.d/11730.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates. \ No newline at end of file diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 324bd5f879..bc7e876047 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -191,7 +191,7 @@ class DeviceWorkerStore(SQLBaseStore): @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int - ) -> Tuple[int, List[Tuple[str, dict]]]: + ) -> Tuple[int, List[Tuple[str, JsonDict]]]: """Get a stream of device updates to send to the given remote server. Args: @@ -200,9 +200,10 @@ class DeviceWorkerStore(SQLBaseStore): limit: Maximum number of device updates to return Returns: - A mapping from the current stream id (ie, the stream id of the last - update included in the response), and the list of updates, where - each update is a pair of EDU type and EDU contents. + - The current stream id (i.e. the stream id of the last update included + in the response); and + - The list of updates, where each update is a pair of EDU type and + EDU contents. """ now_stream_id = self.get_device_stream_token() @@ -221,6 +222,9 @@ class DeviceWorkerStore(SQLBaseStore): limit, ) + # We need to ensure `updates` doesn't grow too big. + # Currently: `len(updates) <= limit`. + # Return an empty list if there are no updates if not updates: return now_stream_id, [] @@ -277,16 +281,43 @@ class DeviceWorkerStore(SQLBaseStore): query_map = {} cross_signing_keys_by_user = {} for user_id, device_id, update_stream_id, update_context in updates: - if ( + # Calculate the remaining length budget. + # Note that, for now, each entry in `cross_signing_keys_by_user` + # gives rise to two device updates in the result, so those cost twice + # as much (and are the whole reason we need to separately calculate + # the budget; we know len(updates) <= limit otherwise!) + # N.B. len() on dicts is cheap since they store their size. + remaining_length_budget = limit - ( + len(query_map) + 2 * len(cross_signing_keys_by_user) + ) + assert remaining_length_budget >= 0 + + is_master_key_update = ( user_id in master_key_by_user and device_id == master_key_by_user[user_id]["device_id"] - ): - result = cross_signing_keys_by_user.setdefault(user_id, {}) - result["master_key"] = master_key_by_user[user_id]["key_info"] - elif ( + ) + is_self_signing_key_update = ( user_id in self_signing_key_by_user and device_id == self_signing_key_by_user[user_id]["device_id"] + ) + + is_cross_signing_key_update = ( + is_master_key_update or is_self_signing_key_update + ) + + if ( + is_cross_signing_key_update + and user_id not in cross_signing_keys_by_user ): + # This will give rise to 2 device updates. + # If we don't have the budget, stop here! + if remaining_length_budget < 2: + break + + if is_master_key_update: + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["master_key"] = master_key_by_user[user_id]["key_info"] + elif is_self_signing_key_update: result = cross_signing_keys_by_user.setdefault(user_id, {}) result["self_signing_key"] = self_signing_key_by_user[user_id][ "key_info" @@ -294,23 +325,44 @@ class DeviceWorkerStore(SQLBaseStore): else: key = (user_id, device_id) + if key not in query_map and remaining_length_budget < 1: + # We don't have space for a new entry + break + previous_update_stream_id, _ = query_map.get(key, (0, None)) if update_stream_id > previous_update_stream_id: + # FIXME If this overwrites an older update, this discards the + # previous OpenTracing context. + # It might make it harder to track down issues using OpenTracing. + # If there's a good reason why it doesn't matter, a comment here + # about that would not hurt. query_map[key] = (update_stream_id, update_context) + # As this update has been added to the response, advance the stream + # position. last_processed_stream_id = update_stream_id + # In the worst case scenario, each update is for a distinct user and is + # added either to the query_map or to cross_signing_keys_by_user, + # but not both: + # len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here, + # so len(query_map) + len(cross_signing_keys_by_user) <= limit. + results = await self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) - # add the updated cross-signing keys to the results list + # len(results) <= len(query_map) here, + # so len(results) + len(cross_signing_keys_by_user) <= limit. + + # Add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id results.append(("m.signing_key_update", result)) # also send the unstable version # FIXME: remove this when enough servers have upgraded + # and remove the length budgeting above. results.append(("org.matrix.signing_key_update", result)) return last_processed_stream_id, results @@ -322,7 +374,7 @@ class DeviceWorkerStore(SQLBaseStore): from_stream_id: int, now_stream_id: int, limit: int, - ): + ) -> List[Tuple[str, str, int, Optional[str]]]: """Return device update information for a given remote destination Args: @@ -333,7 +385,11 @@ class DeviceWorkerStore(SQLBaseStore): limit: Maximum number of device updates to return Returns: - List: List of device updates + List: List of device update tuples: + - user_id + - device_id + - stream_id + - opentracing_context """ # get the list of device updates that need to be sent sql = """ @@ -357,15 +413,21 @@ class DeviceWorkerStore(SQLBaseStore): Args: destination: The host the device updates are intended for from_stream_id: The minimum stream_id to filter updates by, exclusive - query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping - user_id/device_id to update stream_id and the relevant json-encoded - opentracing context + query_map: Dictionary mapping (user_id, device_id) to + (update stream_id, the relevant json-encoded opentracing context) Returns: - List of objects representing an device update EDU + List of objects representing a device update EDU. + + Postconditions: + The returned list has a length not exceeding that of the query_map: + len(result) <= len(query_map) """ devices = ( await self.get_e2e_device_keys_and_signatures( + # Because these are (user_id, device_id) tuples with all + # device_ids not being None, the returned list's length will not + # exceed that of query_map. query_map.keys(), include_all_devices=True, include_deleted_devices=True, diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 5cd7f6bb7a..b547bf8d99 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -125,7 +125,7 @@ class DeviceStoreTestCase(HomeserverTestCase): self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) - # Get all device updates ever meant for this remote + # Get device updates meant for this remote next_stream_id, device_updates = self.get_success( self.store.get_device_updates_by_remote("somehost", -1, limit=3) ) @@ -155,6 +155,116 @@ class DeviceStoreTestCase(HomeserverTestCase): # Check the newly-added device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) + # Check there are no more device updates left. + _, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3) + ) + self.assertEqual(device_updates, []) + + def test_get_device_updates_by_remote_cross_signing_key_updates( + self, + ) -> None: + """ + Tests that `get_device_updates_by_remote` limits the length of the return value + properly when cross-signing key updates are present. + Current behaviour is that the cross-signing key updates will always come in pairs, + even if that means leaving an earlier batch one EDU short of the limit. + """ + + assert self.hs.is_mine_id( + "@user_id:test" + ), "Test not valid: this MXID should be considered local" + + self.get_success( + self.store.set_e2e_cross_signing_key( + "@user_id:test", + "master", + { + "keys": { + "ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA=" + }, + "signatures": { + "@user_id:test": { + "ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA=" + } + }, + }, + ) + ) + self.get_success( + self.store.set_e2e_cross_signing_key( + "@user_id:test", + "self_signing", + { + "keys": { + "ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA=" + }, + "signatures": { + "@user_id:test": { + "ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA=" + } + }, + }, + ) + ) + + # Add some device updates with sequential `stream_id`s + # Note that the public cross-signing keys occupy the same space as device IDs, + # so also notify that those have updated. + device_ids = [ + "device_id1", + "device_id2", + "fakeMaster", + "fakeSelfSigning", + ] + + self.get_success( + self.store.add_device_change_to_streams( + "@user_id:test", device_ids, ["somehost"] + ) + ) + + # Get device updates meant for this remote + next_stream_id, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", -1, limit=3) + ) + + # Here we expect the device updates for `device_id1` and `device_id2`. + # That means we only receive 2 updates this time around. + # If we had a higher limit, we would expect to see the pair of + # (unstable-prefixed & unprefixed) signing key updates for the device + # represented by `fakeMaster` and `fakeSelfSigning`. + # Our implementation only sends these two variants together, so we get + # a short batch. + self.assertEqual(len(device_updates), 2, device_updates) + + # Check the first two devices (device_id1, device_id2) came out. + self._check_devices_in_updates(device_ids[:2], device_updates) + + # Get more device updates meant for this remote + next_stream_id, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3) + ) + + # The next 2 updates should be a cross-signing key update + # (the master key update and the self-signing key update are combined into + # one 'signing key update', but the cross-signing key update is emitted + # twice, once with an unprefixed type and once again with an unstable-prefixed type) + # (This is a temporary arrangement for backwards compatibility!) + self.assertEqual(len(device_updates), 2, device_updates) + self.assertEqual( + device_updates[0][0], "m.signing_key_update", device_updates[0] + ) + self.assertEqual( + device_updates[1][0], "org.matrix.signing_key_update", device_updates[1] + ) + + # Check there are no more device updates left. + _, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3) + ) + self.assertEqual(device_updates, []) + def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) -- cgit 1.4.1