diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdb..9c5625c8bb 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,25 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
from twisted.internet import defer
+from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
@@ -78,17 +89,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
- def get_max_receipt_stream_id(self):
- """Get the current max stream ID for receipts stream
-
- Returns:
- int
- """
+ def get_max_receipt_stream_id(self) -> int:
+ """Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
@cached()
- async def get_users_with_read_receipts_in_room(self, room_id):
- receipts = await self.get_receipts_for_room(room_id, "m.read")
+ async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
+ receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -119,7 +126,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=2)
- async def get_receipts_for_user(self, user_id, receipt_type):
+ async def get_receipts_for_user(
+ self, user_id: str, receipt_type: str
+ ) -> Dict[str, str]:
rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -129,8 +138,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
- def f(txn):
+ async def get_receipts_for_user_with_orderings(
+ self, user_id: str, receipt_type: str
+ ) -> JsonDict:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
@@ -209,10 +220,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> List[JsonDict]:
"""See get_linearized_receipts_for_room"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +261,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
list_name="room_ids",
num_args=3,
)
- async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(
+ self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ ) -> Dict[str, List[JsonDict]]:
if not room_ids:
return {}
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
@@ -323,7 +336,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
@@ -379,7 +392,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if last_id == current_id:
return defer.succeed([])
- def _get_users_sent_receipts_between_txn(txn):
+ def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT user_id FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +432,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_receipts_txn(txn):
+ def get_all_updated_receipts_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized
@@ -446,8 +461,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
def _invalidate_get_users_with_receipts_in_room(
self, room_id: str, receipt_type: str, user_id: str
- ):
- if receipt_type != "m.read":
+ ) -> None:
+ if receipt_type != ReceiptTypes.READ:
return
res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@@ -461,7 +476,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
- def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+ def invalidate_caches_for_receipt(
+ self, room_id: str, receipt_type: str, user_id: str
+ ) -> None:
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
@@ -482,11 +499,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
- self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_id: str,
+ data: JsonDict,
+ stream_id: int,
+ ) -> Optional[int]:
"""Inserts a read-receipt into the database if it's newer than the current RR
- Returns: int|None
+ Returns:
None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
@@ -550,7 +574,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
lock=False,
)
- if receipt_type == "m.read" and stream_ordering is not None:
+ if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@@ -580,7 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
- def graph_to_linear(txn):
+ def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
@@ -634,11 +658,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
return stream_id, max_persisted_id
async def insert_graph_receipt(
- self, room_id, receipt_type, user_id, event_ids, data
- ):
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: JsonDict,
+ ) -> None:
assert self._can_write_to_receipts
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -649,8 +678,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def insert_graph_receipt_txn(
- self, txn, room_id, receipt_type, user_id, event_ids, data
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: JsonDict,
+ ) -> None:
assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
|