summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/__init__.py17
-rw-r--r--synapse/storage/databases/main/account_data.py93
-rw-r--r--synapse/storage/databases/main/appservice.py10
-rw-r--r--synapse/storage/databases/main/cache.py9
-rw-r--r--synapse/storage/databases/main/censor_events.py13
-rw-r--r--synapse/storage/databases/main/client_ips.py22
-rw-r--r--synapse/storage/databases/main/deviceinbox.py11
-rw-r--r--synapse/storage/databases/main/devices.py56
-rw-r--r--synapse/storage/databases/main/directory.py10
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py237
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py211
-rw-r--r--synapse/storage/databases/main/event_federation.py32
-rw-r--r--synapse/storage/databases/main/event_push_actions.py273
-rw-r--r--synapse/storage/databases/main/events.py148
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py77
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/filtering.py4
-rw-r--r--synapse/storage/databases/main/group_server.py10
-rw-r--r--synapse/storage/databases/main/lock.py14
-rw-r--r--synapse/storage/databases/main/media_repository.py3
-rw-r--r--synapse/storage/databases/main/metrics.py27
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py27
-rw-r--r--synapse/storage/databases/main/presence.py11
-rw-r--r--synapse/storage/databases/main/push_rule.py9
-rw-r--r--synapse/storage/databases/main/pusher.py29
-rw-r--r--synapse/storage/databases/main/receipts.py112
-rw-r--r--synapse/storage/databases/main/registration.py24
-rw-r--r--synapse/storage/databases/main/relations.py42
-rw-r--r--synapse/storage/databases/main/room.py226
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/search.py36
-rw-r--r--synapse/storage/databases/main/state.py47
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stats.py103
-rw-r--r--synapse/storage/databases/main/stream.py54
-rw-r--r--synapse/storage/databases/main/tags.py22
-rw-r--r--synapse/storage/databases/main/transactions.py62
-rw-r--r--synapse/storage/databases/main/ui_auth.py15
-rw-r--r--synapse/storage/databases/main/user_directory.py11
39 files changed, 1396 insertions, 742 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py

index 9ff2d8d8c3..f024761ba7 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -68,7 +68,7 @@ from .session import SessionStore from .signatures import SignatureStore from .state import StateStore from .stats import StatsStore -from .stream import StreamStore +from .stream import StreamWorkerStore from .tags import TagsStore from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore @@ -87,7 +87,7 @@ class DataStore( RoomStore, RoomBatchStore, RegistrationStore, - StreamStore, + StreamWorkerStore, ProfileStore, PresenceStore, TransactionWorkerStore, @@ -129,7 +129,12 @@ class DataStore( LockStore, SessionStore, ): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine @@ -143,11 +148,7 @@ class DataStore( ("device_lists_outbound_pokes", "stream_id"), ], ) - self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" - ) - self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._group_updates_id_gen = StreamIdGenerator( diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index f8bec266ac..32a553fdd7 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -14,15 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage._base import db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -34,13 +44,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class AccountDataWorkerStore(SQLBaseStore): - """This is an abstract base class where subclasses must implement - `get_max_account_data_stream_id` which can be called in the initializer. - """ +class AccountDataWorkerStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): - self._instance_name = hs.get_instance_name() + # `_can_write_to_account_data` indicates whether the current worker is allowed + # to write account data. A value of `True` implies that `_account_data_id_gen` + # is an `AbstractStreamIdGenerator` and not just a tracker. + self._account_data_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( @@ -61,8 +77,6 @@ class AccountDataWorkerStore(SQLBaseStore): writers=hs.config.worker.writers.account_data, ) else: - self._can_write_to_account_data = True - # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # @@ -70,7 +84,8 @@ class AccountDataWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.account_data: + if self._instance_name in hs.config.worker.writers.account_data: + self._can_write_to_account_data = True self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", @@ -90,8 +105,6 @@ class AccountDataWorkerStore(SQLBaseStore): "AccountDataAndTagsChangeCache", account_max ) - super().__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream @@ -113,7 +126,9 @@ class AccountDataWorkerStore(SQLBaseStore): room_id string to per room account_data dicts. """ - def get_account_data_for_user_txn(txn): + def get_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: rows = self.db_pool.simple_select_list_txn( txn, "account_data", @@ -132,7 +147,7 @@ class AccountDataWorkerStore(SQLBaseStore): ["room_id", "account_data_type", "content"], ) - by_room = {} + by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = db_to_json(row["content"]) @@ -177,7 +192,9 @@ class AccountDataWorkerStore(SQLBaseStore): A dict of the room account_data """ - def get_account_data_for_room_txn(txn): + def get_account_data_for_room_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonDict]: rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", @@ -207,7 +224,9 @@ class AccountDataWorkerStore(SQLBaseStore): The room account_data for that type, or None if there isn't any set. """ - def get_account_data_for_room_and_type_txn(txn): + def get_account_data_for_room_and_type_txn( + txn: LoggingTransaction, + ) -> Optional[JsonDict]: content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", @@ -243,14 +262,16 @@ class AccountDataWorkerStore(SQLBaseStore): if last_id == current_id: return [] - def get_updated_global_account_data_txn(txn): + def get_updated_global_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str]]: sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn @@ -273,14 +294,16 @@ class AccountDataWorkerStore(SQLBaseStore): if last_id == current_id: return [] - def get_updated_room_account_data_txn(txn): + def get_updated_room_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str]]: sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn @@ -299,7 +322,9 @@ class AccountDataWorkerStore(SQLBaseStore): mapping from room_id string to per room account_data dicts. """ - def get_updated_account_data_for_user_txn(txn): + def get_updated_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" @@ -316,7 +341,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) - account_data_by_room = {} + account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) @@ -353,12 +378,15 @@ class AccountDataWorkerStore(SQLBaseStore): ) ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: @@ -372,7 +400,8 @@ class AccountDataWorkerStore(SQLBaseStore): (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + + super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -389,6 +418,7 @@ class AccountDataWorkerStore(SQLBaseStore): The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -431,6 +461,7 @@ class AccountDataWorkerStore(SQLBaseStore): The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -452,7 +483,7 @@ class AccountDataWorkerStore(SQLBaseStore): def _add_account_data_for_user( self, - txn, + txn: LoggingTransaction, next_id: int, user_id: str, account_data_type: str, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..92c95a41d7 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -24,9 +24,8 @@ from synapse.appservice import ( from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.types import Connection from synapse.types import JsonDict from synapse.util import json_encoder @@ -58,7 +57,12 @@ def _make_exclusive_regex( class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.services_cache = load_appservices( hs.hostname, hs.config.appservice.app_service_config_files ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 36e8422fc6..0024348067 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, ) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 0f56e10220..fd3fc298b3 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py
@@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util import json_encoder @@ -31,7 +35,12 @@ logger = logging.getLogger(__name__) class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if ( diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index a6fd9f2636..f3881671fd 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -26,7 +26,6 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore -from synapse.storage.types import Connection from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache @@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict): class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.server.user_ips_max_age @@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): # (user_id, access_token, ip,) -> last_seen self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ab8766c75b..3682cb6a81 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -668,7 +673,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): # There's a type mismatch here between how we want to type the row and # what fetchone says it returns, but we silence it because we know that # res can't be None. - res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment] + res = cast(Tuple[Optional[int]], txn.fetchone()) if res[0] is None: # this can only happen if the `device_inbox` table is empty, in which # case we have no work to do. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..273adb61fd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -101,7 +107,9 @@ class DeviceWorkerStore(SQLBaseStore): "count_devices_by_users", count_devices_by_users_txn, user_ids ) - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device( + self, user_id: str, device_id: str + ) -> Optional[Dict[str, Any]]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -109,15 +117,35 @@ class DeviceWorkerStore(SQLBaseStore): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - A dict containing the device information - Raises: - StoreError: if the device is not found + A dict containing the device information, or `None` if the device does not + exist. """ return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", + allow_none=True, + ) + + async def get_device_opt( + self, user_id: str, device_id: str + ) -> Optional[Dict[str, Any]]: + """Retrieve a device. Only returns devices that are not marked as + hidden. + + Args: + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve + Returns: + A dict containing the device information, or None if the device does not exist. + """ + return await self.db_pool.simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + allow_none=True, ) async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: @@ -274,7 +302,9 @@ class DeviceWorkerStore(SQLBaseStore): # add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("m.signing_key_update", result)) + # also send the unstable version + # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) return now_stream_id, results @@ -949,7 +979,12 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -1081,7 +1116,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index a3442814d7..f76c6121e8 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py
@@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr + from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached -RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomAliasMapping: + room_id: str + room_alias: str + servers: List[str] class DirectoryWorkerStore(CacheInvalidationWorkerStore): diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b15fb71e62..0cb48b9dd7 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,35 +13,71 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Dict, Iterable, Mapping, Optional, Tuple, cast + +from typing_extensions import Literal, TypedDict from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction +from synapse.types import JsonDict, JsonSerializable from synapse.util import json_encoder +class RoomKey(TypedDict): + """`KeyBackupData` in the Matrix spec. + + https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid + """ + + first_message_index: int + forwarded_count: int + is_verified: bool + session_data: JsonSerializable + + class EndToEndRoomKeyStore(SQLBaseStore): + """The store for end to end room key backups. + + See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups + + As per the spec, backups are identified by an opaque version string. Internally, + version identifiers are assigned using incrementing integers. Non-numeric version + strings are treated as if they do not exist, since we would have never issued them. + """ + async def update_e2e_room_key( - self, user_id, version, room_id, session_id, room_key - ): + self, + user_id: str, + version: str, + room_id: str, + session_id: str, + room_key: RoomKey, + ) -> None: """Replaces the encrypted E2E room key for a given session in a given backup Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set + user_id: the user whose backup we're setting + version: the version ID of the backup we're updating + room_id: the ID of the room whose keys we're setting + session_id: the session whose room_key we're setting + room_key: the room_key being set Raises: StoreError """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + raise StoreError(404, "No backup with that version exists") await self.db_pool.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, - "version": version, + "version": version_int, "room_id": room_id, "session_id": session_id, }, @@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="update_e2e_room_key", ) - async def add_e2e_room_keys(self, user_id, version, room_keys): + async def add_e2e_room_keys( + self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]] + ) -> None: """Bulk add room keys to a given backup. Args: - user_id (str): the user whose backup we're adding to - version (str): the version ID of the backup for the set of keys we're adding to - room_keys (iterable[(str, str, dict)]): the keys to add, in the form - (roomID, sessionID, keyData) + user_id: the user whose backup we're adding to + version: the version ID of the backup for the set of keys we're adding to + room_keys: the keys to add, in the form (roomID, sessionID, keyData) """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + raise StoreError(404, "No backup with that version exists") values = [] for (room_id, session_id, room_key) in room_keys: values.append( { "user_id": user_id, - "version": version, + "version": version_int, "room_id": room_id, "session_id": session_id, "first_message_index": room_key["first_message_index"], @@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_e2e_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> Dict[ + Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] + ]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup for the set of keys we're querying - room_id (str): Optional. the ID of the room whose keys we're querying, if any. + user_id: the user whose backup we're querying + version: the version ID of the backup for the set of keys we're querying + room_id: Optional. the ID of the room whose keys we're querying, if any. If not specified, we return the keys for all the rooms in the backup. - session_id (str): Optional. the session whose room_key we're querying, if any. + session_id: Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we return all the keys in this version of the backup (or for the specified room) Returns: - A list of dicts giving the session_data and message metadata for - these room keys. + A dict giving the session_data and message metadata for these room keys. + `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}` """ try: - version = int(version) + version_int = int(version) except ValueError: return {"rooms": {}} - keyvalues = {"user_id": user_id, "version": version} + keyvalues = {"user_id": user_id, "version": version_int} if room_id: keyvalues["room_id"] = room_id if session_id: @@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_keys", ) - sessions = {"rooms": {}} + sessions: Dict[ + Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] + ] = {"rooms": {}} for row in rows: room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) room_entry["sessions"][row["session_id"]] = { @@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions - async def get_e2e_room_keys_multi(self, user_id, version, room_keys): + async def get_e2e_room_keys_multi( + self, + user_id: str, + version: str, + room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]], + ) -> Dict[str, Dict[str, RoomKey]]: """Get multiple room keys at a time. The difference between this function and get_e2e_room_keys is that this function can be used to retrieve multiple specific keys at a time, whereas get_e2e_room_keys is used for @@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore): specific key. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about - room_keys (dict[str, dict[str, iterable[str]]]): a map from - room ID -> {"session": [session ids]} indicating the session IDs - that we want to query + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about + room_keys: a map from room ID -> {"sessions": [session ids]} + indicating the session IDs that we want to query Returns: - dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key + A map of room IDs to session IDs to room key """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + return {} return await self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, - version, + version_int, room_keys, ) @staticmethod - def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + def _get_e2e_room_keys_multi_txn( + txn: LoggingTransaction, + user_id: str, + version: int, + room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]], + ) -> Dict[str, Dict[str, RoomKey]]: if not room_keys: return {} @@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): txn.execute(sql, params) - ret = {} + ret: Dict[str, Dict[str, RoomKey]] = {} for row in txn: room_id = row[0] @@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore): user_id: the user whose backup we're querying version: the version ID of the backup we're querying about """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + return 0 return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", - keyvalues={"user_id": user_id, "version": version}, + keyvalues={"user_id": user_id, "version": version_int}, retcol="COUNT(*)", desc="count_e2e_room_keys", ) @trace async def delete_e2e_room_keys( - self, user_id, version, room_id=None, session_id=None - ): + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> None: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. Args: - user_id(str): the user whose backup we're deleting from - version(str): the version ID of the backup for the set of keys we're deleting - room_id(str): Optional. the ID of the room whose keys we're deleting, if any. + user_id: the user whose backup we're deleting from + version: the version ID of the backup for the set of keys we're deleting + room_id: Optional. the ID of the room whose keys we're deleting, if any. If not specified, we delete the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. + session_id: Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we delete all the keys in this version of the backup (or for the specified room) - - Returns: - The deletion transaction """ + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + return - keyvalues = {"user_id": user_id, "version": int(version)} + keyvalues = {"user_id": user_id, "version": version_int} if room_id: keyvalues["room_id"] = room_id if session_id: @@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @staticmethod - def _get_current_version(txn, user_id): + def _get_current_version(txn: LoggingTransaction, user_id: str) -> int: txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions " "WHERE user_id=? AND deleted=0", (user_id,), ) - row = txn.fetchone() - if not row: + # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will + # be `NULL` when there are no available versions. + row = cast(Tuple[Optional[int]], txn.fetchone()) + if row[0] is None: raise StoreError(404, "No current backup version") return row[0] - async def get_e2e_room_keys_version_info(self, user_id, version=None): + async def get_e2e_room_keys_version_info( + self, user_id: str, version: Optional[str] = None + ) -> JsonDict: """Get info metadata about a version of our room_keys backup. Args: - user_id(str): the user whose backup we're querying - version(str): Optional. the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: Optional. the version ID of the backup we're querying about If missing, we return the information about the current version. Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present @@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): etag(int): tag of the keys in the backup """ - def _get_e2e_room_keys_version_info_txn(txn): + def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict: if version is None: this_version = self._get_current_version(txn, user_id) else: @@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): except ValueError: # Our versions are all ints so if we can't convert it to an integer, # it isn't there. - raise StoreError(404, "No row found") + raise StoreError(404, "No backup with that version exists") result = self.db_pool.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, ) + assert result is not None # see comment on `simple_select_one_txn` result["auth_data"] = db_to_json(result["auth_data"]) result["version"] = str(result["version"]) if result["etag"] is None: @@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: + async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str: """Atomically creates a new version of this user's e2e_room_keys store with the given version info. Args: - user_id(str): the user whose backup we're creating a version - info(dict): the info about the backup version to be created + user_id: the user whose backup we're creating a version + info: the info about the backup version to be created Returns: The newly created version ID """ - def _create_e2e_room_keys_version_txn(txn): + def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str: txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", (user_id,), ) - current_version = txn.fetchone()[0] + current_version = cast(Tuple[Optional[int]], txn.fetchone())[0] if current_version is None: - current_version = "0" + current_version = 0 - new_version = str(int(current_version) + 1) + new_version = current_version + 1 self.db_pool.simple_insert_txn( txn, @@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): }, ) - return new_version + return str(new_version) return await self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn @@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): self, user_id: str, version: str, - info: Optional[dict] = None, + info: Optional[JsonDict] = None, version_etag: Optional[int] = None, ) -> None: """Update a given backup version @@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version_etag: etag of the keys in the backup. If None, then the etag is not updated. """ - updatevalues = {} + updatevalues: Dict[str, object] = {} if info is not None and "auth_data" in info: updatevalues["auth_data"] = json_encoder.encode(info["auth_data"]) @@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - await self.db_pool.simple_update( + try: + version_int = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it doesn't exist. + raise StoreError(404, "No backup with that version exists") + + await self.db_pool.simple_update_one( table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, + keyvalues={"user_id": user_id, "version": version_int}, updatevalues=updatevalues, desc="update_e2e_room_keys_version", ) @@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): or if the version requested doesn't exist. """ - def _delete_e2e_room_keys_version_txn(txn): + def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None: if version is None: this_version = self._get_current_version(txn, user_id) - if this_version is None: - raise StoreError(404, "No current backup version") else: - this_version = version + try: + this_version = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it isn't there. + raise StoreError(404, "No backup with that version exists") self.db_pool.simple_delete_txn( txn, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b..57b5ffbad3 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,19 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + cast, +) import attr from canonicaljson import encode_canonical_json -from twisted.enterprise.adbapi import Connection - from synapse.api.constants import DeviceKeyAlgorithms from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -50,7 +63,12 @@ class DeviceKeyLookupResult: class EndToEndKeyBackgroundStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore): ) -class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): +class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._allow_device_name_lookup_over_federation = ( @@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): # Build the result structure, un-jsonify the results, and add the # "unsigned" section - rv = {} + rv: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): @@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): # add each cross-signing signature to the correct device in the result dict. for (user_id, key_id, device_id, signature) in cross_sigs_result: target_device_result = result[user_id][device_id] + # We've only looked up cross-signatures for non-deleted devices with key + # data. + assert target_device_result is not None + assert target_device_result.keys is not None target_device_signatures = target_device_result.keys.setdefault( "signatures", {} ) @@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): return result def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False + self, + txn: LoggingTransaction, + query_list: Collection[Tuple[str, str]], + include_all_devices: bool = False, + include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: """Get information on devices from the database @@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): return result def _get_e2e_cross_signing_signatures_for_devices_txn( - self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]] ) -> List[Tuple[str, str, str, str]]: """Get cross-signing signatures for a given list of devices @@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): ) txn.execute(signature_sql, signature_query_params) - return txn.fetchall() + return cast( + List[ + Tuple[ + str, + str, + str, + str, + ] + ], + txn.fetchall(), + ) async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str] @@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ - def _add_e2e_one_time_keys(txn): + def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("new_keys", new_keys) @@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): A mapping from algorithm to number of keys for that algorithm. """ - def _count_e2e_one_time_keys(txn): + def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]: sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ?" @@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): ) def _set_e2e_fallback_keys_txn( - self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + fallback_keys: JsonDict, ) -> None: # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad @@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """Returns a user's cross-signing key. Args: @@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id): + def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): their user ID will map to None. """ - return await self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, ) + # The `Optional` comes from the `@cachedList` decorator. + return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) + def _get_bare_e2e_cross_signing_keys_bulk_txn( self, - txn: Connection, + txn: LoggingTransaction, user_ids: Iterable[str], - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Dict[str, JsonDict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_ids (list[str]): the users whose keys are being requested + txn: db connection + user_ids: the users whose keys are being requested Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, their user - ID will not be in the dict. + Mapping from user ID to key type to key data. + If a user's cross-signing keys were not found, their user ID will not be in + the dict. """ - result = {} + result: Dict[str, Dict[str, JsonDict]] = {} for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause( @@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): user_id = row["user_id"] key_type = row["keytype"] key = db_to_json(row["keydata"]) - user_info = result.setdefault(user_id, {}) - user_info[key_type] = key + user_keys = result.setdefault(user_id, {}) + user_keys[key_type] = key return result def _get_e2e_cross_signing_signatures_txn( self, - txn: Connection, - keys: Dict[str, Dict[str, dict]], + txn: LoggingTransaction, + keys: Dict[str, Optional[Dict[str, JsonDict]]], from_user_id: str, - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing signatures made by a user on a set of keys. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - keys (dict[str, dict[str, dict]]): a map of user ID to key type to - key data. This dict will be modified to add signatures. - from_user_id (str): fetch the signatures made by this user + txn: db connection + keys: a map of user ID to key type to key data. + This dict will be modified to add signatures. + from_user_id: fetch the signatures made by this user Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. The return value will be the same as the keys argument, - with the modifications included. + Mapping from user ID to key type to key data. + The return value will be the same as the keys argument, with the + modifications included. """ # find out what cross-signing keys (a.k.a. devices) we need to get # signatures for. This is a map of (user_id, device_id) to key type # (device_id is the key's public part). - devices = {} + devices: Dict[Tuple[str, str], str] = {} - for user_id, user_info in keys.items(): - if user_info is None: + for user_id, user_keys in keys.items(): + if user_keys is None: continue - for key_type, key in user_info.items(): + for key_type, key in user_keys.items(): device_id = None for k in key["keys"].values(): device_id = k + # `key` ought to be a `CrossSigningKey`, whose .keys property is a + # dictionary with a single entry: + # "algorithm:base64_public_key": "base64_public_key" + # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing + assert isinstance(device_id, str) devices[(user_id, device_id)] = key_type for batch in batch_iter(devices.keys(), size=100): @@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): # and add the signatures to the appropriate keys for row in rows: - key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] + key_id: str = row["key_id"] + target_user_id: str = row["target_user_id"] + target_device_id: str = row["target_device_id"] key_type = devices[(target_user_id, target_device_id)] # We need to copy everything, because the result may have come # from the cache. dict.copy only does a shallow copy, so we # need to recursively copy the dicts that will be modified. - user_info = keys[target_user_id] = keys[target_user_id].copy() - target_user_key = user_info[key_type] = user_info[key_type].copy() + user_keys = keys[target_user_id] + # `user_keys` cannot be `None` because we only fetched signatures for + # users with keys + assert user_keys is not None + user_keys = keys[target_user_id] = user_keys.copy() + + target_user_key = user_keys[key_type] = user_keys[key_type].copy() if "signatures" in target_user_key: signatures = target_user_key["signatures"] = target_user_key[ "signatures" @@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Dict[str, dict]]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): if last_id == current_id: return [], current_id, False - def _get_all_user_signature_changes_for_remotes_txn(txn): + def _get_all_user_signature_changes_for_remotes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, from_user_id AS user_id FROM user_signature_stream @@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): @trace def _claim_e2e_one_time_key_simple( - txn, user_id: str, device_id: str, algorithm: str + txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str ) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that don't support RETURNING. @@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): @trace def _claim_e2e_one_time_key_returning( - txn, user_id: str, device_id: str, algorithm: str + txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str ) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that support RETURNING. @@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json - results = {} + results: Dict[str, Dict[str, Dict[str, str]]] = {} for user_id, device_id, algorithm in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that @@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._cross_signing_id_gen = StreamIdGenerator( + db_conn, "e2e_cross_signing_keys", "stream_id" + ) + async def set_e2e_device_keys( self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict ) -> bool: @@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): or the keys were already in the database. """ - def _set_e2e_device_keys_txn(txn): + def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("time_now", time_now) @@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: - def delete_e2e_keys_by_device_txn(txn): + def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: log_kv( { "message": "Deleting keys for device", @@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): + def _set_e2e_cross_signing_key_txn( + self, + txn: LoggingTransaction, + user_id: str, + key_type: str, + key: JsonDict, + stream_id: int, + ) -> None: """Set a user's cross-signing key. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user to set the signing key for - key_type (str): the type of key that is being set: either 'master' + txn: db connection + user_id: the user to set the signing key for + key_type: the type of key that is being set: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - key (dict): the key data - stream_id (int) + key: the key data + stream_id """ # the 'key' dict will look something like: # { @@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - async def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key( + self, user_id: str, key_type: str, key: JsonDict + ) -> None: """Set a user's cross-signing key. Args: - user_id (str): the user to set the user-signing key for - key_type (str): the type of cross-signing key to set - key (dict): the key data + user_id: the user to set the user-signing key for + key_type: the type of cross-signing key to set + key: the key data """ async with self._cross_signing_id_gen.get_next() as stream_id: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..270b30800b 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine @@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -279,7 +288,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas new_front = set() for chunk in batch_iter(front, 100): # Pull the auth events either from the cache or DB. - to_fetch = [] # Event IDs to fetch from DB # type: List[str] + to_fetch: List[str] = [] # Event IDs to fetch from DB for event_id in chunk: res = self._event_auth_cache.get(event_id) if res is None: @@ -606,8 +615,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # currently walking, either from cache or DB. search, chunk = search[:-100], search[-100:] - found = [] # Results found # type: List[Tuple[str, str, int]] - to_fetch = [] # Event IDs to fetch from DB # type: List[str] + found: List[Tuple[str, str, int]] = [] # Results found + to_fetch: List[str] = [] # Event IDs to fetch from DB for _, event_id in chunk: res = self._event_auth_cache.get(event_id) if res is None: @@ -1384,7 +1393,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas count = await self.db_pool.simple_select_one_onecol( table="federation_inbound_events_staging", keyvalues={"room_id": room_id}, - retcol="COALESCE(COUNT(*), 0)", + retcol="COUNT(*)", desc="prune_staged_events_in_room_count", ) @@ -1476,9 +1485,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """Update the prometheus metrics for the inbound federation staging area.""" def _get_stats_for_federation_staging_txn(txn): - txn.execute( - "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging" - ) + txn.execute("SELECT count(*) FROM federation_inbound_events_staging") (count,) = txn.fetchone() txn.execute( @@ -1514,7 +1521,12 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..a98e6b2593 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -30,29 +33,64 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] -DEFAULT_HIGHLIGHT_ACTION = [ +DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [ + "notify", + {"set_tweak": "highlight", "value": False}, +] +DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [ "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}, ] -class BasePushAction(TypedDict): - event_id: str - actions: List[Union[dict, str]] - +@attr.s(slots=True, frozen=True, auto_attribs=True) +class HttpPushAction: + """ + HttpPushAction instances include the information used to generate HTTP + requests to a push gateway. + """ -class HttpPushAction(BasePushAction): + event_id: str room_id: str stream_ordering: int + actions: List[Union[dict, str]] +@attr.s(slots=True, frozen=True, auto_attribs=True) class EmailPushAction(HttpPushAction): + """ + EmailPushAction instances include the information used to render an email + push notification. + """ + received_ts: Optional[int] -def _serialize_action(actions, is_highlight): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPushAction(EmailPushAction): + """ + UserPushAction instances include the necessary information to respond to + /notifications requests. + """ + + topological_ordering: int + highlight: bool + profile_tag: str + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class NotifCounts: + """ + The per-user, per-room count of notifications. Used by sync and push. + """ + + notify_count: int + unread_count: int + highlight_count: int + + +def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str: """Custom serializer for actions. This allows us to "compress" common actions. We use the fact that most users have the same actions for notifs (and for @@ -70,7 +108,7 @@ def _serialize_action(actions, is_highlight): return json_encoder.encode(actions) -def _deserialize_action(actions, is_highlight): +def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]: """Custom deserializer for actions. This allows us to "compress" common actions""" if actions: return db_to_json(actions) @@ -82,12 +120,17 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn - self.stream_ordering_month_ago = None - self.stream_ordering_day_ago = None + self.stream_ordering_month_ago: Optional[int] = None + self.stream_ordering_day_ago: Optional[int] = None cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) @@ -111,7 +154,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): room_id: str, user_id: str, last_read_event_id: Optional[str], - ) -> Dict[str, int]: + ) -> NotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after the given read receipt. @@ -140,15 +183,15 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_unread_counts_by_receipt_txn( self, - txn, - room_id, - user_id, - last_read_event_id, - ): + txn: LoggingTransaction, + room_id: str, + user_id: str, + last_read_event_id: Optional[str], + ) -> NotifCounts: stream_ordering = None if last_read_event_id is not None: - stream_ordering = self.get_stream_id_for_event_txn( + stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined] txn, last_read_event_id, allow_none=True, @@ -166,13 +209,15 @@ class EventPushActionsWorkerStore(SQLBaseStore): retcol="event_id", ) - stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined] return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering ) - def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): + def _get_unread_counts_by_pos_txn( + self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + ) -> NotifCounts: sql = ( "SELECT" " COUNT(CASE WHEN notif = 1 THEN 1 END)," @@ -210,16 +255,16 @@ class EventPushActionsWorkerStore(SQLBaseStore): # for this row. unread_count += row[1] - return { - "notify_count": notif_count, - "unread_count": unread_count, - "highlight_count": highlight_count, - } + return NotifCounts( + notify_count=notif_count, + unread_count=unread_count, + highlight_count=highlight_count, + ) async def get_push_action_users_in_range( - self, min_stream_ordering, max_stream_ordering - ): - def f(txn): + self, min_stream_ordering: int, max_stream_ordering: int + ) -> List[str]: + def f(txn: LoggingTransaction) -> List[str]: sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1" @@ -227,8 +272,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f) - return ret + return await self.db_pool.runInteraction("get_push_action_users_in_range", f) async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -254,7 +298,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ # find rooms that have a read receipt in them and return the next # push actions - def get_after_receipt(txn): + def get_after_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool]]: # find rooms that have a read receipt in them and return the next # push actions sql = ( @@ -280,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt @@ -289,7 +335,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # There are rooms with push actions in them but you don't have a read receipt in # them e.g. rooms you've been invited to, so get push actions for rooms which do # not have read receipts in them too. - def get_no_receipt(txn): + def get_no_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight " @@ -309,19 +357,19 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - } + HttpPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[3], row[4]), + ) for row in after_read_receipt + no_read_receipt ] @@ -329,7 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by stream_ordering, oldest first. - notifs.sort(key=lambda r: r["stream_ordering"]) + notifs.sort(key=lambda r: r.stream_ordering) # Take only up to the limit. We have to stop at the limit because # one of the subqueries may have hit the limit. @@ -359,7 +407,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ # find rooms that have a read receipt in them and return the most recent # push actions - def get_after_receipt(txn): + def get_after_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool, int]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight, e.received_ts" @@ -384,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt @@ -393,7 +443,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # There are rooms with push actions in them but you don't have a read receipt in # them e.g. rooms you've been invited to, so get push actions for rooms which do # not have read receipts in them too. - def get_no_receipt(txn): + def get_no_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool, int]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight, e.received_ts" @@ -413,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt @@ -421,13 +473,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Make a list of dicts from the two sets of results. notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - "received_ts": row[5], - } + EmailPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[3], row[4]), + received_ts=row[5], + ) for row in after_read_receipt + no_read_receipt ] @@ -435,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r["received_ts"] or 0)) + notifs.sort(key=lambda r: -(r.received_ts or 0)) # Now return the first `limit` return notifs[:limit] @@ -456,7 +508,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): True if there may be push to process, False if there definitely isn't. """ - def _get_if_maybe_push_in_range_for_user_txn(txn): + def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool: sql = """ SELECT 1 FROM event_push_actions WHERE user_id = ? AND stream_ordering > ? AND notif = 1 @@ -490,19 +542,21 @@ class EventPushActionsWorkerStore(SQLBaseStore): # This is a helper function for generating the necessary tuple that # can be used to insert into the `event_push_actions_staging` table. - def _gen_entry(user_id, actions): + def _gen_entry( + user_id: str, actions: List[Union[dict, str]] + ) -> Tuple[str, str, str, int, int, int]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 return ( event_id, # event_id column user_id, # user_id column - _serialize_action(actions, is_highlight), # actions column + _serialize_action(actions, bool(is_highlight)), # actions column notif, # notif column is_highlight, # highlight column int(count_as_unread), # unread column ) - def _add_push_actions_to_staging_txn(txn): + def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None: # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. @@ -530,12 +584,11 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = await self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", ) - return res except Exception: # this method is called from an exception handler, so propagating # another exception here really isn't helpful - there's nothing @@ -588,7 +641,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) @staticmethod - def _find_first_stream_ordering_after_ts_txn(txn, ts): + def _find_first_stream_ordering_after_ts_txn( + txn: LoggingTransaction, ts: int + ) -> int: """ Find the stream_ordering of the first event that was received on or after a given timestamp. This is relatively slow as there is no index @@ -600,14 +655,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): stream_ordering Args: - txn (twisted.enterprise.adbapi.Transaction): - ts (int): timestamp to search for + txn: + ts: timestamp to search for Returns: - int: stream ordering + The stream ordering """ txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] + max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_stream_ordering is None: return 0 @@ -663,8 +718,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end - async def get_time_of_last_push_action_before(self, stream_ordering): - def f(txn): + async def get_time_of_last_push_action_before( + self, stream_ordering: int + ) -> Optional[int]: + def f(txn: LoggingTransaction) -> Optional[Tuple[int]]: sql = ( "SELECT e.received_ts" " FROM event_push_actions AS ep" @@ -674,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " LIMIT 1" ) txn.execute(sql, (stream_ordering,)) - return txn.fetchone() + return cast(Optional[Tuple[int]], txn.fetchone()) result = await self.db_pool.runInteraction( "get_time_of_last_push_action_before", f @@ -682,7 +739,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): return result[0] if result else None @wrap_as_background_process("rotate_notifs") - async def _rotate_notifs(self): + async def _rotate_notifs(self) -> None: if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return self._doing_notif_rotation = True @@ -700,7 +757,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): finally: self._doing_notif_rotation = False - def _rotate_notifs_txn(self, txn): + def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool: """Archives older notifications into event_push_summary. Returns whether the archiving process has caught up or not. """ @@ -725,6 +782,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): stream_row = txn.fetchone() if stream_row: (offset_stream_ordering,) = stream_row + assert self.stream_ordering_day_ago is not None rotate_to_stream_ordering = min( self.stream_ordering_day_ago, offset_stream_ordering ) @@ -740,7 +798,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # We have caught up iff we were limited by `stream_ordering_day_ago` return caught_up - def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + def _rotate_notifs_before_txn( + self, txn: LoggingTransaction, rotate_to_stream_ordering: int + ) -> None: old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", @@ -861,8 +921,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): + self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + ) -> None: """ Purges old push actions for a user and room before a given stream_ordering. @@ -910,7 +970,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -929,9 +994,15 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) async def get_push_actions_for_user( - self, user_id, before=None, limit=50, only_highlight=False - ): - def f(txn): + self, + user_id: str, + before: Optional[str] = None, + limit: int = 50, + only_highlight: bool = False, + ) -> List[UserPushAction]: + def f( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, int, str, bool, str, int]]: before_clause = "" if before: before_clause = "AND epa.stream_ordering < ?" @@ -958,32 +1029,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall() + ) push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) - for pa in push_actions: - pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) - return push_actions + return [ + UserPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[4], row[5]), + received_ts=row[7], + topological_ordering=row[3], + highlight=row[5], + profile_tag=row[6], + ) + for row in push_actions + ] -def _action_has_highlight(actions): +def _action_has_highlight(actions: List[Union[dict, str]]) -> bool: for action in actions: - try: - if action.get("set_tweak", None) == "highlight": - return action.get("value", True) - except AttributeError: - pass + if not isinstance(action, dict): + continue + + if action.get("set_tweak", None) == "highlight": + return action.get("value", True) return False -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _EventPushSummary: """Summary of pending event push actions for a given user in a given room. Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. """ - unread_count = attr.ib(type=int) - stream_ordering = attr.ib(type=int) - old_user_id = attr.ib(type=str) - notif_count = attr.ib(type=int) + unread_count: int + stream_ordering: int + old_user_id: str + notif_count: int diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..dd255aefb9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@ from collections import OrderedDict from typing import ( TYPE_CHECKING, Any, + Collection, Dict, Generator, Iterable, @@ -40,10 +41,13 @@ from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.types import Connection from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id @@ -94,7 +98,7 @@ class PersistEventsStore: hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore", - db_conn: Connection, + db_conn: LoggingDatabaseConnection, ): self.hs = hs self.db_pool = db @@ -1319,14 +1323,13 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] - def _store_event_txn(self, txn, events_and_contexts): + def _store_event_txn( + self, + txn: LoggingTransaction, + events_and_contexts: Collection[Tuple[EventBase, EventContext]], + ) -> None: """Insert new events into the event, event_json, redaction and state_events tables. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting """ if not events_and_contexts: @@ -1339,46 +1342,58 @@ class PersistEventsStore: d.pop("redacted_because", None) return d - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_values_txn( txn, table="event_json", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": json_encoder.encode( - event.internal_metadata.get_dict() - ), - "json": json_encoder.encode(event_dict(event)), - "format_version": event.format_version, - } + keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), + values=( + ( + event.event_id, + event.room_id, + json_encoder.encode(event.internal_metadata.get_dict()), + json_encoder.encode(event_dict(event)), + event.format_version, + ) for event, _ in events_and_contexts - ], + ), ) - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_values_txn( txn, table="events", - values=[ - { - "instance_name": self._instance_name, - "stream_ordering": event.internal_metadata.stream_ordering, - "topological_ordering": event.depth, - "depth": event.depth, - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "processed": True, - "outlier": event.internal_metadata.is_outlier(), - "origin_server_ts": int(event.origin_server_ts), - "received_ts": self._clock.time_msec(), - "sender": event.sender, - "contains_url": ( - "url" in event.content and isinstance(event.content["url"], str) - ), - } + keys=( + "instance_name", + "stream_ordering", + "topological_ordering", + "depth", + "event_id", + "room_id", + "type", + "processed", + "outlier", + "origin_server_ts", + "received_ts", + "sender", + "contains_url", + ), + values=( + ( + self._instance_name, + event.internal_metadata.stream_ordering, + event.depth, # topological_ordering + event.depth, # depth + event.event_id, + event.room_id, + event.type, + True, # processed + event.internal_metadata.is_outlier(), + int(event.origin_server_ts), + self._clock.time_msec(), + event.sender, + "url" in event.content and isinstance(event.content["url"], str), + ) for event, _ in events_and_contexts - ], + ), ) # If we're persisting an unredacted event we go and ensure @@ -1397,27 +1412,15 @@ class PersistEventsStore: ) txn.execute(sql + clause, [False] + args) - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, _ in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - state_values.append(vals) - - self.db_pool.simple_insert_many_txn( - txn, table="state_events", values=state_values + self.db_pool.simple_insert_many_values_txn( + txn, + table="state_events", + keys=("event_id", "room_id", "type", "state_key"), + values=( + (event.event_id, event.room_id, event.type, event.state_key) + for event, _ in events_and_contexts + if event.is_state() + ), ) def _store_rejected_events_txn(self, txn, events_and_contexts): @@ -1780,10 +1783,14 @@ class PersistEventsStore: ) if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + txn.call_after( + self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) + ) if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + txn.call_after( + self.store.get_thread_summary.invalidate, (parent_id, event.room_id) + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. @@ -1969,14 +1976,17 @@ class PersistEventsStore: txn, self.store.get_retention_policy_for_room, (event.room_id,) ) - def store_event_search_txn(self, txn, event, key, value): + def store_event_search_txn( + self, txn: LoggingTransaction, event: EventBase, key: str, value: str + ) -> None: """Add event to the search table Args: - txn (cursor): - event (EventBase): - key (str): - value (str): + txn: The database transaction. + event: The event being added to the search table. + key: A key describing the search value (one of "content.name", + "content.topic", or "content.body") + value: The value from the event's content. """ self.store.store_search_entries_txn( txn, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c88fd35e7f..a68f14ba48 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast import attr @@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -83,7 +84,12 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( @@ -234,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ################################################################################ - async def _background_reindex_fields_sender(self, progress, batch_size): + async def _background_reindex_fields_sender( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - def reindex_txn(txn): + def reindex_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id, json FROM events" " INNER JOIN event_json USING (event_id)" @@ -301,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result - async def _background_reindex_origin_server_ts(self, progress, batch_size): + async def _background_reindex_origin_server_ts( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id FROM events" " WHERE ? <= stream_ordering AND stream_ordering < ?" @@ -375,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result - async def _cleanup_extremities_bg_update(self, progress, batch_size): + async def _cleanup_extremities_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to clean out extremities that should have been deleted previously. @@ -396,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # have any descendants, but if they do then we should delete those # extremities. - def _cleanup_extremities_bg_update_txn(txn): + def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int: # The set of extremity event IDs that we're checking this round original_set = set() - # A dict[str, set[str]] of event ID to their prev events. - graph = {} + # A dict[str, Set[str]] of event ID to their prev events. + graph: Dict[str, Set[str]] = {} # The set of descendants of the original set that are not rejected # nor soft-failed. Ancestors of these events should be removed @@ -530,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) + self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] ) self.db_pool.simple_delete_many_txn( @@ -552,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES ) - def _drop_table_txn(txn): + def _drop_table_txn(txn: LoggingTransaction) -> None: txn.execute("DROP TABLE _extremities_to_check") await self.db_pool.runInteraction( @@ -561,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return num_handled - async def _redactions_received_ts(self, progress, batch_size): + async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int: """Handles filling out the `received_ts` column in redactions.""" last_event_id = progress.get("last_event_id", "") - def _redactions_received_ts_txn(txn): + def _redactions_received_ts_txn(txn: LoggingTransaction) -> int: # Fetch the set of event IDs that we want to update sql = """ SELECT event_id FROM redactions @@ -616,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return count - async def _event_fix_redactions_bytes(self, progress, batch_size): + async def _event_fix_redactions_bytes( + self, progress: JsonDict, batch_size: int + ) -> int: """Undoes hex encoded censored redacted event JSON.""" - def _event_fix_redactions_bytes_txn(txn): + def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None: # This update is quite fast due to new index. txn.execute( """ @@ -644,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return 1 - async def _event_store_labels(self, progress, batch_size): + async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int: """Background update handler which will store labels for existing events.""" last_event_id = progress.get("last_event_id", "") - def _event_store_labels_txn(txn): + def _event_store_labels_txn(txn: LoggingTransaction) -> int: txn.execute( """ SELECT event_id, json FROM event_json @@ -748,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ), ) - return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore + return cast( + List[Tuple[str, str, JsonDict, bool, bool]], + [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn], + ) results = await self.db_pool.runInteraction( desc="_rejected_events_metadata_get", func=get_rejected_events @@ -906,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): def _calculate_chain_cover_txn( self, - txn: Cursor, + txn: LoggingTransaction, last_room_id: str, last_depth: int, last_stream: int, @@ -1017,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, - self.event_chain_id_gen, + self.event_chain_id_gen, # type: ignore[attr-defined] event_to_room_id, event_to_types, - event_to_auth_chain, + cast(Dict[str, Sequence[str]], event_to_auth_chain), ) return _CalculateChainCover( @@ -1040,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): """ current_event_id = progress.get("current_event_id", "") - def purged_chain_cover_txn(txn) -> int: + def purged_chain_cover_txn(txn: LoggingTransaction) -> int: # The event ID from events will be null if the chain ID / sequence # number points to a purged event. sql = """ @@ -1175,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # Iterate the parent IDs and invalidate caches. for parent_id in {r[1] for r in relations_to_insert}: cache_tuple = (parent_id,) - self._invalidate_cache_and_stream( - txn, self.get_relations_for_event, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined] ) - self._invalidate_cache_and_stream( - txn, self.get_aggregation_groups_for_event, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined] ) - self._invalidate_cache_and_stream( - txn, self.get_thread_summary, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined] ) if results: @@ -1214,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): """ batch_size = max(batch_size, 1) - def process(txn: Cursor) -> int: + def process(txn: LoggingTransaction) -> int: last_stream = progress.get("last_stream", -(1 << 31)) txn.execute( """ diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a..8d4287045a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -1383,10 +1383,6 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_events_token(self) -> int: - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index cf842803bc..cb9ee08fa8 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json @@ -63,7 +63,7 @@ class FilteringStore(SQLBaseStore): sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart,)) - max_id = txn.fetchone()[0] # type: ignore[index] + max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_id is None: filter_id = 0 else: diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index bb621df0dd..3f6086050b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py
@@ -19,8 +19,7 @@ from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import JsonDict from synapse.util import json_encoder @@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict): class GroupServerWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): database.updates.register_background_index_update( update_name="local_group_updates_index", index_name="local_group_updates_stream_id_index", diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index a540f7fb26..bedacaf0d7 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py
@@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.types import Connection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import Clock from synapse.util.stringutils import random_string @@ -54,7 +57,12 @@ class LockStore(SQLBaseStore): `last_renewed_ts` column with the current time. """ - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._reactor = hs.get_reactor() diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 1b076683f7..cbba356b4a 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -23,6 +23,7 @@ from typing import ( Optional, Tuple, Union, + cast, ) from synapse.storage._base import SQLBaseStore @@ -220,7 +221,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): WHERE user_id = ? """ txn.execute(sql, args) - count = txn.fetchone()[0] # type: ignore[index] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d901933ae4..1480a0f048 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) @@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes @@ -100,7 +105,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def _count_messages(txn): sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events + SELECT COUNT(*) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ @@ -117,7 +122,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): like_clause = "%:" + self.hs.hostname sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events + SELECT COUNT(*) FROM events WHERE type = 'm.room.encrypted' AND sender LIKE ? AND stream_ordering > ? @@ -134,7 +139,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): async def count_daily_active_e2ee_rooms(self): def _count(txn): sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ @@ -156,7 +161,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def _count_messages(txn): sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events + SELECT COUNT(*) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ @@ -173,7 +178,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): like_clause = "%:" + self.hs.hostname sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events + SELECT COUNT(*) FROM events WHERE type = 'm.room.message' AND sender LIKE ? AND stream_ordering > ? @@ -190,7 +195,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): async def count_daily_active_rooms(self): def _count(txn): sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ @@ -226,7 +231,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): Returns number of users seen in the past time_from period """ sql = """ - SELECT COALESCE(count(*), 0) FROM ( + SELECT COUNT(*) FROM ( SELECT user_id FROM user_ips WHERE last_seen > ? GROUP BY user_id @@ -253,7 +258,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): thirty_days_ago_in_secs = now - thirty_days_in_secs sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT platform, COUNT(*) FROM ( SELECT users.name, platform, users.creation_ts * 1000, MAX(uip.last_seen) @@ -291,7 +296,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): results[row[0]] = row[1] sql = """ - SELECT COALESCE(count(*), 0) FROM ( + SELECT COUNT(*) FROM ( SELECT users.name, users.creation_ts * 1000, MAX(uip.last_seen) FROM users diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b5284e4f67..8f09dd8e87 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -16,8 +16,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + make_in_list_sql_clause, +) from synapse.util.caches.descriptors import cached +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -30,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -49,7 +59,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): def _count_users(txn): # Exclude app service users sql = """ - SELECT COALESCE(count(*), 0) + SELECT COUNT(*) FROM monthly_active_users LEFT JOIN users ON monthly_active_users.user_id=users.name @@ -76,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): def _count_users_by_service(txn): sql = """ - SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + SELECT COALESCE(appservice_id, 'native'), COUNT(*) FROM monthly_active_users LEFT JOIN users ON monthly_active_users.user_id=users.name GROUP BY appservice_id; @@ -103,7 +113,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] + tp["medium"], canonicalise_email(tp["address"]) ) if user_id: users.append(user_id) @@ -212,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._mau_stats_only = hs.config.server.mau_stats_only diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cc0eebdb46..cbf9ec38f7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator @@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): """ # Add user entries to the table, updating the presence_stream_id column if the user already # exists in the table. + presence_stream_id = self._presence_id_gen.get_current_token() await self.db_pool.simple_upsert_many( table="users_to_send_full_presence_to", key_names=("user_id",), @@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): # devices at different times, each device will receive full presence once - when # the presence stream ID in their sync token is less than the one in the table # for their user ID. - value_values=( - (self._presence_id_gen.get_current_token(),) for _ in user_ids - ), + value_values=[(presence_stream_id,) for _ in user_ids], desc="add_users_to_send_full_presence_to", ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..e01c94930a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore @@ -81,7 +81,12 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b73ce53c91..747b4f31df 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -22,7 +22,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -196,27 +196,6 @@ class PusherWorkerStore(SQLBaseStore): # This only exists for the cachedList decorator raise NotImplementedError() - @cachedList( - cached_method_name="get_if_user_has_pusher", - list_name="user_ids", - num_args=1, - ) - async def get_if_users_have_pushers( - self, user_ids: Iterable[str] - ) -> Dict[str, bool]: - rows = await self.db_pool.simple_select_many_batch( - table="pushers", - column="user_name", - iterable=user_ids, - retcols=["user_name"], - desc="get_if_users_have_pushers", - ) - - result = {user_id: False for user_id in user_ids} - result.update({r["user_name"]: True for r in rows}) - - return result - async def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ) -> None: @@ -515,7 +494,7 @@ class PusherStore(PusherWorkerStore): # invalidate, since we the user might not have had a pusher before await self.db_pool.runInteraction( "add_pusher", - self._invalidate_cache_and_stream, # type: ignore + self._invalidate_cache_and_stream, # type: ignore[attr-defined] self.get_if_user_has_pusher, (user_id,), ) @@ -524,7 +503,7 @@ class PusherStore(PusherWorkerStore): self, app_id: str, pushkey: str, user_id: str ) -> None: def delete_pusher_txn(txn, stream_id): - self._invalidate_cache_and_stream( # type: ignore + self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) @@ -569,7 +548,7 @@ class PusherStore(PusherWorkerStore): pushers = list(await self.get_pushers_by_user_id(user_id)) def delete_pushers_txn(txn, stream_ids): - self._invalidate_cache_and_stream( # type: ignore + self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdb..bf0b903af2 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,29 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from twisted.internet import defer +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -36,7 +51,12 @@ logger = logging.getLogger(__name__) class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): @@ -78,17 +98,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - def get_max_receipt_stream_id(self): - """Get the current max stream ID for receipts stream - - Returns: - int - """ + def get_max_receipt_stream_id(self) -> int: + """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @cached() - async def get_users_with_read_receipts_in_room(self, room_id): - receipts = await self.get_receipts_for_room(room_id, "m.read") + async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: + receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -119,7 +135,9 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=2) - async def get_receipts_for_user(self, user_id, receipt_type): + async def get_receipts_for_user( + self, user_id: str, receipt_type: str + ) -> Dict[str, str]: rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, @@ -129,8 +147,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return {row["room_id"]: row["event_id"] for row in rows} - async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): + async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_type: str + ) -> JsonDict: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" @@ -209,10 +229,10 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonDict]: """See get_linearized_receipts_for_room""" - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" @@ -250,11 +270,13 @@ class ReceiptsWorkerStore(SQLBaseStore): list_name="room_ids", num_args=3, ) - async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms( + self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + ) -> Dict[str, List[JsonDict]]: if not room_ids: return {} - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -323,7 +345,7 @@ class ReceiptsWorkerStore(SQLBaseStore): A dictionary of roomids to a list of receipts. """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -379,7 +401,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return defer.succeed([]) - def _get_users_sent_receipts_between_txn(txn): + def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? @@ -419,7 +441,9 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_receipts_txn(txn): + def get_all_updated_receipts_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized @@ -446,8 +470,8 @@ class ReceiptsWorkerStore(SQLBaseStore): def _invalidate_get_users_with_receipts_in_room( self, room_id: str, receipt_type: str, user_id: str - ): - if receipt_type != "m.read": + ) -> None: + if receipt_type != ReceiptTypes.READ: return res = self.get_users_with_read_receipts_in_room.cache.get_immediate( @@ -461,7 +485,9 @@ class ReceiptsWorkerStore(SQLBaseStore): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + def invalidate_caches_for_receipt( + self, room_id: str, receipt_type: str, user_id: str + ) -> None: self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( @@ -482,11 +508,18 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + data: JsonDict, + stream_id: int, + ) -> Optional[int]: """Inserts a read-receipt into the database if it's newer than the current RR - Returns: int|None + Returns: None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) @@ -550,7 +583,7 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) - if receipt_type == "m.read" and stream_ordering is not None: + if receipt_type == ReceiptTypes.READ and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -580,7 +613,7 @@ class ReceiptsWorkerStore(SQLBaseStore): else: # we need to points in graph -> linearized form. # TODO: Make this better. - def graph_to_linear(txn): + def graph_to_linear(txn: LoggingTransaction) -> str: clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) @@ -634,11 +667,16 @@ class ReceiptsWorkerStore(SQLBaseStore): return stream_id, max_persisted_id async def insert_graph_receipt( - self, room_id, receipt_type, user_id, event_ids, data - ): + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -649,8 +687,14 @@ class ReceiptsWorkerStore(SQLBaseStore): ) def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..4175c82a25 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -794,7 +794,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): yesterday = int(self._clock.time()) - (60 * 60 * 24) sql = """ - SELECT user_type, COALESCE(count(*), 0) AS count FROM ( + SELECT user_type, COUNT(*) AS count FROM ( SELECT CASE WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' @@ -819,7 +819,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): def _count_users(txn): txn.execute( """ - SELECT COALESCE(COUNT(*), 0) FROM users + SELECT COUNT(*) FROM users WHERE appservice_id IS NULL """ ) @@ -856,7 +856,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Args: medium: threepid medium e.g. email - address: threepid address e.g. me@example.com + address: threepid address e.g. me@example.com. This must already be + in canonical form. Returns: The user ID or None if no user id/threepid mapping exists @@ -1356,12 +1357,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - res: Dict[str, Any] = self.db_pool.simple_select_one_txn( - txn, - "registration_tokens", - keyvalues={"token": token}, - retcols=["pending", "completed"], - ) # type: ignore + res = cast( + Dict[str, Any], + self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ), + ) # Decrement pending and increment completed self.db_pool.simple_update_one_txn( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..4ff6aed253 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, cast import attr @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_relations_for_event( self, event_id: str, + room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, aggregation_key: Optional[str] = None, @@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. aggregation_key: Only fetch events with this aggregation key, if given. @@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore): the form `{"event_id": "..."}`. """ - where_clause = ["relates_to_id = ?"] - where_args: List[Union[str, int]] = [event_id] + where_clause = ["relates_to_id = ?", "room_id = ?"] + where_args: List[Union[str, int]] = [event_id, room_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_aggregation_groups_for_event( self, event_id: str, + room_id: str, event_type: Optional[str] = None, limit: int = 5, direction: str = "b", @@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. direction: Whether to fetch the highest count first (`"b"`) or @@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore): `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] + where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] + where_args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] if event_type: where_clause.append("type = ?") @@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + async def get_applicable_edit( + self, event_id: str, room_id: str + ) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. @@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: The original event ID + room_id: The original event's room ID Returns: The most recent edit, if any. @@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore): WHERE relates_to_id = ? AND relation_type = ? + AND edit.room_id = ? AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE)) + txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) row = txn.fetchone() if row: return row[0] @@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore): @cached() async def get_thread_summary( - self, event_id: str + self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: """Get the number of threaded replies, the senders of those replies, and the latest reply (if any) for the given event. Args: - event_id: The original event ID + event_id: Summarize the thread related to this event ID. + room_id: The room the event belongs to. Returns: The number of items in the thread and the most recent response, if any. @@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore): INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1 """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) row = txn.fetchone() if row is None: return 0, None @@ -376,14 +390,16 @@ class RelationsWorkerStore(SQLBaseStore): latest_event_id = row[0] sql = """ - SELECT COALESCE(COUNT(event_id), 0) + SELECT COUNT(event_id) FROM event_relations + INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) - count = txn.fetchone()[0] # type: ignore[index] + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) + count = cast(Tuple[int], txn.fetchone())[0] return count, latest_event_id diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d694d852d..c0e837854a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + +import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.databases.main.search import SearchStore +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -38,9 +54,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RatelimitOverride: + messages_per_second: int + burst_count: int class RoomSortOrder(Enum): @@ -71,8 +88,13 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" -class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): +class RoomWorkerStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -83,7 +105,7 @@ class RoomWorkerStore(SQLBaseStore): room_creator_user_id: str, is_public: bool, room_version: RoomVersion, - ): + ) -> None: """Stores a room. Args: @@ -111,7 +133,7 @@ class RoomWorkerStore(SQLBaseStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> dict: + async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve a room. Args: @@ -136,7 +158,9 @@ class RoomWorkerStore(SQLBaseStore): A dict containing the room information, or None if the room is unknown. """ - def get_room_with_stats_txn(txn, room_id): + def get_room_with_stats_txn( + txn: LoggingTransaction, room_id: str + ) -> Optional[Dict[str, Any]]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -185,7 +209,7 @@ class RoomWorkerStore(SQLBaseStore): ignore_non_federatable: If true filters out non-federatable rooms """ - def _count_public_rooms_txn(txn): + def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] if network_tuple: @@ -195,6 +219,7 @@ class RoomWorkerStore(SQLBaseStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -208,7 +233,7 @@ class RoomWorkerStore(SQLBaseStore): sql = """ SELECT - COALESCE(COUNT(*), 0) + COUNT(*) FROM ( %(published_sql)s ) published @@ -226,7 +251,7 @@ class RoomWorkerStore(SQLBaseStore): } txn.execute(sql, query_args) - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn @@ -235,11 +260,11 @@ class RoomWorkerStore(SQLBaseStore): async def get_room_count(self) -> int: """Retrieve the total number of rooms.""" - def f(txn): + def f(txn: LoggingTransaction) -> int: sql = "SELECT count(*) FROM rooms" txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 + row = cast(Tuple[int], txn.fetchone()) + return row[0] return await self.db_pool.runInteraction("get_rooms", f) @@ -251,7 +276,7 @@ class RoomWorkerStore(SQLBaseStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ): + ) -> List[Dict[str, Any]]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -272,7 +297,7 @@ class RoomWorkerStore(SQLBaseStore): """ where_clauses = [] - query_args = [] + query_args: List[Union[str, int]] = [] if network_tuple: if network_tuple.appservice_id: @@ -281,6 +306,7 @@ class RoomWorkerStore(SQLBaseStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -372,7 +398,9 @@ class RoomWorkerStore(SQLBaseStore): LIMIT ? """ - def _get_largest_public_rooms_txn(txn): + def _get_largest_public_rooms_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: txn.execute(sql, query_args) results = self.db_pool.cursor_to_dict(txn) @@ -435,7 +463,7 @@ class RoomWorkerStore(SQLBaseStore): """ # Filter room names by a string where_statement = "" - search_pattern = [] + search_pattern: List[object] = [] if search_term: where_statement = """ WHERE LOWER(state.name) LIKE ? @@ -543,7 +571,9 @@ class RoomWorkerStore(SQLBaseStore): where_statement, ) - def _get_rooms_paginate_txn(txn): + def _get_rooms_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: # Add the search term into the WHERE clause # and execute the data query txn.execute(info_sql, search_pattern + [limit, start]) @@ -575,7 +605,7 @@ class RoomWorkerStore(SQLBaseStore): # Add the search term into the WHERE clause if present txn.execute(count_sql, search_pattern) - room_count = txn.fetchone() + room_count = cast(Tuple[int], txn.fetchone()) return rooms, room_count[0] return await self.db_pool.runInteraction( @@ -620,7 +650,7 @@ class RoomWorkerStore(SQLBaseStore): burst_count: How many actions that can be performed before being limited. """ - def set_ratelimit_txn(txn): + def set_ratelimit_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_upsert_txn( txn, table="ratelimit_override", @@ -643,7 +673,7 @@ class RoomWorkerStore(SQLBaseStore): user_id: user ID of the user """ - def delete_ratelimit_txn(txn): + def delete_ratelimit_txn(txn: LoggingTransaction) -> None: row = self.db_pool.simple_select_one_txn( txn, table="ratelimit_override", @@ -667,7 +697,7 @@ class RoomWorkerStore(SQLBaseStore): await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) @cached() - async def get_retention_policy_for_room(self, room_id): + async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]: """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -676,13 +706,15 @@ class RoomWorkerStore(SQLBaseStore): configuration). Args: - room_id (str): The ID of the room to get the retention policy of. + room_id: The ID of the room to get the retention policy of. Returns: - dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + A dict containing "min_lifetime" and "max_lifetime" for this room. """ - def get_retention_policy_for_room_txn(txn): + def get_retention_policy_for_room_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Optional[int]]]: txn.execute( """ SELECT min_lifetime, max_lifetime FROM room_retention @@ -707,19 +739,23 @@ class RoomWorkerStore(SQLBaseStore): "max_lifetime": self.config.retention.retention_default_max_lifetime, } - row = ret[0] + min_lifetime = ret[0]["min_lifetime"] + max_lifetime = ret[0]["max_lifetime"] # If one of the room's policy's attributes isn't defined, use the matching # attribute from the default policy. # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. - if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention.retention_default_min_lifetime + if min_lifetime is None: + min_lifetime = self.config.retention.retention_default_min_lifetime - if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention.retention_default_max_lifetime + if max_lifetime is None: + max_lifetime = self.config.retention.retention_default_max_lifetime - return row + return { + "min_lifetime": min_lifetime, + "max_lifetime": max_lifetime, + } async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room @@ -731,7 +767,9 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of the media IDs. """ - def _get_media_mxcs_in_room_txn(txn): + def _get_media_mxcs_in_room_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], List[str]]: local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_media_mxcs = [] remote_media_mxcs = [] @@ -757,7 +795,7 @@ class RoomWorkerStore(SQLBaseStore): logger.info("Quarantining media in room: %s", room_id) - def _quarantine_media_in_room_txn(txn): + def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int: local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) return self._quarantine_media_txn( txn, local_mxcs, remote_mxcs, quarantined_by @@ -767,13 +805,11 @@ class RoomWorkerStore(SQLBaseStore): "quarantine_media_in_room", _quarantine_media_in_room_txn ) - def _get_media_mxcs_in_room_txn(self, txn, room_id): + def _get_media_mxcs_in_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> Tuple[List[str], List[Tuple[str, str]]]: """Retrieves all the local and remote media MXC URIs in a given room - Args: - txn (cursor) - room_id (str) - Returns: The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. @@ -841,7 +877,7 @@ class RoomWorkerStore(SQLBaseStore): logger.info("Quarantining media: %s/%s", server_name, media_id) is_local = server_name == self.config.server.server_name - def _quarantine_media_by_id_txn(txn): + def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int: local_mxcs = [media_id] if is_local else [] remote_mxcs = [(server_name, media_id)] if not is_local else [] @@ -863,7 +899,7 @@ class RoomWorkerStore(SQLBaseStore): quarantined_by: The ID of the user who made the quarantine request """ - def _quarantine_media_by_user_txn(txn): + def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int: local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) @@ -871,7 +907,9 @@ class RoomWorkerStore(SQLBaseStore): "quarantine_media_by_user", _quarantine_media_by_user_txn ) - def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): + def _get_media_ids_by_user_txn( + self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True + ) -> List[str]: """Retrieves local media IDs by a given user Args: @@ -900,7 +938,7 @@ class RoomWorkerStore(SQLBaseStore): def _quarantine_media_txn( self, - txn, + txn: LoggingTransaction, local_mxcs: List[str], remote_mxcs: List[Tuple[str, str]], quarantined_by: Optional[str], @@ -928,12 +966,15 @@ class RoomWorkerStore(SQLBaseStore): # set quarantine if quarantined_by is not None: sql += "AND safe_from_quarantine = ?" - rows = [(quarantined_by, media_id, False) for media_id in local_mxcs] + txn.executemany( + sql, [(quarantined_by, media_id, False) for media_id in local_mxcs] + ) # remove from quarantine else: - rows = [(quarantined_by, media_id) for media_id in local_mxcs] + txn.executemany( + sql, [(quarantined_by, media_id) for media_id in local_mxcs] + ) - txn.executemany(sql, rows) # Note that a rowcount of -1 can be used to indicate no rows were affected. total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 @@ -951,7 +992,7 @@ class RoomWorkerStore(SQLBaseStore): async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False - ) -> Dict[str, dict]: + ) -> Dict[str, Dict[str, Optional[int]]]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. @@ -971,7 +1012,9 @@ class RoomWorkerStore(SQLBaseStore): "min_lifetime" (int|None), and "max_lifetime" (int|None). """ - def get_rooms_for_retention_period_in_range_txn(txn): + def get_rooms_for_retention_period_in_range_txn( + txn: LoggingTransaction, + ) -> Dict[str, Dict[str, Optional[int]]]: range_conditions = [] args = [] @@ -1050,11 +1093,14 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) - self.config = hs.config - self.db_pool.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, @@ -1085,7 +1131,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_populate_rooms_creator_column, ) - async def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention( + self, progress: JsonDict, batch_size: int + ) -> int: """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -1095,7 +1143,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): last_room = progress.get("room_id", "") - def _background_insert_retention_txn(txn): + def _background_insert_retention_txn(txn: LoggingTransaction) -> bool: txn.execute( """ SELECT state.room_id, state.event_id, events.json @@ -1154,15 +1202,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size async def _background_add_rooms_room_version_column( - self, progress: dict, batch_size: int - ): + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to go and add room version information to `rooms` table from `current_state_events` table. """ last_room_id = progress.get("room_id", "") - def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): + def _background_add_rooms_room_version_column_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT room_id, json FROM current_state_events INNER JOIN event_json USING (room_id, event_id) @@ -1223,7 +1273,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size async def _remove_tombstoned_rooms_from_directory( - self, progress, batch_size + self, progress: JsonDict, batch_size: int ) -> int: """Removes any rooms with tombstone events from the room directory @@ -1233,7 +1283,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): last_room = progress.get("room_id", "") - def _get_rooms(txn): + def _get_rooms(txn: LoggingTransaction) -> List[str]: txn.execute( """ SELECT room_id @@ -1271,7 +1321,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return len(rooms) @abstractmethod - def set_room_is_public(self, room_id, is_public): + def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]: # this will need to be implemented if a background update is performed with # existing (tombstoned, public) rooms in the database. # @@ -1318,7 +1368,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): 32-bit integer field. """ - def process(txn: Cursor) -> int: + def process(txn: LoggingTransaction) -> int: last_room = progress.get("last_room", "") txn.execute( """ @@ -1375,15 +1425,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return 0 async def _background_populate_rooms_creator_column( - self, progress: dict, batch_size: int - ): + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to go and add creator information to `rooms` table from `current_state_events` table. """ last_room_id = progress.get("room_id", "") - def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): + def _background_populate_rooms_creator_column_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT room_id, json FROM event_json INNER JOIN rooms AS room USING (room_id) @@ -1434,15 +1486,20 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size -class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) - self.config = hs.config + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") async def upsert_room_on_join( self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] - ): + ) -> None: """Ensure that the room is stored in the table Called when we join a room over federation, and overwrites any room version @@ -1488,7 +1545,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion - ): + ) -> None: """ When we receive an invite or any other event over federation that may relate to a room we are not in, store the version of the room if we don't already know the room version. @@ -1528,8 +1585,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.hs.get_notifier().on_new_replication_data() async def set_room_is_public_appservice( - self, room_id, appservice_id, network_id, is_public - ): + self, room_id: str, appservice_id: str, network_id: str, is_public: bool + ) -> None: """Edit the appservice/network specific public room list. Each appservice can have a number of published room lists associated @@ -1538,11 +1595,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): network. Args: - room_id (str) - appservice_id (str) - network_id (str) - is_public (bool): Whether to publish or unpublish the room from the - list. + room_id + appservice_id + network_id + is_public: Whether to publish or unpublish the room from the list. """ if is_public: @@ -1607,7 +1663,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): event_report: json list of information from event report """ - def _get_event_report_txn(txn, report_id): + def _get_event_report_txn( + txn: LoggingTransaction, report_id: int + ) -> Optional[Dict[str, Any]]: sql = """ SELECT @@ -1679,9 +1737,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): count: total number of event reports matching the filter criteria """ - def _get_event_reports_paginate_txn(txn): + def _get_event_reports_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: filters = [] - args = [] + args: List[object] = [] if user_id: filters.append("er.user_id LIKE ?") @@ -1705,7 +1765,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): where_clause ) txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..cda80d6511 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Used by `_get_joined_hosts` to ensure only one thing mutates the cache @@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile @@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 642560a70d..3cbaca21b5 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -14,13 +14,18 @@ import logging import re -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set +import attr + from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -29,10 +34,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -SearchEntry = namedtuple( - "SearchEntry", - ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], -) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SearchEntry: + key: str + value: str + event_id: str + room_id: str + stream_ordering: Optional[int] + origin_server_ts: int def _clean_value_for_search(value: str) -> str: @@ -105,7 +115,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if not hs.config.server.enable_search: @@ -358,7 +373,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index fa2c3b1feb..2fb3e65192 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,6 @@ # limitations under the License. import collections.abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership @@ -22,7 +21,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter @@ -39,24 +42,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -182,11 +177,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): NotFoundError if the room is unknown """ state_ids = await self.get_current_state_ids(room_id) + + if not state_ids: + raise NotFoundError(f"Current state for room {room_id} is empty") + create_id = state_ids.get((EventTypes.Create, "")) # If we can't find the create event, assume we've hit a dead end if not create_id: - raise NotFoundError("Unknown room %s" % (room_id,)) + raise NotFoundError(f"No create event in current state for room {room_id}") # Retrieve the room's create event and return create_event = await self.get_event(create_id) @@ -349,7 +348,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -536,5 +540,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 7f3624b128..188afec332 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py
@@ -56,7 +56,9 @@ class StateDeltasStore(SQLBaseStore): prev_stream_id = int(prev_stream_id) # check we're not going backwards - assert prev_stream_id <= max_stream_id + assert ( + prev_stream_id <= max_stream_id + ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5d7b59d861..427ae1f649 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@ import logging from enum import Enum from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast from typing_extensions import Counter @@ -24,7 +24,11 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import StoreError -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -96,7 +100,12 @@ class UserSortOrder(Enum): class StatsStore(StateDeltasStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -117,7 +126,9 @@ class StatsStore(StateDeltasStore): self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") self.db_pool.updates.register_noop_background_update("populate_stats_prepare") - async def _populate_stats_process_users(self, progress, batch_size): + async def _populate_stats_process_users( + self, progress: JsonDict, batch_size: int + ) -> int: """ This is a background update which regenerates statistics for users. """ @@ -129,7 +140,7 @@ class StatsStore(StateDeltasStore): last_user_id = progress.get("last_user_id", "") - def _get_next_batch(txn): + def _get_next_batch(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT name FROM users WHERE name > ? @@ -163,7 +174,9 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) - async def _populate_stats_process_rooms(self, progress, batch_size): + async def _populate_stats_process_rooms( + self, progress: JsonDict, batch_size: int + ) -> int: """This is a background update which regenerates statistics for rooms.""" if not self.stats_enabled: await self.db_pool.updates._end_background_update( @@ -173,7 +186,7 @@ class StatsStore(StateDeltasStore): last_room_id = progress.get("last_room_id", "") - def _get_next_batch(txn): + def _get_next_batch(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT room_id FROM current_state_events WHERE room_id > ? @@ -302,7 +315,7 @@ class StatsStore(StateDeltasStore): stream_id: Current position. """ - def _bulk_update_stats_delta_txn(txn): + def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None: for stats_type, stats_updates in updates.items(): for stats_id, fields in stats_updates.items(): logger.debug( @@ -334,7 +347,7 @@ class StatsStore(StateDeltasStore): stats_type: str, stats_id: str, fields: Dict[str, int], - complete_with_stream_id: Optional[int], + complete_with_stream_id: int, absolute_field_overrides: Optional[Dict[str, int]] = None, ) -> None: """ @@ -367,14 +380,14 @@ class StatsStore(StateDeltasStore): def _update_stats_delta_txn( self, - txn, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): + txn: LoggingTransaction, + ts: int, + stats_type: str, + stats_id: str, + fields: Dict[str, int], + complete_with_stream_id: int, + absolute_field_overrides: Optional[Dict[str, int]] = None, + ) -> None: if absolute_field_overrides is None: absolute_field_overrides = {} @@ -417,20 +430,23 @@ class StatsStore(StateDeltasStore): ) def _upsert_with_additive_relatives_txn( - self, txn, table, keyvalues, absolutes, additive_relatives - ): + self, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + absolutes: Dict[str, Any], + additive_relatives: Dict[str, int], + ) -> None: """Used to update values in the stats tables. This is basically a slightly convoluted upsert that *adds* to any existing rows. Args: - txn - table (str): Table name - keyvalues (dict[str, any]): Row-identifying key values - absolutes (dict[str, any]): Absolute (set) fields - additive_relatives (dict[str, int]): Fields that will be added onto - if existing row present. + table: Table name + keyvalues: Row-identifying key values + absolutes: Absolute (set) fields + additive_relatives: Fields that will be added onto if existing row present. """ if self.database_engine.can_native_upsert: absolute_updates = [ @@ -486,20 +502,17 @@ class StatsStore(StateDeltasStore): current_row.update(absolutes) self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) - async def _calculate_and_set_initial_state_for_room( - self, room_id: str - ) -> Tuple[dict, dict, int]: + async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None: """Calculate and insert an entry into room_stats_current. Args: room_id: The room ID under calculation. - - Returns: - A tuple of room state, membership counts and stream position. """ - def _fetch_current_state_stats(txn): - pos = self.get_room_max_stream_ordering() + def _fetch_current_state_stats( + txn: LoggingTransaction, + ) -> Tuple[List[str], Dict[str, int], int, List[str], int]: + pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined] rows = self.db_pool.simple_select_many_txn( txn, @@ -519,7 +532,7 @@ class StatsStore(StateDeltasStore): retcols=["event_id"], ) - event_ids = [row["event_id"] for row in rows] + event_ids = cast(List[str], [row["event_id"] for row in rows]) txn.execute( """ @@ -533,15 +546,15 @@ class StatsStore(StateDeltasStore): txn.execute( """ - SELECT COALESCE(count(*), 0) FROM current_state_events + SELECT COUNT(*) FROM current_state_events WHERE room_id = ? """, (room_id,), ) - (current_state_events_count,) = txn.fetchone() + current_state_events_count = cast(Tuple[int], txn.fetchone())[0] - users_in_room = self.get_users_in_room_txn(txn, room_id) + users_in_room = self.get_users_in_room_txn(txn, room_id) # type: ignore[attr-defined] return ( event_ids, @@ -561,7 +574,7 @@ class StatsStore(StateDeltasStore): "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = await self.get_events(event_ids, get_prev_content=False) + state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined] room_state = { "join_rules": None, @@ -617,8 +630,10 @@ class StatsStore(StateDeltasStore): }, ) - async def _calculate_and_set_initial_state_for_user(self, user_id): - def _calculate_and_set_initial_state_for_user_txn(txn): + async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None: + def _calculate_and_set_initial_state_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[int, int]: pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) txn.execute( @@ -629,7 +644,7 @@ class StatsStore(StateDeltasStore): """, (user_id,), ) - (count,) = txn.fetchone() + count = cast(Tuple[int], txn.fetchone())[0] return count, pos joined_rooms, pos = await self.db_pool.runInteraction( @@ -673,7 +688,9 @@ class StatsStore(StateDeltasStore): users that exist given this query """ - def get_users_media_usage_paginate_txn(txn): + def get_users_media_usage_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: filters = [] args = [self.hs.config.server.server_name] @@ -728,7 +745,7 @@ class StatsStore(StateDeltasStore): sql_base=sql_base, ) txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..319464b1fa 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -34,11 +34,11 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ -import abc + import logging -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +import attr from frozendict import frozendict from twisted.internet import defer @@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_in_list_sql_clause, ) @@ -73,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological" # Used as return values for pagination APIs -_EventDictReturn = namedtuple( - "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventDictReturn: + event_id: str + topological_ordering: Optional[int] + stream_ordering: int def generate_pagination_where_clause( @@ -333,13 +336,13 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: return " AND ".join(clauses), args -class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): - """This is an abstract base class where subclasses must implement - `get_room_max_stream_ordering` and `get_room_min_stream_ordering` - which can be called in the initializer. - """ - - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -371,13 +374,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): self._stream_order_on_start = self.get_room_max_stream_ordering() - @abc.abstractmethod def get_room_max_stream_ordering(self) -> int: - raise NotImplementedError() + """Get the stream_ordering of regular events that we have committed up to + + Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + return self._stream_id_gen.get_current_token() - @abc.abstractmethod def get_room_min_stream_ordering(self) -> int: - raise NotImplementedError() + """Get the stream_ordering of backfilled events that we have committed up to + + Backfilled events use *negative* stream orderings, so this returns the + minimum negative stream id such that all stream ids greater than or + equal to it have been successfully persisted. + """ + return self._backfill_id_gen.get_current_token() def get_room_max_token(self) -> RoomStreamToken: """Get a `RoomStreamToken` that marks the current maximum persisted @@ -819,7 +831,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: - topo = row.topological_ordering + topo: Optional[int] = row.topological_ordering else: topo = None internal = event.internal_metadata @@ -1343,11 +1355,3 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): retcol="instance_name", desc="get_name_from_instance_id", ) - - -class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self) -> int: - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self) -> int: - return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 8f510de53d..c8e508a910 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py
@@ -15,11 +15,13 @@ # limitations under the License. import logging -from typing import Dict, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Tuple, cast +from synapse.replication.tcp.streams import TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -204,6 +206,7 @@ class TagsWorkerStore(AccountDataWorkerStore): The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -230,6 +233,7 @@ class TagsWorkerStore(AccountDataWorkerStore): The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: sql = ( @@ -258,6 +262,7 @@ class TagsWorkerStore(AccountDataWorkerStore): next_id: The the revision to advance to. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id @@ -287,6 +292,21 @@ class TagsWorkerStore(AccountDataWorkerStore): # than the id that the client has. pass + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + class TagsStore(TagsWorkerStore): pass diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..6c299cafa5 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -13,16 +13,19 @@ # limitations under the License. import logging -from collections import namedtuple from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -35,16 +38,6 @@ db_binary_type = memoryview logger = logging.getLogger(__name__) -_TransactionRow = namedtuple( - "_TransactionRow", - ("id", "transaction_id", "destination", "ts", "response_code", "response_json"), -) - -_UpdateTransactionRow = namedtuple( - "_TransactionRow", ("response_code", "response_json") -) - - class DestinationSortOrder(Enum): """Enum to define the sorting method used when returning destinations.""" @@ -71,7 +64,12 @@ class DestinationRetryTimings: class TransactionWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -82,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 - def _cleanup_transactions_txn(txn): + def _cleanup_transactions_txn(txn: LoggingTransaction) -> None: txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) await self.db_pool.runInteraction( @@ -112,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): origin, ) - def _get_received_txn_response(self, txn, transaction_id, origin): + def _get_received_txn_response( + self, txn: LoggingTransaction, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: result = self.db_pool.simple_select_one_txn( txn, table="received_transactions", @@ -187,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): return result def _get_destination_retry_timings( - self, txn, destination: str + self, txn: LoggingTransaction, destination: str ) -> Optional[DestinationRetryTimings]: result = self.db_pool.simple_select_one_txn( txn, @@ -222,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): """ if self.database_engine.can_native_upsert: - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings_native, destination, @@ -232,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): db_autocommit=True, # Safe as its a single upsert ) else: - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings_emulated, destination, @@ -242,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) def _set_destination_retry_timings_native( - self, txn, destination, failure_ts, retry_last_ts, retry_interval - ): + self, + txn: LoggingTransaction, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: assert self.database_engine.can_native_upsert # Upsert retry time interval if retry_interval is zero (i.e. we're @@ -273,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) def _set_destination_retry_timings_emulated( - self, txn, destination, failure_ts, retry_last_ts, retry_interval - ): + self, + txn: LoggingTransaction, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us @@ -384,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): last_successful_stream_ordering: the stream_ordering of the most recent successfully-sent PDU """ - return await self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "destinations", keyvalues={"destination": destination}, values={"last_successful_stream_ordering": last_successful_stream_ordering}, @@ -525,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): else: order = "ASC" - args = [] + args: List[object] = [] where_statement = "" if destination: args.extend(["%" + destination.lower() + "%"]) @@ -534,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): sql_base = f"FROM destinations {where_statement} " sql = f"SELECT COUNT(*) as total_destinations {sql_base}" txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = f""" SELECT destination, retry_last_ts, retry_interval, failure_ts, diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 340ca9e47d..a1a1a6a14a 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -225,11 +225,14 @@ class UIAuthWorkerStore(SQLBaseStore): self, txn: LoggingTransaction, session_id: str, key: str, value: Any ): # Get the current value. - result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), + result = cast( + Dict[str, Any], + self.db_pool.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ), ) # Update it and add it back to the database. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e98a45b6af..0f9b8575d3 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -32,11 +32,14 @@ if TYPE_CHECKING: from synapse.server import HomeServer from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.types import Connection from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ) -> None: super().__init__(database, db_conn, hs)