From 1abfb15f07d4f8119afcf908f9e1903e7feef371 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 13 Dec 2021 16:28:26 +0000 Subject: Add type hints to `synapse/storage/databases/main/end_to_end_keys.py` (#11551) --- synapse/storage/databases/main/__init__.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'synapse/storage/databases/main/__init__.py') diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9ff2d8d8c3..065145c0d2 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -143,9 +143,6 @@ class DataStore( ("device_lists_outbound_pokes", "stream_id"), ], ) - self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" - ) self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") -- cgit 1.5.1 From 5305a5e88144828419249fd9e4c5198d92276a44 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 13 Dec 2021 17:05:00 +0000 Subject: Type hint the constructors of the data store classes (#11555) --- changelog.d/11555.misc | 1 + synapse/replication/slave/storage/_base.py | 9 ++++++-- synapse/replication/slave/storage/client_ips.py | 9 ++++++-- synapse/replication/slave/storage/devices.py | 9 ++++++-- synapse/replication/slave/storage/events.py | 9 ++++++-- synapse/replication/slave/storage/filtering.py | 9 ++++++-- synapse/replication/slave/storage/groups.py | 9 ++++++-- synapse/storage/_base.py | 13 +++++++---- synapse/storage/database.py | 2 +- synapse/storage/databases/main/__init__.py | 9 ++++++-- synapse/storage/databases/main/appservice.py | 10 +++++--- synapse/storage/databases/main/cache.py | 9 ++++++-- synapse/storage/databases/main/censor_events.py | 13 +++++++++-- synapse/storage/databases/main/client_ips.py | 22 ++++++++++++++---- synapse/storage/databases/main/deviceinbox.py | 7 +++++- synapse/storage/databases/main/devices.py | 22 +++++++++++++++--- synapse/storage/databases/main/event_federation.py | 20 +++++++++++++--- .../storage/databases/main/event_push_actions.py | 20 +++++++++++++--- synapse/storage/databases/main/events.py | 9 +++++--- .../storage/databases/main/events_bg_updates.py | 8 ++++++- synapse/storage/databases/main/group_server.py | 10 +++++--- synapse/storage/databases/main/lock.py | 14 ++++++++--- synapse/storage/databases/main/metrics.py | 9 ++++++-- .../storage/databases/main/monthly_active_users.py | 20 +++++++++++++--- synapse/storage/databases/main/presence.py | 6 ++--- synapse/storage/databases/main/push_rule.py | 9 ++++++-- synapse/storage/databases/main/receipts.py | 13 +++++++++-- synapse/storage/databases/main/room.py | 27 ++++++++++++++++++---- synapse/storage/databases/main/roommember.py | 23 ++++++++++++++---- synapse/storage/databases/main/search.py | 20 +++++++++++++--- synapse/storage/databases/main/state.py | 27 ++++++++++++++++++---- synapse/storage/databases/main/stats.py | 9 ++++++-- synapse/storage/databases/main/stream.py | 8 ++++++- synapse/storage/databases/main/transactions.py | 13 +++++++++-- synapse/storage/databases/main/user_directory.py | 11 +++++---- 35 files changed, 351 insertions(+), 87 deletions(-) create mode 100644 changelog.d/11555.misc (limited to 'synapse/storage/databases/main/__init__.py') diff --git a/changelog.d/11555.misc b/changelog.d/11555.misc new file mode 100644 index 0000000000..d451940bf2 --- /dev/null +++ b/changelog.d/11555.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 7ecb446e7c..7644146dba 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -27,7 +27,12 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen: Optional[ diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 61cd7e5228..bc888ce1a8 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.lrucache import LruCache @@ -25,7 +25,12 @@ if TYPE_CHECKING: class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.client_ip_last_seen: LruCache[tuple, int] = LruCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 0a58296089..a2aff75b70 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -27,7 +27,12 @@ if TYPE_CHECKING: class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 63ed50caa5..50e7379e83 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_federation import EventFederationWorkerStore from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, @@ -58,7 +58,12 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 90284c202d..4d185e2b56 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore @@ -24,7 +24,12 @@ if TYPE_CHECKING: class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 497e16c69e..9d90e26375 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import GroupServerStream -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -26,7 +26,12 @@ if TYPE_CHECKING: class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 3056e64ff5..7967011afd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,10 +17,8 @@ import logging from abc import ABCMeta from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union -from synapse.storage.database import LoggingTransaction # noqa: F401 -from synapse.storage.database import make_in_list_sql_clause # noqa: F401 -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import get_domain_from_id from synapse.util import json_decoder @@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 5552dd3c5c..3b44e6469c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -175,7 +175,7 @@ class LoggingDatabaseConnection: def rollback(self) -> None: self.conn.rollback() - def __enter__(self) -> "Connection": + def __enter__(self) -> "LoggingDatabaseConnection": self.conn.__enter__() return self diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 065145c0d2..716b25dd34 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,7 +18,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -129,7 +129,12 @@ class DataStore( LockStore, SessionStore, ): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 4a883dc166..92c95a41d7 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -24,9 +24,8 @@ from synapse.appservice import ( from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.types import Connection from synapse.types import JsonDict from synapse.util import json_encoder @@ -58,7 +57,12 @@ def _make_exclusive_regex( class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.services_cache = load_appservices( hs.hostname, hs.config.appservice.app_service_config_files ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 36e8422fc6..0024348067 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, ) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 0f56e10220..fd3fc298b3 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util import json_encoder @@ -31,7 +35,12 @@ logger = logging.getLogger(__name__) class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if ( diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 1dc7f0ebe3..8b0c614ece 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -26,7 +26,6 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore -from synapse.storage.types import Connection from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache @@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict): class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.server.user_ips_max_age @@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): # (user_id, access_token, ip,) -> last_seen self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index ab8766c75b..b410eefdc7 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index eff825dd22..3932599988 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -953,7 +959,12 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -1085,7 +1096,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 9580a40785..2287f1cc68 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine @@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -1514,7 +1523,12 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3efdd0c920..eacff3e432 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -20,7 +20,11 @@ from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -82,7 +86,12 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn @@ -910,7 +919,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5184e6bf85..81e67ece55 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -41,10 +41,13 @@ from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.types import Connection from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id @@ -95,7 +98,7 @@ class PersistEventsStore: hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore", - db_conn: Connection, + db_conn: LoggingDatabaseConnection, ): self.hs = hs self.db_pool = db diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index c88fd35e7f..9b36941fec 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -83,7 +84,12 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index bb621df0dd..3f6086050b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -19,8 +19,7 @@ from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import JsonDict from synapse.util import json_encoder @@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict): class GroupServerWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): database.updates.register_background_index_update( update_name="local_group_updates_index", index_name="local_group_updates_stream_id_index", diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index a540f7fb26..bedacaf0d7 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.types import Connection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import Clock from synapse.util.stringutils import random_string @@ -54,7 +57,12 @@ class LockStore(SQLBaseStore): `last_renewed_ts` column with the current time. """ - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._reactor = hs.get_reactor() diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index d901933ae4..3bb21958d1 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) @@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 3c98ef876f..65b7e307e1 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -16,7 +16,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + make_in_list_sql_clause, +) from synapse.util.caches.descriptors import cached from synapse.util.threepids import canonicalise_email @@ -31,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -213,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._mau_stats_only = hs.config.server.mau_stats_only diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index cc0eebdb46..02d534ae45 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator @@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 3b63267395..e01c94930a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore @@ -81,7 +81,12 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 9c5625c8bb..bf0b903af2 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -32,7 +32,11 @@ from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -47,7 +51,12 @@ logger = logging.getLogger(__name__) class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d694d852d..28c4b65bbd 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -24,7 +24,11 @@ from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.search import SearchStore from synapse.storage.types import Cursor from synapse.types import JsonDict, ThirdPartyInstanceID @@ -72,7 +76,12 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1050,7 +1059,12 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1435,7 +1449,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 6b2a8d06a6..cda80d6511 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Used by `_get_joined_hosts` to ensure only one thing mutates the cache @@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile @@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 7fe233767f..f87acfb866 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -20,7 +20,11 @@ from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -105,7 +109,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if not hs.config.server.enable_search: @@ -358,7 +367,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index fa2c3b1feb..4bc044fb16 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -22,7 +22,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter @@ -56,7 +60,12 @@ class _GetStateGroupDelta( class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -349,7 +358,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -536,5 +550,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 5d7b59d861..9020e0976c 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -24,7 +24,7 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import StoreError -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -96,7 +96,12 @@ class UserSortOrder(Enum): class StatsStore(StateDeltasStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 57aab55259..9488fd5094 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_in_list_sql_clause, ) @@ -339,7 +340,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 1622822552..54b41513ee 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -22,7 +22,11 @@ from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -71,7 +75,12 @@ class DestinationRetryTimings: class TransactionWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index e98a45b6af..0f9b8575d3 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -32,11 +32,14 @@ if TYPE_CHECKING: from synapse.server import HomeServer from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.types import Connection from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ) -> None: super().__init__(database, db_conn, hs) -- cgit 1.5.1 From c7fe32edb4da58a262ec37ce6fde7bef95920872 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 15 Dec 2021 18:00:48 +0000 Subject: Add type hints to `synapse/storage/databases/main/room.py` (#11575) --- changelog.d/11575.misc | 1 + mypy.ini | 4 +- synapse/handlers/room_member.py | 6 +- synapse/storage/databases/main/__init__.py | 1 - synapse/storage/databases/main/room.py | 173 +++++++++++++++++------------ 5 files changed, 108 insertions(+), 77 deletions(-) create mode 100644 changelog.d/11575.misc (limited to 'synapse/storage/databases/main/__init__.py') diff --git a/changelog.d/11575.misc b/changelog.d/11575.misc new file mode 100644 index 0000000000..d451940bf2 --- /dev/null +++ b/changelog.d/11575.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index e38ad635aa..cbe1e8302c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -37,7 +37,6 @@ exclude = (?x) |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py - |synapse/storage/databases/main/room.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py |synapse/storage/databases/main/state.py @@ -205,6 +204,9 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.events_worker] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.room] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.room_batch] disallow_untyped_defs = True diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 447e3ce571..6aa910dd10 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1020,7 +1020,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Add new room to the room directory if the old room was there # Remove old room from the room directory old_room = await self.store.get_room(old_room_id) - if old_room and old_room["is_public"]: + if old_room is not None and old_room["is_public"]: await self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(room_id, True) @@ -1031,7 +1031,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): local_group_ids = await self.store.get_local_groups_for_room(old_room_id) for group_id in local_group_ids: # Add new the new room to those groups - await self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) + await self.store.add_room_to_group( + group_id, room_id, old_room is not None and old_room["is_public"] + ) # Remove the old room from those groups await self.store.remove_room_from_group(group_id, old_room_id) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 716b25dd34..a594223fc6 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -149,7 +149,6 @@ class DataStore( ], ) - self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._group_updates_id_gen = StreamIdGenerator( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 6cf6cc8484..4472335af9 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -17,7 +17,7 @@ import collections import logging from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError @@ -29,8 +29,9 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.databases.main.search import SearchStore +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -75,7 +76,7 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" -class RoomWorkerStore(SQLBaseStore): +class RoomWorkerStore(CacheInvalidationWorkerStore): def __init__( self, database: DatabasePool, @@ -92,7 +93,7 @@ class RoomWorkerStore(SQLBaseStore): room_creator_user_id: str, is_public: bool, room_version: RoomVersion, - ): + ) -> None: """Stores a room. Args: @@ -120,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> dict: + async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve a room. Args: @@ -145,7 +146,9 @@ class RoomWorkerStore(SQLBaseStore): A dict containing the room information, or None if the room is unknown. """ - def get_room_with_stats_txn(txn, room_id): + def get_room_with_stats_txn( + txn: LoggingTransaction, room_id: str + ) -> Optional[Dict[str, Any]]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -194,7 +197,7 @@ class RoomWorkerStore(SQLBaseStore): ignore_non_federatable: If true filters out non-federatable rooms """ - def _count_public_rooms_txn(txn): + def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] if network_tuple: @@ -235,7 +238,7 @@ class RoomWorkerStore(SQLBaseStore): } txn.execute(sql, query_args) - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn @@ -244,11 +247,11 @@ class RoomWorkerStore(SQLBaseStore): async def get_room_count(self) -> int: """Retrieve the total number of rooms.""" - def f(txn): + def f(txn: LoggingTransaction) -> int: sql = "SELECT count(*) FROM rooms" txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 + row = cast(Tuple[int], txn.fetchone()) + return row[0] return await self.db_pool.runInteraction("get_rooms", f) @@ -260,7 +263,7 @@ class RoomWorkerStore(SQLBaseStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ): + ) -> List[Dict[str, Any]]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -381,7 +384,9 @@ class RoomWorkerStore(SQLBaseStore): LIMIT ? """ - def _get_largest_public_rooms_txn(txn): + def _get_largest_public_rooms_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: txn.execute(sql, query_args) results = self.db_pool.cursor_to_dict(txn) @@ -444,7 +449,7 @@ class RoomWorkerStore(SQLBaseStore): """ # Filter room names by a string where_statement = "" - search_pattern = [] + search_pattern: List[object] = [] if search_term: where_statement = """ WHERE LOWER(state.name) LIKE ? @@ -552,7 +557,9 @@ class RoomWorkerStore(SQLBaseStore): where_statement, ) - def _get_rooms_paginate_txn(txn): + def _get_rooms_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: # Add the search term into the WHERE clause # and execute the data query txn.execute(info_sql, search_pattern + [limit, start]) @@ -584,7 +591,7 @@ class RoomWorkerStore(SQLBaseStore): # Add the search term into the WHERE clause if present txn.execute(count_sql, search_pattern) - room_count = txn.fetchone() + room_count = cast(Tuple[int], txn.fetchone()) return rooms, room_count[0] return await self.db_pool.runInteraction( @@ -629,7 +636,7 @@ class RoomWorkerStore(SQLBaseStore): burst_count: How many actions that can be performed before being limited. """ - def set_ratelimit_txn(txn): + def set_ratelimit_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_upsert_txn( txn, table="ratelimit_override", @@ -652,7 +659,7 @@ class RoomWorkerStore(SQLBaseStore): user_id: user ID of the user """ - def delete_ratelimit_txn(txn): + def delete_ratelimit_txn(txn: LoggingTransaction) -> None: row = self.db_pool.simple_select_one_txn( txn, table="ratelimit_override", @@ -676,7 +683,7 @@ class RoomWorkerStore(SQLBaseStore): await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) @cached() - async def get_retention_policy_for_room(self, room_id): + async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]: """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -685,13 +692,15 @@ class RoomWorkerStore(SQLBaseStore): configuration). Args: - room_id (str): The ID of the room to get the retention policy of. + room_id: The ID of the room to get the retention policy of. Returns: - dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + A dict containing "min_lifetime" and "max_lifetime" for this room. """ - def get_retention_policy_for_room_txn(txn): + def get_retention_policy_for_room_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Optional[int]]]: txn.execute( """ SELECT min_lifetime, max_lifetime FROM room_retention @@ -716,19 +725,23 @@ class RoomWorkerStore(SQLBaseStore): "max_lifetime": self.config.retention.retention_default_max_lifetime, } - row = ret[0] + min_lifetime = ret[0]["min_lifetime"] + max_lifetime = ret[0]["max_lifetime"] # If one of the room's policy's attributes isn't defined, use the matching # attribute from the default policy. # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. - if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention.retention_default_min_lifetime + if min_lifetime is None: + min_lifetime = self.config.retention.retention_default_min_lifetime - if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention.retention_default_max_lifetime + if max_lifetime is None: + max_lifetime = self.config.retention.retention_default_max_lifetime - return row + return { + "min_lifetime": min_lifetime, + "max_lifetime": max_lifetime, + } async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room @@ -740,7 +753,9 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of the media IDs. """ - def _get_media_mxcs_in_room_txn(txn): + def _get_media_mxcs_in_room_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], List[str]]: local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_media_mxcs = [] remote_media_mxcs = [] @@ -766,7 +781,7 @@ class RoomWorkerStore(SQLBaseStore): logger.info("Quarantining media in room: %s", room_id) - def _quarantine_media_in_room_txn(txn): + def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int: local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) return self._quarantine_media_txn( txn, local_mxcs, remote_mxcs, quarantined_by @@ -776,13 +791,11 @@ class RoomWorkerStore(SQLBaseStore): "quarantine_media_in_room", _quarantine_media_in_room_txn ) - def _get_media_mxcs_in_room_txn(self, txn, room_id): + def _get_media_mxcs_in_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> Tuple[List[str], List[Tuple[str, str]]]: """Retrieves all the local and remote media MXC URIs in a given room - Args: - txn (cursor) - room_id (str) - Returns: The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. @@ -850,7 +863,7 @@ class RoomWorkerStore(SQLBaseStore): logger.info("Quarantining media: %s/%s", server_name, media_id) is_local = server_name == self.config.server.server_name - def _quarantine_media_by_id_txn(txn): + def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int: local_mxcs = [media_id] if is_local else [] remote_mxcs = [(server_name, media_id)] if not is_local else [] @@ -872,7 +885,7 @@ class RoomWorkerStore(SQLBaseStore): quarantined_by: The ID of the user who made the quarantine request """ - def _quarantine_media_by_user_txn(txn): + def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int: local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) @@ -880,7 +893,9 @@ class RoomWorkerStore(SQLBaseStore): "quarantine_media_by_user", _quarantine_media_by_user_txn ) - def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): + def _get_media_ids_by_user_txn( + self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True + ) -> List[str]: """Retrieves local media IDs by a given user Args: @@ -909,7 +924,7 @@ class RoomWorkerStore(SQLBaseStore): def _quarantine_media_txn( self, - txn, + txn: LoggingTransaction, local_mxcs: List[str], remote_mxcs: List[Tuple[str, str]], quarantined_by: Optional[str], @@ -937,12 +952,15 @@ class RoomWorkerStore(SQLBaseStore): # set quarantine if quarantined_by is not None: sql += "AND safe_from_quarantine = ?" - rows = [(quarantined_by, media_id, False) for media_id in local_mxcs] + txn.executemany( + sql, [(quarantined_by, media_id, False) for media_id in local_mxcs] + ) # remove from quarantine else: - rows = [(quarantined_by, media_id) for media_id in local_mxcs] + txn.executemany( + sql, [(quarantined_by, media_id) for media_id in local_mxcs] + ) - txn.executemany(sql, rows) # Note that a rowcount of -1 can be used to indicate no rows were affected. total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 @@ -960,7 +978,7 @@ class RoomWorkerStore(SQLBaseStore): async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False - ) -> Dict[str, dict]: + ) -> Dict[str, Dict[str, Optional[int]]]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. @@ -980,7 +998,9 @@ class RoomWorkerStore(SQLBaseStore): "min_lifetime" (int|None), and "max_lifetime" (int|None). """ - def get_rooms_for_retention_period_in_range_txn(txn): + def get_rooms_for_retention_period_in_range_txn( + txn: LoggingTransaction, + ) -> Dict[str, Dict[str, Optional[int]]]: range_conditions = [] args = [] @@ -1067,8 +1087,6 @@ class RoomBackgroundUpdateStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - self.config = hs.config - self.db_pool.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, @@ -1099,7 +1117,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_populate_rooms_creator_column, ) - async def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention( + self, progress: JsonDict, batch_size: int + ) -> int: """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -1109,7 +1129,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): last_room = progress.get("room_id", "") - def _background_insert_retention_txn(txn): + def _background_insert_retention_txn(txn: LoggingTransaction) -> bool: txn.execute( """ SELECT state.room_id, state.event_id, events.json @@ -1168,15 +1188,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size async def _background_add_rooms_room_version_column( - self, progress: dict, batch_size: int - ): + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to go and add room version information to `rooms` table from `current_state_events` table. """ last_room_id = progress.get("room_id", "") - def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): + def _background_add_rooms_room_version_column_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT room_id, json FROM current_state_events INNER JOIN event_json USING (room_id, event_id) @@ -1237,7 +1259,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size async def _remove_tombstoned_rooms_from_directory( - self, progress, batch_size + self, progress: JsonDict, batch_size: int ) -> int: """Removes any rooms with tombstone events from the room directory @@ -1247,7 +1269,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): last_room = progress.get("room_id", "") - def _get_rooms(txn): + def _get_rooms(txn: LoggingTransaction) -> List[str]: txn.execute( """ SELECT room_id @@ -1285,7 +1307,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return len(rooms) @abstractmethod - def set_room_is_public(self, room_id, is_public): + def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]: # this will need to be implemented if a background update is performed with # existing (tombstoned, public) rooms in the database. # @@ -1332,7 +1354,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): 32-bit integer field. """ - def process(txn: Cursor) -> int: + def process(txn: LoggingTransaction) -> int: last_room = progress.get("last_room", "") txn.execute( """ @@ -1389,15 +1411,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return 0 async def _background_populate_rooms_creator_column( - self, progress: dict, batch_size: int - ): + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to go and add creator information to `rooms` table from `current_state_events` table. """ last_room_id = progress.get("room_id", "") - def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): + def _background_populate_rooms_creator_column_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT room_id, json FROM event_json INNER JOIN rooms AS room USING (room_id) @@ -1448,7 +1472,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size -class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): def __init__( self, database: DatabasePool, @@ -1457,11 +1481,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ): super().__init__(database, db_conn, hs) - self.config = hs.config + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") async def upsert_room_on_join( self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] - ): + ) -> None: """Ensure that the room is stored in the table Called when we join a room over federation, and overwrites any room version @@ -1507,7 +1531,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion - ): + ) -> None: """ When we receive an invite or any other event over federation that may relate to a room we are not in, store the version of the room if we don't already know the room version. @@ -1547,8 +1571,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.hs.get_notifier().on_new_replication_data() async def set_room_is_public_appservice( - self, room_id, appservice_id, network_id, is_public - ): + self, room_id: str, appservice_id: str, network_id: str, is_public: bool + ) -> None: """Edit the appservice/network specific public room list. Each appservice can have a number of published room lists associated @@ -1557,11 +1581,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): network. Args: - room_id (str) - appservice_id (str) - network_id (str) - is_public (bool): Whether to publish or unpublish the room from the - list. + room_id + appservice_id + network_id + is_public: Whether to publish or unpublish the room from the list. """ if is_public: @@ -1626,7 +1649,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): event_report: json list of information from event report """ - def _get_event_report_txn(txn, report_id): + def _get_event_report_txn( + txn: LoggingTransaction, report_id: int + ) -> Optional[Dict[str, Any]]: sql = """ SELECT @@ -1698,9 +1723,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): count: total number of event reports matching the filter criteria """ - def _get_event_reports_paginate_txn(txn): + def _get_event_reports_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: filters = [] - args = [] + args: List[object] = [] if user_id: filters.append("er.user_id LIKE ?") @@ -1724,7 +1751,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): where_clause ) txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT -- cgit 1.5.1 From 2359ee3864a065229c80e3ff58faa981edd24558 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 4 Jan 2022 16:10:27 +0000 Subject: Remove redundant `get_current_events_token` (#11643) * Push `get_room_{min,max_stream_ordering}` into StreamStore Both implementations of this are identical, so we may as well push it down and get rid of the abstract base class nonsense. * Remove redundant `StreamStore` class This is empty now * Remove redundant `get_current_events_token` This was an exact duplicate of `get_room_max_stream_ordering`, so let's get rid of it. * newsfile --- changelog.d/11643.misc | 1 + synapse/handlers/federation_event.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/replication/slave/storage/events.py | 9 ------- synapse/storage/databases/main/__init__.py | 4 +-- synapse/storage/databases/main/events_worker.py | 4 --- synapse/storage/databases/main/stream.py | 34 +++++++++++-------------- 7 files changed, 20 insertions(+), 36 deletions(-) create mode 100644 changelog.d/11643.misc (limited to 'synapse/storage/databases/main/__init__.py') diff --git a/changelog.d/11643.misc b/changelog.d/11643.misc new file mode 100644 index 0000000000..1c3b3071f6 --- /dev/null +++ b/changelog.d/11643.misc @@ -0,0 +1 @@ +Remove redundant `get_current_events_token` method. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9917613298..d08e48da58 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1838,7 +1838,7 @@ class FederationEventHandler: The stream ID after which all events have been persisted. """ if not event_and_contexts: - return self._store.get_current_events_token() + return self._store.get_room_max_stream_ordering() instance = self._config.worker.events_shard_config.get_instance(room_id) if instance != self._instance_name: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 454d06c973..c781fefb1b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -729,7 +729,7 @@ class PresenceHandler(BasePresenceHandler): # Presence is best effort and quickly heals itself, so lets just always # stream from the current state when we restart. - self._event_pos = self.store.get_current_events_token() + self._event_pos = self.store.get_room_max_stream_ordering() self._event_processing = False async def _on_shutdown(self) -> None: diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 50e7379e83..0f08372694 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -80,12 +80,3 @@ class SlavedEventStore( min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) - - # Cached functions can't be accessed through a class instance so we need - # to reach inside the __dict__ to extract them. - - def get_room_max_stream_ordering(self): - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self): - return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index a594223fc6..f024761ba7 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -68,7 +68,7 @@ from .session import SessionStore from .signatures import SignatureStore from .state import StateStore from .stats import StatsStore -from .stream import StreamStore +from .stream import StreamWorkerStore from .tags import TagsStore from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore @@ -87,7 +87,7 @@ class DataStore( RoomStore, RoomBatchStore, RegistrationStore, - StreamStore, + StreamWorkerStore, ProfileStore, PresenceStore, TransactionWorkerStore, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c7b660ac5a..8d4287045a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1383,10 +1383,6 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_events_token(self) -> int: - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index b0642ca69f..319464b1fa 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -34,7 +34,7 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ -import abc + import logging from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple @@ -336,12 +336,7 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: return " AND ".join(clauses), args -class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): - """This is an abstract base class where subclasses must implement - `get_room_max_stream_ordering` and `get_room_min_stream_ordering` - which can be called in the initializer. - """ - +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -379,13 +374,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): self._stream_order_on_start = self.get_room_max_stream_ordering() - @abc.abstractmethod def get_room_max_stream_ordering(self) -> int: - raise NotImplementedError() + """Get the stream_ordering of regular events that we have committed up to + + Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + return self._stream_id_gen.get_current_token() - @abc.abstractmethod def get_room_min_stream_ordering(self) -> int: - raise NotImplementedError() + """Get the stream_ordering of backfilled events that we have committed up to + + Backfilled events use *negative* stream orderings, so this returns the + minimum negative stream id such that all stream ids greater than or + equal to it have been successfully persisted. + """ + return self._backfill_id_gen.get_current_token() def get_room_max_token(self) -> RoomStreamToken: """Get a `RoomStreamToken` that marks the current maximum persisted @@ -1351,11 +1355,3 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): retcol="instance_name", desc="get_name_from_instance_id", ) - - -class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self) -> int: - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self) -> int: - return self._backfill_id_gen.get_current_token() -- cgit 1.5.1