From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/replication/slave/storage/devices.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse/replication/slave/storage/devices.py') diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 9d8067342f..722f3745e9 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -16,14 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.data_stores.main.devices import DeviceWorkerStore -from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.devices import DeviceWorkerStore +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs -- cgit 1.5.1 From e3c91a3c5593c52298a7c511737d3e0eec4135ae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 Aug 2020 13:15:20 +0100 Subject: Make SlavedIdTracker.advance have same interface as MultiWriterIDGenerator (#8171) --- changelog.d/8171.misc | 1 + synapse/replication/slave/storage/_slaved_id_tracker.py | 4 ++-- synapse/replication/slave/storage/account_data.py | 4 ++-- synapse/replication/slave/storage/deviceinbox.py | 2 +- synapse/replication/slave/storage/devices.py | 4 ++-- synapse/replication/slave/storage/groups.py | 2 +- synapse/replication/slave/storage/presence.py | 2 +- synapse/replication/slave/storage/push_rule.py | 2 +- synapse/replication/slave/storage/pushers.py | 2 +- synapse/replication/slave/storage/receipts.py | 2 +- synapse/replication/slave/storage/room.py | 2 +- synapse/storage/databases/main/events_worker.py | 4 ++-- 12 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/8171.misc (limited to 'synapse/replication/slave/storage/devices.py') diff --git a/changelog.d/8171.misc b/changelog.d/8171.misc new file mode 100644 index 0000000000..cafbf23d83 --- /dev/null +++ b/changelog.d/8171.misc @@ -0,0 +1 @@ +Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index d43eaf3a29..047f2c50f7 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -21,9 +21,9 @@ class SlavedIdTracker(object): self.step = step self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self.advance(_load_current_id(db_conn, table, column)) + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, new_id): + def advance(self, instance_name, new_id): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self): diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 154f0e687c..bb66ba9b80 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(token) + self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(token) + self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index ee7f69a918..533d927701 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: - self._device_inbox_id_gen.advance(token) + self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 722f3745e9..596c72eb92 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -50,10 +50,10 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 3291558c7a..567b4a5cc1 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == GroupServerStream.NAME: - self._group_updates_id_gen.advance(token) + self._group_updates_id_gen.advance(instance_name, token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index a912c04360..025f6f6be8 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(token) + self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 90d90833f9..de904c943c 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -30,7 +30,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) if stream_name == PushRulesStream.NAME: - self._push_rules_stream_id_gen.advance(token) + self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 63300e5da6..9da218bfe8 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(token) + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 17ba1f22ac..5c2986e050 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(token) + self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 427c81772b..80ae803ad9 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PublicRoomsStream.NAME: - self._public_room_id_gen.advance(token) + self._public_room_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d59d73938a..e6247d682d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -113,9 +113,9 @@ class EventsWorkerStore(SQLBaseStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(token) + self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(-token) + self._backfill_id_gen.advance(instance_name, -token) super().process_replication_rows(stream_name, instance_name, token, rows) -- cgit 1.5.1 From aa07c37cf0a3b812e6aa1bb2d97d543e6925c8e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Sep 2020 12:41:21 +0100 Subject: Move and rename `get_devices_with_keys_by_user` (#8204) * Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore` this seems a better fit for it. This commit simply moves the existing code: no other changes at all. * Rename `get_devices_with_keys_by_user` to better reflect what it does. * get_device_stream_token abstract method To avoid referencing fields which are declared in the derived classes, make `get_device_stream_token` abstract, and define that in the classes which define `_device_list_id_gen`. --- changelog.d/8204.misc | 1 + synapse/handlers/device.py | 4 +- synapse/replication/slave/storage/devices.py | 3 ++ synapse/storage/databases/main/__init__.py | 3 ++ synapse/storage/databases/main/devices.py | 52 +++------------------- synapse/storage/databases/main/end_to_end_keys.py | 53 ++++++++++++++++++++++- 6 files changed, 67 insertions(+), 49 deletions(-) create mode 100644 changelog.d/8204.misc (limited to 'synapse/replication/slave/storage/devices.py') diff --git a/changelog.d/8204.misc b/changelog.d/8204.misc new file mode 100644 index 0000000000..979c8b227b --- /dev/null +++ b/changelog.d/8204.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index db417d60de..ee4666337a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler): return result async def on_federation_query_user_devices(self, user_id): - stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( + user_id + ) master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 596c72eb92..3b788c9625 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 70cf15dd7f..e6536c8456 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -264,6 +264,9 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index def96637a2..e8379c73c4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import Any, Dict, Iterable, List, Optional, Set, Tuple @@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore): update included in the response), and the list of updates, where each update is a pair of EDU type and EDU contents. """ - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore): }, ) + @abc.abstractmethod def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() + """Get the current stream id from the _device_list_id_gen""" + ... @trace async def get_user_devices_from_cache( @@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id: str): - """Get all devices (with any device keys) for a user - - Returns: - Deferred which resolves to (stream_id, devices) - """ - return self.db_pool.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in user_devices.items(): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - async def get_users_whose_devices_changed( self, from_key: str, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 50ecddf7fa..fb3b1f94de 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -33,6 +34,51 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): + def get_e2e_device_keys_for_federation_query(self, user_id: str): + """Get all devices (with any device keys) for a user + + Returns: + Deferred which resolves to (stream_id, devices) + """ + return self.db_pool.runInteraction( + "get_e2e_device_keys_for_federation_query", + self._get_e2e_device_keys_for_federation_query_txn, + user_id, + ) + + def _get_e2e_device_keys_for_federation_query_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: + now_stream_id = self.get_device_stream_token() + + devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + @trace async def get_e2e_device_keys_for_cs_api( self, query_list: List[Tuple[str, Optional[str]]] @@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): _get_all_user_signature_changes_for_remotes_txn, ) + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): -- cgit 1.5.1