diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index bf0b903af2..e6f97aeece 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -24,10 +24,9 @@ from typing import (
Optional,
Set,
Tuple,
+ cast,
)
-from twisted.internet import defer
-
from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
@@ -38,7 +37,11 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()
+ self._receipts_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
@@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, int]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
@@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not rows:
return []
- content = {}
+ content: JsonDict = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
@@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"_get_linearized_receipts_for_rooms", f
)
- results = {}
+ results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
@@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_linearized_receipts_for_all_rooms", f
)
- results = {}
+ results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
@@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
if last_id == current_id:
- return defer.succeed([])
+ return []
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
@@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
+ updates = cast(
+ List[Tuple[int, list]],
+ [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
+ )
limited = False
upper_bound = current_id
@@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
@@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
- self._remove_old_push_actions_before_txn(
+ self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear
)
- async with self._receipts_id_gen.get_next() as stream_id:
+ async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
|