summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_federation.py30
-rw-r--r--synapse/storage/databases/main/events_worker.py132
-rw-r--r--synapse/storage/databases/main/stream.py1
3 files changed, 88 insertions, 75 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 431bd76693..4826be630c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def get_auth_chain(self, event_ids, include_given=False):
+    async def get_auth_chain(self, event_ids, include_given=False):
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
@@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns:
             list of events
         """
-        return self.get_auth_chain_ids(
+        event_ids = await self.get_auth_chain_ids(
             event_ids, include_given=include_given
-        ).addCallback(self.get_events_as_list)
+        )
+        return await self.get_events_as_list(event_ids)
 
     def get_auth_chain_ids(
         self,
@@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
-    def get_backfill_events(self, room_id, event_list, limit):
+    async def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
 
@@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             event_list (list)
             limit (int)
         """
-        return (
-            self.db_pool.runInteraction(
-                "get_backfill_events",
-                self._get_backfill_events,
-                room_id,
-                event_list,
-                limit,
-            )
-            .addCallback(self.get_events_as_list)
-            .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+        event_ids = await self.db_pool.runInteraction(
+            "get_backfill_events",
+            self._get_backfill_events,
+            room_id,
+            event_list,
+            limit,
         )
+        events = await self.get_events_as_list(event_ids)
+        return sorted(events, key=lambda e: -e.depth)
 
     def _get_backfill_events(self, txn, room_id, event_list, limit):
         logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             latest_events,
             limit,
         )
-        events = await self.get_events_as_list(ids)
-        return events
+        return await self.get_events_as_list(ids)
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8c63a0dc4d..e3a154a527 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
 
 from constantly import NamedConstant, Names
+from typing_extensions import Literal
 
 from twisted.internet import defer
 
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersions,
 )
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
             desc="get_received_ts",
         )
 
-    @defer.inlineCallbacks
-    def get_event(
+    # Inform mypy that if allow_none is False (the default) then get_event
+    # always returns an EventBase.
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[False] = False,
+        check_room_id: Optional[str] = None,
+    ) -> EventBase:
+        ...
+
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[True] = False,
+        check_room_id: Optional[str] = None,
+    ) -> Optional[EventBase]:
+        ...
+
+    async def get_event(
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
         allow_rejected: bool = False,
         allow_none: bool = False,
         check_room_id: Optional[str] = None,
-    ):
+    ) -> Optional[EventBase]:
         """Get an event from the database by event_id.
 
         Args:
@@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
-            Deferred[EventBase|None]
+            The event, or None if the event was not found.
         """
         if not isinstance(event_id, str):
             raise TypeError("Invalid event event_id %r" % (event_id,))
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [event_id],
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event
 
-    @defer.inlineCallbacks
-    def get_events(
+    async def get_events(
         self,
-        event_ids: List[str],
+        event_ids: Iterable[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> Dict[str, EventBase]:
         """Get events from the database
 
         Args:
@@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejeted events from the response.
 
         Returns:
-            Deferred : Dict from event_id to event.
+            A mapping from event_id to event.
         """
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             event_ids,
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
-    @defer.inlineCallbacks
-    def get_events_as_list(
+    async def get_events_as_list(
         self,
-        event_ids: List[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> List[EventBase]:
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
@@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejected events from the response.
 
         Returns:
-            Deferred[list[EventBase]]: List of events fetched from the database. The
-            events are in the same order as `event_ids` arg.
+            List of events fetched from the database. The events are in the same
+            order as `event_ids` arg.
 
             Note that the returned list may be smaller than the list of event
             IDs if not all events could be fetched.
@@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
             return []
 
         # there may be duplicates so we cast the list to a set
-        event_entry_map = yield self._get_events_from_cache_or_db(
+        event_entry_map = await self._get_events_from_cache_or_db(
             set(event_ids), allow_rejected=allow_rejected
         )
 
@@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
                     continue
 
                 redacted_event_id = entry.event.redacts
-                event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+                event_map = await self._get_events_from_cache_or_db([redacted_event_id])
                 original_event_entry = event_map.get(redacted_event_id)
                 if not original_event_entry:
                     # we don't have the redacted event (or it was rejected).
@@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             if get_prev_content:
                 if "replaces_state" in event.unsigned:
-                    prev = yield self.get_event(
+                    prev = await self.get_event(
                         event.unsigned["replaces_state"],
                         get_prev_content=False,
                         allow_none=True,
@@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
-    @defer.inlineCallbacks
-    def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result
         """
         event_entry_map = self._get_events_from_cache(
@@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
             # the events have been redacted, and if so pulling the redaction event out
             # of the database to check it.
             #
-            missing_events = yield self._get_events_from_db(
+            missing_events = await self._get_events_from_db(
                 missing_events_ids, allow_rejected=allow_rejected
             )
 
@@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
-    @defer.inlineCallbacks
-    def _get_events_from_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the database.
 
         Returned events will be added to the cache for future lookups.
@@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result. May return extra events which
                 weren't asked for.
         """
@@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
         events_to_fetch = event_ids
 
         while events_to_fetch:
-            row_map = yield self._enqueue_events(events_to_fetch)
+            row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids = set()
@@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events):
+    async def _enqueue_events(self, events):
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore):
             events (Iterable[str]): events to be fetched.
 
         Returns:
-            Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+            Dict[str, Dict]: map from event id to row data from the database.
                 May contain events that weren't requested.
         """
 
@@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
-            row_map = yield events_d
+            row_map = await events_d
         logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
 
         return row_map
@@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    @defer.inlineCallbacks
-    def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield defer.ensureDeferred(
-            self.db_pool.simple_select_many_batch(
-                table="events",
-                retcols=("event_id",),
-                column="event_id",
-                iterable=list(event_ids),
-                keyvalues={"outlier": False},
-                desc="have_events_in_timeline",
-            )
+        rows = await self.db_pool.simple_select_many_batch(
+            table="events",
+            retcols=("event_id",),
+            column="event_id",
+            iterable=list(event_ids),
+            keyvalues={"outlier": False},
+            desc="have_events_in_timeline",
         )
 
         return {r["event_id"] for r in rows}
 
-    @defer.inlineCallbacks
-    def have_seen_events(self, event_ids):
+    async def have_seen_events(self, event_ids):
         """Given a list of event ids, check if we have already processed them.
 
         Args:
             event_ids (iterable[str]):
 
         Returns:
-            Deferred[set[str]]: The events we have already seen.
+            set[str]: The events we have already seen.
         """
         results = set()
 
@@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore):
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )
         return results
@@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    @defer.inlineCallbacks
-    def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id):
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore):
             room_id (str)
 
         Returns:
-            Deferred[dict[str:int]] of complexity version to complexity.
+            dict[str:int] of complexity version to complexity.
         """
-        state_events = yield self.get_current_state_event_counts(room_id)
+        state_events = await self.get_current_state_event_counts(room_id)
 
         # Call this one "v1", so we can introduce new ones as we want to develop
         # it.
@@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_event_ordering(self, event_id):
-        res = yield self.db_pool.simple_select_one(
+    @cached(max_entries=5000)
+    async def get_event_ordering(self, event_id):
+        res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4377bddb8c..497f607703 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         limit: int = 0,
         order: str = "DESC",
     ) -> Tuple[List[EventBase], str]:
-
         """Get new room events in stream ordering since `from_key`.
 
         Args: