summary refs log tree commit diff
path: root/synapse/storage/databases/main/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r--synapse/storage/databases/main/receipts.py37
1 files changed, 25 insertions, 12 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index bf0b903af2..e6f97aeece 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -24,10 +24,9 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    cast,
 )
 
-from twisted.internet import defer
-
 from synapse.api.constants import ReceiptTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
@@ -38,7 +37,11 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         hs: "HomeServer",
     ):
         self._instance_name = hs.get_instance_name()
+        self._receipts_id_gen: AbstractStreamIdTracker
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_receipts = (
@@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 " AND user_id = ?"
             )
             txn.execute(sql, (user_id,))
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, int, int]], txn.fetchall())
 
         rows = await self.db_pool.runInteraction(
             "get_receipts_for_user_with_orderings", f
@@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if not rows:
             return []
 
-        content = {}
+        content: JsonDict = {}
         for row in rows:
             content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
                 row["user_id"]
@@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "_get_linearized_receipts_for_rooms", f
         )
 
-        results = {}
+        results: JsonDict = {}
         for row in txn_results:
             # We want a single event per room, since we want to batch the
             # receipts by room, event and type.
@@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "get_linearized_receipts_for_all_rooms", f
         )
 
-        results = {}
+        results: JsonDict = {}
         for row in txn_results:
             # We want a single event per room, since we want to batch the
             # receipts by room, event and type.
@@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """
 
         if last_id == current_id:
-            return defer.succeed([])
+            return []
 
         def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
             sql = """
@@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
             """
             txn.execute(sql, (last_id, current_id, limit))
 
-            updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
+            updates = cast(
+                List[Tuple[int, list]],
+                [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
+            )
 
             limited = False
             upper_bound = current_id
@@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == ReceiptsStream.NAME:
             self._receipts_id_gen.advance(instance_name, token)
             for row in rows:
@@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
         if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
-            self._remove_old_push_actions_before_txn(
+            self._remove_old_push_actions_before_txn(  # type: ignore[attr-defined]
                 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
 
@@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        async with self._receipts_id_gen.get_next() as stream_id:
+        async with self._receipts_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,