diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70cf15dd7f..e6536c8456 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -264,6 +264,9 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 216a5925fc..c2fc847fbc 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
@wrap_as_background_process("update_client_ips")
- def _update_client_ips_batch(self):
+ async def _update_client_ips_batch(self) -> None:
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index def96637a2..e8379c73c4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
- now_stream_id = self._device_list_id_gen.get_current_token()
+ now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
+ @abc.abstractmethod
def get_device_stream_token(self) -> int:
- return self._device_list_id_gen.get_current_token()
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
@trace
async def get_user_devices_from_cache(
@@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id: str):
- """Get all devices (with any device keys) for a user
-
- Returns:
- Deferred which resolves to (stream_id, devices)
- """
- return self.db_pool.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn,
- user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(
- self, txn: LoggingTransaction, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in user_devices.items():
- result = {"device_id": device_id}
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
-
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 405b5eafa5..e5060d4c46 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
- def update_aliases_for_room(
+ async def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
- ):
+ ) -> None:
"""Repoint all of the aliases for a given room, to a different room.
Args:
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index af0b85e2c9..fb3b1f94de 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -22,7 +23,8 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -32,18 +34,58 @@ if TYPE_CHECKING:
class EndToEndKeyWorkerStore(SQLBaseStore):
+ def get_e2e_device_keys_for_federation_query(self, user_id: str):
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ Deferred which resolves to (stream_id, devices)
+ """
+ return self.db_pool.runInteraction(
+ "get_e2e_device_keys_for_federation_query",
+ self._get_e2e_device_keys_for_federation_query_txn,
+ user_id,
+ )
+
+ def _get_e2e_device_keys_for_federation_query_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
+ now_stream_id = self.get_device_stream_token()
+
+ devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in user_devices.items():
+ result = {"device_id": device_id}
+
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
@trace
- async def get_e2e_device_keys(
- self, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- """Fetch a list of device keys.
+ async def get_e2e_device_keys_for_cs_api(
+ self, query_list: List[Tuple[str, Optional[str]]]
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- include_deleted_devices (bool): whether to include null entries for
- devices which no longer exist (but were in the query_list).
- This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@@ -54,11 +96,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return {}
results = await self.db_pool.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
- query_list,
- include_all_devices,
- include_deleted_devices,
+ "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list,
)
# Build the result structure, un-jsonify the results, and add the
@@ -541,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
_get_all_user_signature_changes_for_remotes_txn,
)
+ @abc.abstractmethod
+ def get_device_stream_token(self) -> int:
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 45a1760170..d2f5b9a502 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- def add_user_filter(self, user_localpart, user_filter):
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
- return self.db_pool.runInteraction("add_user_filter", _do_txn)
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index 4db8949da7..2aac64901b 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from synapse.storage._base import SQLBaseStore
@@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
desc="insert_open_id_token",
)
- def get_user_id_for_open_id_token(self, token, ts_now_ms):
+ async def get_user_id_for_open_id_token(
+ self, token: str, ts_now_ms: int
+ ) -> Optional[str]:
def get_user_id_for_token_txn(txn):
sql = (
"SELECT user_id FROM open_id_tokens"
@@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 8b50e00553..de37866d25 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -252,7 +252,9 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
- def get_remote_profile_cache_entries_that_expire(self, last_checked):
+ async def get_remote_profile_cache_entries_that_expire(
+ self, last_checked: int
+ ) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
@@ -267,7 +269,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db_pool.cursor_to_dict(txn)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2fb5b02d7d..0de802a86b 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -18,8 +18,6 @@ import abc
import logging
from typing import List, Tuple, Union
-from twisted.internet import defer
-
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -149,9 +147,11 @@ class PushRulesWorkerStore(
)
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
- def have_push_rules_changed_for_user(self, user_id, last_id):
+ async def have_push_rules_changed_for_user(
+ self, user_id: str, last_id: int
+ ) -> bool:
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
+ return False
else:
def have_push_rules_changed_txn(txn):
@@ -163,7 +163,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index c1d8ef5286..d5dc7a36bb 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
- def get_room_with_stats(self, room_id: str):
+ async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve room with statistics.
Args:
@@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
res["public"] = bool(res["public"])
return res
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
)
@@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids",
)
- def count_public_rooms(self, network_tuple, ignore_non_federatable):
+ async def count_public_rooms(
+ self,
+ network_tuple: Optional[ThirdPartyInstanceID],
+ ignore_non_federatable: bool,
+ ) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
Args:
- network_tuple (ThirdPartyInstanceID|None)
- ignore_non_federatable (bool): If true filters out non-federatable rooms
+ network_tuple
+ ignore_non_federatable: If true filters out non-federatable rooms
"""
def _count_public_rooms_txn(txn):
@@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
)
@@ -608,15 +612,14 @@ class RoomWorkerStore(SQLBaseStore):
return row
- def get_media_mxcs_in_room(self, room_id):
+ async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
- room_id (str)
+ room_id
Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
+ The local and remote media as a lists of the media IDs.
"""
def _get_media_mxcs_in_room_txn(txn):
@@ -632,11 +635,13 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
- def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ async def quarantine_media_ids_in_room(
+ self, room_id: str, quarantined_by: str
+ ) -> int:
"""For a room loops through all events with media and quarantines
the associated media
"""
@@ -649,7 +654,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -712,9 +717,9 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- def quarantine_media_by_id(
+ async def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
- ):
+ ) -> int:
"""quarantines a single local or remote media id
Args:
@@ -733,11 +738,13 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
- def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ async def quarantine_media_ids_by_user(
+ self, user_id: str, quarantined_by: str
+ ) -> int:
"""quarantines all local media associated with a single user
Args:
@@ -749,7 +756,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
@@ -1306,8 +1313,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- def get_room_count(self):
- """Retrieve a list of all rooms
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
"""
def f(txn):
@@ -1316,7 +1323,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
- return self.db_pool.runInteraction("get_rooms", f)
+ return await self.db_pool.runInteraction("get_rooms", f)
async def add_event_report(
self,
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index be191dd870..c8c67953e4 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable, List, Tuple
+
from unpaddedbase64 import encode_base64
from synapse.storage._base import SQLBaseStore
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
- def get_event_reference_hashes(self, event_ids):
+ async def get_event_reference_hashes(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, Dict[str, bytes]]:
+ """Get all hashes for given events.
+
+ Args:
+ event_ids: The event IDs to get hashes for.
+
+ Returns:
+ A mapping of event ID to a mapping of algorithm to hash.
+ """
+
def f(txn):
return {
event_id: self._get_event_reference_hashes_txn(txn, event_id)
for event_id in event_ids
}
- return self.db_pool.runInteraction("get_event_reference_hashes", f)
+ return await self.db_pool.runInteraction("get_event_reference_hashes", f)
- async def add_event_hashes(self, event_ids):
+ async def add_event_hashes(
+ self, event_ids: Iterable[str]
+ ) -> List[Tuple[str, Dict[str, str]]]:
+ """
+
+ Args:
+ event_ids: The event IDs
+
+ Returns:
+ A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
+ """
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
return list(hashes.items())
- def _get_event_reference_hashes_txn(self, txn, event_id):
+ def _get_event_reference_hashes_txn(
+ self, txn: Cursor, event_id: str
+ ) -> Dict[str, bytes]:
"""Get all the hashes for a given PDU.
Args:
- txn (cursor):
- event_id (str): Id for the Event.
+ txn:
+ event_id: Id for the Event.
Returns:
- A dict[unicode, bytes] of algorithm -> hash.
+ A mapping of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 9eef8e57c5..b89668d561 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
class UIAuthStore(UIAuthWorkerStore):
- def delete_old_ui_auth_sessions(self, expiration_time: int):
+ async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
This is an epoch time in milliseconds.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index e3547e53b3..2f7c95fc74 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
class UserErasureStore(UserErasureWorkerStore):
- def mark_user_erased(self, user_id: str) -> None:
+ async def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
@@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db_pool.runInteraction("mark_user_erased", f)
+ await self.db_pool.runInteraction("mark_user_erased", f)
- def mark_user_not_erased(self, user_id: str) -> None:
+ async def mark_user_not_erased(self, user_id: str) -> None:
"""Indicate that user_id is no longer erased.
Args:
@@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db_pool.runInteraction("mark_user_not_erased", f)
+ await self.db_pool.runInteraction("mark_user_not_erased", f)
|