diff options
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 37 |
1 files changed, 25 insertions, 12 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index bf0b903af2..e6f97aeece 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -24,10 +24,9 @@ from typing import ( Optional, Set, Tuple, + cast, ) -from twisted.internet import defer - from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream @@ -38,7 +37,11 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( @@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore): " AND user_id = ?" ) txn.execute(sql, (user_id,)) - return txn.fetchall() + return cast(List[Tuple[str, str, int, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( "get_receipts_for_user_with_orderings", f @@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if not rows: return [] - content = {} + content: JsonDict = {} for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] @@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "_get_linearized_receipts_for_rooms", f ) - results = {} + results: JsonDict = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. @@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "get_linearized_receipts_for_all_rooms", f ) - results = {} + results: JsonDict = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. @@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore): """ if last_id == current_id: - return defer.succeed([]) + return [] def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ @@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore): """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] + updates = cast( + List[Tuple[int, list]], + [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], + ) limited = False upper_bound = current_id @@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore): self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == ReceiptsStream.NAME: self._receipts_id_gen.advance(instance_name, token) for row in rows: @@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) if receipt_type == ReceiptTypes.READ and stream_ordering is not None: - self._remove_old_push_actions_before_txn( + self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "insert_receipt_conv", graph_to_linear ) - async with self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, |