diff options
author | Sean Quah <8349537+squahtx@users.noreply.github.com> | 2021-12-13 16:28:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-13 16:28:10 +0000 |
commit | 6da8591f2ef9597880ace89aaf434332dddaa711 (patch) | |
tree | ad7366da00d9de508201d95918eff082e6f8fb05 /synapse/storage/databases/main | |
parent | Make `get_device` return None if the device doesn't exist rather than raising... (diff) | |
download | synapse-6da8591f2ef9597880ace89aaf434332dddaa711.tar.xz |
Add type hints to `synapse/storage/databases/main/account_data.py` (#11546)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r-- | synapse/storage/databases/main/account_data.py | 93 | ||||
-rw-r--r-- | synapse/storage/databases/main/tags.py | 22 |
2 files changed, 83 insertions, 32 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index f8bec266ac..32a553fdd7 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,15 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage._base import db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -34,13 +44,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class AccountDataWorkerStore(SQLBaseStore): - """This is an abstract base class where subclasses must implement - `get_max_account_data_stream_id` which can be called in the initializer. - """ +class AccountDataWorkerStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): - self._instance_name = hs.get_instance_name() + # `_can_write_to_account_data` indicates whether the current worker is allowed + # to write account data. A value of `True` implies that `_account_data_id_gen` + # is an `AbstractStreamIdGenerator` and not just a tracker. + self._account_data_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( @@ -61,8 +77,6 @@ class AccountDataWorkerStore(SQLBaseStore): writers=hs.config.worker.writers.account_data, ) else: - self._can_write_to_account_data = True - # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # @@ -70,7 +84,8 @@ class AccountDataWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.account_data: + if self._instance_name in hs.config.worker.writers.account_data: + self._can_write_to_account_data = True self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", @@ -90,8 +105,6 @@ class AccountDataWorkerStore(SQLBaseStore): "AccountDataAndTagsChangeCache", account_max ) - super().__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream @@ -113,7 +126,9 @@ class AccountDataWorkerStore(SQLBaseStore): room_id string to per room account_data dicts. """ - def get_account_data_for_user_txn(txn): + def get_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: rows = self.db_pool.simple_select_list_txn( txn, "account_data", @@ -132,7 +147,7 @@ class AccountDataWorkerStore(SQLBaseStore): ["room_id", "account_data_type", "content"], ) - by_room = {} + by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = db_to_json(row["content"]) @@ -177,7 +192,9 @@ class AccountDataWorkerStore(SQLBaseStore): A dict of the room account_data """ - def get_account_data_for_room_txn(txn): + def get_account_data_for_room_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonDict]: rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", @@ -207,7 +224,9 @@ class AccountDataWorkerStore(SQLBaseStore): The room account_data for that type, or None if there isn't any set. """ - def get_account_data_for_room_and_type_txn(txn): + def get_account_data_for_room_and_type_txn( + txn: LoggingTransaction, + ) -> Optional[JsonDict]: content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", @@ -243,14 +262,16 @@ class AccountDataWorkerStore(SQLBaseStore): if last_id == current_id: return [] - def get_updated_global_account_data_txn(txn): + def get_updated_global_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str]]: sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn @@ -273,14 +294,16 @@ class AccountDataWorkerStore(SQLBaseStore): if last_id == current_id: return [] - def get_updated_room_account_data_txn(txn): + def get_updated_room_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str]]: sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn @@ -299,7 +322,9 @@ class AccountDataWorkerStore(SQLBaseStore): mapping from room_id string to per room account_data dicts. """ - def get_updated_account_data_for_user_txn(txn): + def get_updated_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" @@ -316,7 +341,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) - account_data_by_room = {} + account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) @@ -353,12 +378,15 @@ class AccountDataWorkerStore(SQLBaseStore): ) ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: @@ -372,7 +400,8 @@ class AccountDataWorkerStore(SQLBaseStore): (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + + super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -389,6 +418,7 @@ class AccountDataWorkerStore(SQLBaseStore): The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -431,6 +461,7 @@ class AccountDataWorkerStore(SQLBaseStore): The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -452,7 +483,7 @@ class AccountDataWorkerStore(SQLBaseStore): def _add_account_data_for_user( self, - txn, + txn: LoggingTransaction, next_id: int, user_id: str, account_data_type: str, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 8f510de53d..c8e508a910 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,11 +15,13 @@ # limitations under the License. import logging -from typing import Dict, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Tuple, cast +from synapse.replication.tcp.streams import TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -204,6 +206,7 @@ class TagsWorkerStore(AccountDataWorkerStore): The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -230,6 +233,7 @@ class TagsWorkerStore(AccountDataWorkerStore): The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: sql = ( @@ -258,6 +262,7 @@ class TagsWorkerStore(AccountDataWorkerStore): next_id: The the revision to advance to. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id @@ -287,6 +292,21 @@ class TagsWorkerStore(AccountDataWorkerStore): # than the id that the client has. pass + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + class TagsStore(TagsWorkerStore): pass |