summary refs log tree commit diff
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2021-12-13 16:28:10 +0000
committerGitHub <noreply@github.com>2021-12-13 16:28:10 +0000
commit6da8591f2ef9597880ace89aaf434332dddaa711 (patch)
treead7366da00d9de508201d95918eff082e6f8fb05
parentMake `get_device` return None if the device doesn't exist rather than raising... (diff)
downloadsynapse-6da8591f2ef9597880ace89aaf434332dddaa711.tar.xz
Add type hints to `synapse/storage/databases/main/account_data.py` (#11546)
-rw-r--r--changelog.d/11546.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/storage/databases/main/account_data.py93
-rw-r--r--synapse/storage/databases/main/tags.py22
4 files changed, 87 insertions, 33 deletions
diff --git a/changelog.d/11546.misc b/changelog.d/11546.misc
new file mode 100644
index 0000000000..d451940bf2
--- /dev/null
+++ b/changelog.d/11546.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index 1caf807e85..4b2ddafd20 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -25,7 +25,6 @@ exclude = (?x)
   ^(
    |synapse/storage/databases/__init__.py
    |synapse/storage/databases/main/__init__.py
-   |synapse/storage/databases/main/account_data.py
    |synapse/storage/databases/main/cache.py
    |synapse/storage/databases/main/devices.py
    |synapse/storage/databases/main/e2e_room_keys.py
@@ -181,6 +180,9 @@ disallow_untyped_defs = True
 [mypy-synapse.state.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.account_data]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.client_ips]
 disallow_untyped_defs = True
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index f8bec266ac..32a553fdd7 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,15 +14,25 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage._base import db_to_json
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -34,13 +44,19 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class AccountDataWorkerStore(SQLBaseStore):
-    """This is an abstract base class where subclasses must implement
-    `get_max_account_data_stream_id` which can be called in the initializer.
-    """
+class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
-        self._instance_name = hs.get_instance_name()
+        # `_can_write_to_account_data` indicates whether the current worker is allowed
+        # to write account data. A value of `True` implies that `_account_data_id_gen`
+        # is an `AbstractStreamIdGenerator` and not just a tracker.
+        self._account_data_id_gen: AbstractStreamIdTracker
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_account_data = (
@@ -61,8 +77,6 @@ class AccountDataWorkerStore(SQLBaseStore):
                 writers=hs.config.worker.writers.account_data,
             )
         else:
-            self._can_write_to_account_data = True
-
             # We shouldn't be running in worker mode with SQLite, but its useful
             # to support it for unit tests.
             #
@@ -70,7 +84,8 @@ class AccountDataWorkerStore(SQLBaseStore):
             # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
             # updated over replication. (Multiple writers are not supported for
             # SQLite).
-            if hs.get_instance_name() in hs.config.worker.writers.account_data:
+            if self._instance_name in hs.config.worker.writers.account_data:
+                self._can_write_to_account_data = True
                 self._account_data_id_gen = StreamIdGenerator(
                     db_conn,
                     "room_account_data",
@@ -90,8 +105,6 @@ class AccountDataWorkerStore(SQLBaseStore):
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super().__init__(database, db_conn, hs)
-
     def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream ID for account data stream
 
@@ -113,7 +126,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             room_id string to per room account_data dicts.
         """
 
-        def get_account_data_for_user_txn(txn):
+        def get_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
             rows = self.db_pool.simple_select_list_txn(
                 txn,
                 "account_data",
@@ -132,7 +147,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 ["room_id", "account_data_type", "content"],
             )
 
-            by_room = {}
+            by_room: Dict[str, Dict[str, JsonDict]] = {}
             for row in rows:
                 room_data = by_room.setdefault(row["room_id"], {})
                 room_data[row["account_data_type"]] = db_to_json(row["content"])
@@ -177,7 +192,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             A dict of the room account_data
         """
 
-        def get_account_data_for_room_txn(txn):
+        def get_account_data_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, JsonDict]:
             rows = self.db_pool.simple_select_list_txn(
                 txn,
                 "room_account_data",
@@ -207,7 +224,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             The room account_data for that type, or None if there isn't any set.
         """
 
-        def get_account_data_for_room_and_type_txn(txn):
+        def get_account_data_for_room_and_type_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[JsonDict]:
             content_json = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 table="room_account_data",
@@ -243,14 +262,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return []
 
-        def get_updated_global_account_data_txn(txn):
+        def get_updated_global_account_data_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str]]:
             sql = (
                 "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_updated_global_account_data", get_updated_global_account_data_txn
@@ -273,14 +294,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return []
 
-        def get_updated_room_account_data_txn(txn):
+        def get_updated_room_account_data_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str]]:
             sql = (
                 "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_updated_room_account_data", get_updated_room_account_data_txn
@@ -299,7 +322,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             mapping from room_id string to per room account_data dicts.
         """
 
-        def get_updated_account_data_for_user_txn(txn):
+        def get_updated_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
             sql = (
                 "SELECT account_data_type, content FROM account_data"
                 " WHERE user_id = ? AND stream_id > ?"
@@ -316,7 +341,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             txn.execute(sql, (user_id, stream_id))
 
-            account_data_by_room = {}
+            account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
             for row in txn:
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = db_to_json(row[2])
@@ -353,12 +378,15 @@ class AccountDataWorkerStore(SQLBaseStore):
             )
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == TagAccountDataStream.NAME:
             self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
         elif stream_name == AccountDataStream.NAME:
             self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
@@ -372,7 +400,8 @@ class AccountDataWorkerStore(SQLBaseStore):
                     (row.user_id, row.room_id, row.data_type)
                 )
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
 
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -389,6 +418,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -431,6 +461,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
@@ -452,7 +483,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
     def _add_account_data_for_user(
         self,
-        txn,
+        txn: LoggingTransaction,
         next_id: int,
         user_id: str,
         account_data_type: str,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 8f510de53d..c8e508a910 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,11 +15,13 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Tuple, cast
 
+from synapse.replication.tcp.streams import TagAccountDataStream
 from synapse.storage._base import db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -204,6 +206,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             The next account data ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -230,6 +233,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             The next account data ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
             sql = (
@@ -258,6 +262,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             next_id: The the revision to advance to.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         txn.call_after(
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -287,6 +292,21 @@ class TagsWorkerStore(AccountDataWorkerStore):
                 # than the id that the client has.
                 pass
 
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
 
 class TagsStore(TagsWorkerStore):
     pass