summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/__init__.py3
-rw-r--r--synapse/storage/databases/main/client_ips.py4
-rw-r--r--synapse/storage/databases/main/devices.py52
-rw-r--r--synapse/storage/databases/main/directory.py6
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py73
-rw-r--r--synapse/storage/databases/main/filtering.py5
-rw-r--r--synapse/storage/databases/main/openid.py8
-rw-r--r--synapse/storage/databases/main/profile.py6
-rw-r--r--synapse/storage/databases/main/push_rule.py10
-rw-r--r--synapse/storage/databases/main/room.py49
-rw-r--r--synapse/storage/databases/main/signatures.py40
-rw-r--r--synapse/storage/databases/main/ui_auth.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py8
13 files changed, 156 insertions, 112 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py

index 70cf15dd7f..e6536c8456 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -264,6 +264,9 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 216a5925fc..c2fc847fbc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): + async def _update_client_ips_batch(self) -> None: # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index def96637a2..e8379c73c4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import Any, Dict, Iterable, List, Optional, Set, Tuple @@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore): update included in the response), and the list of updates, where each update is a pair of EDU type and EDU contents. """ - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore): }, ) + @abc.abstractmethod def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() + """Get the current stream id from the _device_list_id_gen""" + ... @trace async def get_user_devices_from_cache( @@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id: str): - """Get all devices (with any device keys) for a user - - Returns: - Deferred which resolves to (stream_id, devices) - """ - return self.db_pool.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in user_devices.items(): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - async def get_users_whose_devices_changed( self, from_key: str, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 405b5eafa5..e5060d4c46 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room( + async def update_aliases_for_room( self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): + ) -> None: """Repoint all of the aliases for a given room, to a different room. Args: @@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index af0b85e2c9..fb3b1f94de 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -22,7 +23,8 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -32,18 +34,58 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): + def get_e2e_device_keys_for_federation_query(self, user_id: str): + """Get all devices (with any device keys) for a user + + Returns: + Deferred which resolves to (stream_id, devices) + """ + return self.db_pool.runInteraction( + "get_e2e_device_keys_for_federation_query", + self._get_e2e_device_keys_for_federation_query_txn, + user_id, + ) + + def _get_e2e_device_keys_for_federation_query_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: + now_stream_id = self.get_device_stream_token() + + devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + @trace - async def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. + async def get_e2e_device_keys_for_cs_api( + self, query_list: List[Tuple[str, Optional[str]]] + ) -> Dict[str, Dict[str, JsonDict]]: + """Fetch a list of device keys, formatted suitably for the C/S API. Args: query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -54,11 +96,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return {} results = await self.db_pool.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, ) # Build the result structure, un-jsonify the results, and add the @@ -541,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): _get_all_user_signature_changes_for_remotes_txn, ) + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 45a1760170..d2f5b9a502 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - def add_user_filter(self, user_localpart, user_filter): + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.db_pool.runInteraction("add_user_filter", _do_txn) + return await self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index 4db8949da7..2aac64901b 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py
@@ -1,3 +1,5 @@ +from typing import Optional + from synapse.storage._base import SQLBaseStore @@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore): desc="insert_open_id_token", ) - def get_user_id_for_open_id_token(self, token, ts_now_ms): + async def get_user_id_for_open_id_token( + self, token: str, ts_now_ms: int + ) -> Optional[str]: def get_user_id_for_token_txn(txn): sql = ( "SELECT user_id FROM open_id_tokens" @@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 8b50e00553..de37866d25 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -252,7 +252,9 @@ class ProfileStore(ProfileWorkerStore): desc="delete_remote_profile_cache", ) - def get_remote_profile_cache_entries_that_expire(self, last_checked): + async def get_remote_profile_cache_entries_that_expire( + self, last_checked: int + ) -> Dict[str, str]: """Get all users who haven't been checked since `last_checked` """ @@ -267,7 +269,7 @@ class ProfileStore(ProfileWorkerStore): return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2fb5b02d7d..0de802a86b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from twisted.internet import defer - from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -149,9 +147,11 @@ class PushRulesWorkerStore( ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - def have_push_rules_changed_for_user(self, user_id, last_id): + async def have_push_rules_changed_for_user( + self, user_id: str, last_id: int + ) -> bool: if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) + return False else: def have_push_rules_changed_txn(txn): @@ -163,7 +163,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index c1d8ef5286..d5dc7a36bb 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) - def get_room_with_stats(self, room_id: str): + async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve room with statistics. Args: @@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore): res["public"] = bool(res["public"]) return res - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) @@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore): desc="get_public_room_ids", ) - def count_public_rooms(self, network_tuple, ignore_non_federatable): + async def count_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + ignore_non_federatable: bool, + ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms + network_tuple + ignore_non_federatable: If true filters out non-federatable rooms """ def _count_public_rooms_txn(txn): @@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn ) @@ -608,15 +612,14 @@ class RoomWorkerStore(SQLBaseStore): return row - def get_media_mxcs_in_room(self, room_id): + async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room Args: - room_id (str) + room_id Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. + The local and remote media as a lists of the media IDs. """ def _get_media_mxcs_in_room_txn(txn): @@ -632,11 +635,13 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) - def quarantine_media_ids_in_room(self, room_id, quarantined_by): + async def quarantine_media_ids_in_room( + self, room_id: str, quarantined_by: str + ) -> int: """For a room loops through all events with media and quarantines the associated media """ @@ -649,7 +654,7 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -712,9 +717,9 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - def quarantine_media_by_id( + async def quarantine_media_by_id( self, server_name: str, media_id: str, quarantined_by: str, - ): + ) -> int: """quarantines a single local or remote media id Args: @@ -733,11 +738,13 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + async def quarantine_media_ids_by_user( + self, user_id: str, quarantined_by: str + ) -> int: """quarantines all local media associated with a single user Args: @@ -749,7 +756,7 @@ class RoomWorkerStore(SQLBaseStore): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -1306,8 +1313,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - def get_room_count(self): - """Retrieve a list of all rooms + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. """ def f(txn): @@ -1316,7 +1323,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.db_pool.runInteraction("get_rooms", f) + return await self.db_pool.runInteraction("get_rooms", f) async def add_event_report( self, diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index be191dd870..c8c67953e4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py
@@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Tuple + from unpaddedbase64 import encode_base64 from synapse.storage._base import SQLBaseStore +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 ) - def get_event_reference_hashes(self, event_ids): + async def get_event_reference_hashes( + self, event_ids: Iterable[str] + ) -> Dict[str, Dict[str, bytes]]: + """Get all hashes for given events. + + Args: + event_ids: The event IDs to get hashes for. + + Returns: + A mapping of event ID to a mapping of algorithm to hash. + """ + def f(txn): return { event_id: self._get_event_reference_hashes_txn(txn, event_id) for event_id in event_ids } - return self.db_pool.runInteraction("get_event_reference_hashes", f) + return await self.db_pool.runInteraction("get_event_reference_hashes", f) - async def add_event_hashes(self, event_ids): + async def add_event_hashes( + self, event_ids: Iterable[str] + ) -> List[Tuple[str, Dict[str, str]]]: + """ + + Args: + event_ids: The event IDs + + Returns: + A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. + """ hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} @@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore): return list(hashes.items()) - def _get_event_reference_hashes_txn(self, txn, event_id): + def _get_event_reference_hashes_txn( + self, txn: Cursor, event_id: str + ) -> Dict[str, bytes]: """Get all the hashes for a given PDU. Args: - txn (cursor): - event_id (str): Id for the Event. + txn: + event_id: Id for the Event. Returns: - A dict[unicode, bytes] of algorithm -> hash. + A mapping of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 9eef8e57c5..b89668d561 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore): class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): + async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore): This is an epoch time in milliseconds. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index e3547e53b3..2f7c95fc74 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id: str) -> None: + async def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: @@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_erased", f) + await self.db_pool.runInteraction("mark_user_erased", f) - def mark_user_not_erased(self, user_id: str) -> None: + async def mark_user_not_erased(self, user_id: str) -> None: """Indicate that user_id is no longer erased. Args: @@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_not_erased", f) + await self.db_pool.runInteraction("mark_user_not_erased", f)