diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..19ad1c056f 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
import abc
import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
from twisted.internet import defer
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
raise NotImplementedError()
- @cachedInlineCallbacks()
- def get_users_with_read_receipts_in_room(self, room_id):
- receipts = yield self.get_receipts_for_room(room_id, "m.read")
+ @cached()
+ async def get_users_with_read_receipts_in_room(self, room_id):
+ receipts = await self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
- @cachedInlineCallbacks(num_args=2)
- def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self.db_pool.simple_select_list(
+ @cached(num_args=2)
+ async def get_receipts_for_user(self, user_id, receipt_type):
+ rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- @defer.inlineCallbacks
- def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+ async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.db_pool.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
)
return {
@@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows
}
- @defer.inlineCallbacks
- def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def get_linearized_receipts_for_rooms(
+ self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_ids (list): List of room_ids.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_id: List of room_ids.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- list: A list of receipts.
+ A list of receipts.
"""
room_ids = set(room_ids)
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
- room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
- results = yield self._get_linearized_receipts_for_rooms(
+ results = await self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
- def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ async def get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for a single room for sending to clients.
Args:
- room_ids (str): The room id.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_ids: The room id.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- Deferred[list]: A list of receipts.
+ A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
- defer.succeed([])
+ return []
- return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+ return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
- @cachedInlineCallbacks(num_args=3, tree=True)
- def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ @cached(num_args=3, tree=True)
+ async def _get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""See get_linearized_receipts_for_room
"""
@@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
+ rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -212,9 +216,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
- inlineCallbacks=True,
)
- def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -243,7 +246,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
- txn_results = yield self.db_pool.runInteraction(
+ txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -346,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _invalidate_get_users_with_receipts_in_room(
- self, room_id, receipt_type, user_id
+ self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
return
@@ -472,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts
- @defer.inlineCallbacks
- def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+ async def insert_receipt(
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: dict,
+ ) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
- return
+ return None
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
@@ -507,13 +516,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.db_pool.runInteraction(
+ linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- event_ts = yield self.db_pool.runInteraction(
+ event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -535,7 +544,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts,
)
- yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+ await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
|