summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-14 10:05:19 -0400
committerGitHub <noreply@github.com>2020-08-14 10:05:19 -0400
commite8861957d9005a9f6cad050a55a478b7706f34c9 (patch)
tree685ce94358811e1735cd32b9078fb019c3cca565
parentAdd type hints to synapse.handlers.room (#8090) (diff)
downloadsynapse-e8861957d9005a9f6cad050a55a478b7706f34c9.tar.xz
Convert receipts and events databases to async/await. (#8076)
-rw-r--r--changelog.d/8076.misc1
-rw-r--r--synapse/storage/databases/main/events.py33
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py46
-rw-r--r--synapse/storage/databases/main/receipts.py82
4 files changed, 80 insertions, 82 deletions
diff --git a/changelog.d/8076.misc b/changelog.d/8076.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8076.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..b90e6de2d5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
 
 import attr
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
@@ -113,15 +111,14 @@ class PersistEventsStore:
             hs.config.worker.writers.events == hs.get_instance_name()
         ), "Can only instantiate EventsStore on master"
 
-    @defer.inlineCallbacks
-    def _persist_events_and_state_updates(
+    async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
         backfilled: bool = False,
-    ):
+    ) -> None:
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
 
@@ -136,7 +133,7 @@ class PersistEventsStore:
             backfilled
 
         Returns:
-            Deferred: resolves when the events have been persisted
+            Resolves when the events have been persisted
         """
 
         # We want to calculate the stream orderings as late as possible, as
@@ -168,7 +165,7 @@ class PersistEventsStore:
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
@@ -206,16 +203,15 @@ class PersistEventsStore:
                     (room_id,), list(latest_event_ids)
                 )
 
-    @defer.inlineCallbacks
-    def _get_events_which_are_prevs(self, event_ids):
+    async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
 
         Args:
-            event_ids (Iterable[str]): event ids to filter
+            event_ids: event ids to filter
 
         Returns:
-            Deferred[List[str]]: filtered event ids
+            Filtered event ids
         """
         results = []
 
@@ -240,14 +236,13 @@ class PersistEventsStore:
             results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
             )
 
         return results
 
-    @defer.inlineCallbacks
-    def _get_prevs_before_rejected(self, event_ids):
+    async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
         """Get soft-failed ancestors to remove from the extremities.
 
         Given a set of events, find all those that have been soft-failed or
@@ -259,11 +254,11 @@ class PersistEventsStore:
         are separated by soft failed events.
 
         Args:
-            event_ids (Iterable[str]): Events to find prev events for. Note
-                that these must have already been persisted.
+            event_ids: Events to find prev events for. Note that these must have
+                already been persisted.
 
         Returns:
-            Deferred[set[str]]
+            The previous events.
         """
 
         # The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +299,7 @@ class PersistEventsStore:
                         existing_prevs.add(prev_event_id)
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
             )
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 35a0e09e3c..e53c6373a8 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventContentFields
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             where_clause="NOT have_censored",
         )
 
-    @defer.inlineCallbacks
-    def _background_reindex_fields_sender(self, progress, batch_size):
+    async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows)
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _background_reindex_origin_server_ts(self, progress, batch_size):
+    async def _background_reindex_origin_server_ts(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows_to_update)
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_ORIGIN_SERVER_TS_NAME
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _cleanup_extremities_bg_update(self, progress, batch_size):
+    async def _cleanup_extremities_bg_update(self, progress, batch_size):
         """Background update to clean out extremities that should have been
         deleted previously.
 
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(original_set)
 
-        num_handled = yield self.db_pool.runInteraction(
+        num_handled = await self.db_pool.runInteraction(
             "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
         )
 
         if not num_handled:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.DELETE_SOFT_FAILED_EXTREMITIES
             )
 
             def _drop_table_txn(txn):
                 txn.execute("DROP TABLE _extremities_to_check")
 
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
             )
 
         return num_handled
 
-    @defer.inlineCallbacks
-    def _redactions_received_ts(self, progress, batch_size):
+    async def _redactions_received_ts(self, progress, batch_size):
         """Handles filling out the `received_ts` column in redactions.
         """
         last_event_id = progress.get("last_event_id", "")
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows)
 
-        count = yield self.db_pool.runInteraction(
+        count = await self.db_pool.runInteraction(
             "_redactions_received_ts", _redactions_received_ts_txn
         )
 
         if not count:
-            yield self.db_pool.updates._end_background_update("redactions_received_ts")
+            await self.db_pool.updates._end_background_update("redactions_received_ts")
 
         return count
 
-    @defer.inlineCallbacks
-    def _event_fix_redactions_bytes(self, progress, batch_size):
+    async def _event_fix_redactions_bytes(self, progress, batch_size):
         """Undoes hex encoded censored redacted event JSON.
         """
 
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             txn.execute("DROP INDEX redactions_censored_redacts")
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
         )
 
-        yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
+        await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
 
         return 1
 
-    @defer.inlineCallbacks
-    def _event_store_labels(self, progress, batch_size):
+    async def _event_store_labels(self, progress, batch_size):
         """Background update handler which will store labels for existing events."""
         last_event_id = progress.get("last_event_id", "")
 
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return nbrows
 
-        num_rows = yield self.db_pool.runInteraction(
+        num_rows = await self.db_pool.runInteraction(
             desc="event_store_labels", func=_event_store_labels_txn
         )
 
         if not num_rows:
-            yield self.db_pool.updates._end_background_update("event_store_labels")
+            await self.db_pool.updates._end_background_update("event_store_labels")
 
         return num_rows
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 579b7bb17b..19ad1c056f 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
 
 import abc
 import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """
         raise NotImplementedError()
 
-    @cachedInlineCallbacks()
-    def get_users_with_read_receipts_in_room(self, room_id):
-        receipts = yield self.get_receipts_for_room(room_id, "m.read")
+    @cached()
+    async def get_users_with_read_receipts_in_room(self, room_id):
+        receipts = await self.get_receipts_for_room(room_id, "m.read")
         return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
@@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    @cachedInlineCallbacks(num_args=2)
-    def get_receipts_for_user(self, user_id, receipt_type):
-        rows = yield self.db_pool.simple_select_list(
+    @cached(num_args=2)
+    async def get_receipts_for_user(self, user_id, receipt_type):
+        rows = await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
             retcols=("room_id", "event_id"),
@@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return {row["room_id"]: row["event_id"] for row in rows}
 
-    @defer.inlineCallbacks
-    def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+    async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
         def f(txn):
             sql = (
                 "SELECT rl.room_id, rl.event_id,"
@@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return txn.fetchall()
 
-        rows = yield self.db_pool.runInteraction(
+        rows = await self.db_pool.runInteraction(
             "get_receipts_for_user_with_orderings", f
         )
         return {
@@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    @defer.inlineCallbacks
-    def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def get_linearized_receipts_for_rooms(
+        self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """Get receipts for multiple rooms for sending to clients.
 
         Args:
-            room_ids (list): List of room_ids.
-            to_key (int): Max stream id to fetch receipts upto.
-            from_key (int): Min stream id to fetch receipts from. None fetches
+            room_id: List of room_ids.
+            to_key: Max stream id to fetch receipts upto.
+            from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
         Returns:
-            list: A list of receipts.
+            A list of receipts.
         """
         room_ids = set(room_ids)
 
         if from_key is not None:
             # Only ask the database about rooms where there have been new
             # receipts added since `from_key`
-            room_ids = yield self._receipts_stream_cache.get_entities_changed(
+            room_ids = self._receipts_stream_cache.get_entities_changed(
                 room_ids, from_key
             )
 
-        results = yield self._get_linearized_receipts_for_rooms(
+        results = await self._get_linearized_receipts_for_rooms(
             room_ids, to_key, from_key=from_key
         )
 
         return [ev for res in results.values() for ev in res]
 
-    def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+    async def get_linearized_receipts_for_room(
+        self, room_id: str, to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """Get receipts for a single room for sending to clients.
 
         Args:
-            room_ids (str): The room id.
-            to_key (int): Max stream id to fetch receipts upto.
-            from_key (int): Min stream id to fetch receipts from. None fetches
+            room_ids: The room id.
+            to_key: Max stream id to fetch receipts upto.
+            from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
         Returns:
-            Deferred[list]: A list of receipts.
+            A list of receipts.
         """
         if from_key is not None:
             # Check the cache first to see if any new receipts have been added
             # since`from_key`. If not we can no-op.
             if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
-                defer.succeed([])
+                return []
 
-        return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+        return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
 
-    @cachedInlineCallbacks(num_args=3, tree=True)
-    def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+    @cached(num_args=3, tree=True)
+    async def _get_linearized_receipts_for_room(
+        self, room_id: str, to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """See get_linearized_receipts_for_room
         """
 
@@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return rows
 
-        rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
+        rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
 
         if not rows:
             return []
@@ -345,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     def _invalidate_get_users_with_receipts_in_room(
-        self, room_id, receipt_type, user_id
+        self, room_id: str, receipt_type: str, user_id: str
     ):
         if receipt_type != "m.read":
             return
@@ -471,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
         return rx_ts
 
-    @defer.inlineCallbacks
-    def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+    async def insert_receipt(
+        self,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: dict,
+    ) -> Optional[Tuple[int, int]]:
         """Insert a receipt, either from local client or remote server.
 
         Automatically does conversion between linearized and graph
         representations.
         """
         if not event_ids:
-            return
+            return None
 
         if len(event_ids) == 1:
             linearized_event_id = event_ids[0]
@@ -506,13 +516,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 else:
                     raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
 
-            linearized_event_id = yield self.db_pool.runInteraction(
+            linearized_event_id = await self.db_pool.runInteraction(
                 "insert_receipt_conv", graph_to_linear
             )
 
         stream_id_manager = self._receipts_id_gen.get_next()
         with stream_id_manager as stream_id:
-            event_ts = yield self.db_pool.runInteraction(
+            event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
                 room_id,
@@ -534,7 +544,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             now - event_ts,
         )
 
-        yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+        await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
 
         max_persisted_id = self._receipts_id_gen.get_current_token()