summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2021-11-15 12:57:39 +0000
committerDavid Robertson <davidr@element.io>2021-11-15 13:00:04 +0000
commitbb150be1ad7922b88057b713bdd81e369aa7c6c0 (patch)
tree525e12059b175abede3f3da9e3b53d3adb6401a8
parentDatabase storage profile passes mypy (#11342) (diff)
downloadsynapse-bb150be1ad7922b88057b713bdd81e369aa7c6c0.tar.xz
Annotate get_all_updates_caches_txn
-rw-r--r--synapse/storage/databases/main/cache.py30
1 files changed, 25 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 36e8422fc6..e6098152ec 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -25,7 +25,11 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
 )
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -39,16 +43,24 @@ logger = logging.getLogger(__name__)
 # based on the current state when notifying workers over replication.
 CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
+# Corresponds to the (cache_func, keys, invalidation_ts) db columns.
+_CacheData = Tuple[str, Optional[List[str]], Optional[int]]
+
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
 
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+    ) -> Tuple[List[Tuple[int, _CacheData]], int, bool]:
         """Get updates for caches replication stream.
 
         Args:
@@ -73,7 +85,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_caches_txn(txn):
+        def get_all_updated_caches_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, _CacheData]], int, bool]:
             # We purposefully don't bound by the current token, as we want to
             # send across cache invalidations as quickly as possible. Cache
             # invalidations are idempotent, so duplicates are fine.
@@ -85,7 +99,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
             txn.execute(sql, (last_id, instance_name, limit))
-            updates = [(row[0], row[1:]) for row in txn]
+            updates: List[Tuple[int, _CacheData]] = []
+            row: Tuple[int, str, Optional[List[str]], Optional[int]]
+            # Type saftey: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                updates.append((row[0], row[1:]))
             limited = False
             upto_token = current_id
             if len(updates) >= limit: