summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-18 16:20:49 -0400
committerGitHub <noreply@github.com>2020-08-18 16:20:49 -0400
commitf40645e60b9cab69c953094848be61c0989a91cb (patch)
tree2ac850ce839e12a423871860bf288a313a0e1a92 /synapse
parentAdd a link to the matrix-synapse-rest-password-provider. (#8111) (diff)
downloadsynapse-f40645e60b9cab69c953094848be61c0989a91cb.tar.xz
Convert events worker database to async/await. (#8071)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/event_auth.py2
-rw-r--r--synapse/handlers/federation.py16
-rw-r--r--synapse/handlers/message.py6
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/spam_checker_api/__init__.py2
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/storage/databases/main/event_federation.py30
-rw-r--r--synapse/storage/databases/main/events_worker.py132
-rw-r--r--synapse/storage/databases/main/stream.py1
9 files changed, 100 insertions, 93 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c0981eee62..8c907ad596 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -47,7 +47,7 @@ def check(
     Args:
         room_version_obj: the version of the room
         event: the event being checked.
-        auth_events (dict: event-key -> event): the existing room state.
+        auth_events: the existing room state.
 
     Raises:
         AuthError if the checks fail
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 593932adb7..5b270228e7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler):
         """Returns the state at the event. i.e. not including said event.
         """
 
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups(room_id, [event_id])
 
@@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler):
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event.
         """
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
 
@@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler):
         auth_types = auth_types_for_event(event)
         current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
 
-        current_auth_events = await self.store.get_events(current_state_ids)
+        auth_events_map = await self.store.get_events(current_state_ids)
         current_auth_events = {
-            (e.type, e.state_key): e for e in current_auth_events.values()
+            (e.type, e.state_key): e for e in auth_events_map.values()
         }
 
         try:
@@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         # Just go through and process each event in `remote_auth_chain`. We
         # don't want to fall into the trap of `missing` being wrong.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 532fc30681..b999d91d1a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -960,7 +960,7 @@ class EventCreationHandler(object):
                     allow_none=True,
                 )
 
-                is_admin_redaction = (
+                is_admin_redaction = bool(
                     original_event and event.sender != original_event.sender
                 )
 
@@ -1080,8 +1080,8 @@ class EventCreationHandler(object):
             auth_events_ids = self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
-            auth_events = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+            auth_events_map = await self.store.get_events(auth_events_ids)
+            auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
 
             room_version = await self.store.get_room_version_id(event.room_id)
             room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 31705cdbdb..aa1ccde211 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -716,7 +716,7 @@ class RoomMemberHandler(object):
 
         guest_access = await self.store.get_event(guest_access_id)
 
-        return (
+        return bool(
             guest_access
             and guest_access.content
             and "guest_access" in guest_access.content
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 9b78924d96..4d9b13ac04 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -51,5 +51,5 @@ class SpamCheckerApi(object):
         state_ids = yield self._store.get_filtered_current_state_ids(
             room_id=room_id, state_filter=StateFilter.from_types(types)
         )
-        state = yield self._store.get_events(state_ids.values())
+        state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
         return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index a1d3884667..dba8d91eef 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -641,7 +641,7 @@ class StateResolutionStore(object):
             allow_rejected (bool): If True return rejected events.
 
         Returns:
-            Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+            Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
         """
 
         return self.store.get_events(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 431bd76693..4826be630c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def get_auth_chain(self, event_ids, include_given=False):
+    async def get_auth_chain(self, event_ids, include_given=False):
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
@@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns:
             list of events
         """
-        return self.get_auth_chain_ids(
+        event_ids = await self.get_auth_chain_ids(
             event_ids, include_given=include_given
-        ).addCallback(self.get_events_as_list)
+        )
+        return await self.get_events_as_list(event_ids)
 
     def get_auth_chain_ids(
         self,
@@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
-    def get_backfill_events(self, room_id, event_list, limit):
+    async def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
 
@@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             event_list (list)
             limit (int)
         """
-        return (
-            self.db_pool.runInteraction(
-                "get_backfill_events",
-                self._get_backfill_events,
-                room_id,
-                event_list,
-                limit,
-            )
-            .addCallback(self.get_events_as_list)
-            .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+        event_ids = await self.db_pool.runInteraction(
+            "get_backfill_events",
+            self._get_backfill_events,
+            room_id,
+            event_list,
+            limit,
         )
+        events = await self.get_events_as_list(event_ids)
+        return sorted(events, key=lambda e: -e.depth)
 
     def _get_backfill_events(self, txn, room_id, event_list, limit):
         logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             latest_events,
             limit,
         )
-        events = await self.get_events_as_list(ids)
-        return events
+        return await self.get_events_as_list(ids)
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8c63a0dc4d..e3a154a527 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
 
 from constantly import NamedConstant, Names
+from typing_extensions import Literal
 
 from twisted.internet import defer
 
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersions,
 )
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
             desc="get_received_ts",
         )
 
-    @defer.inlineCallbacks
-    def get_event(
+    # Inform mypy that if allow_none is False (the default) then get_event
+    # always returns an EventBase.
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[False] = False,
+        check_room_id: Optional[str] = None,
+    ) -> EventBase:
+        ...
+
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[True] = False,
+        check_room_id: Optional[str] = None,
+    ) -> Optional[EventBase]:
+        ...
+
+    async def get_event(
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
         allow_rejected: bool = False,
         allow_none: bool = False,
         check_room_id: Optional[str] = None,
-    ):
+    ) -> Optional[EventBase]:
         """Get an event from the database by event_id.
 
         Args:
@@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
-            Deferred[EventBase|None]
+            The event, or None if the event was not found.
         """
         if not isinstance(event_id, str):
             raise TypeError("Invalid event event_id %r" % (event_id,))
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [event_id],
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event
 
-    @defer.inlineCallbacks
-    def get_events(
+    async def get_events(
         self,
-        event_ids: List[str],
+        event_ids: Iterable[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> Dict[str, EventBase]:
         """Get events from the database
 
         Args:
@@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejeted events from the response.
 
         Returns:
-            Deferred : Dict from event_id to event.
+            A mapping from event_id to event.
         """
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             event_ids,
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
-    @defer.inlineCallbacks
-    def get_events_as_list(
+    async def get_events_as_list(
         self,
-        event_ids: List[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> List[EventBase]:
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
@@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejected events from the response.
 
         Returns:
-            Deferred[list[EventBase]]: List of events fetched from the database. The
-            events are in the same order as `event_ids` arg.
+            List of events fetched from the database. The events are in the same
+            order as `event_ids` arg.
 
             Note that the returned list may be smaller than the list of event
             IDs if not all events could be fetched.
@@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
             return []
 
         # there may be duplicates so we cast the list to a set
-        event_entry_map = yield self._get_events_from_cache_or_db(
+        event_entry_map = await self._get_events_from_cache_or_db(
             set(event_ids), allow_rejected=allow_rejected
         )
 
@@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
                     continue
 
                 redacted_event_id = entry.event.redacts
-                event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+                event_map = await self._get_events_from_cache_or_db([redacted_event_id])
                 original_event_entry = event_map.get(redacted_event_id)
                 if not original_event_entry:
                     # we don't have the redacted event (or it was rejected).
@@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             if get_prev_content:
                 if "replaces_state" in event.unsigned:
-                    prev = yield self.get_event(
+                    prev = await self.get_event(
                         event.unsigned["replaces_state"],
                         get_prev_content=False,
                         allow_none=True,
@@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
-    @defer.inlineCallbacks
-    def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result
         """
         event_entry_map = self._get_events_from_cache(
@@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
             # the events have been redacted, and if so pulling the redaction event out
             # of the database to check it.
             #
-            missing_events = yield self._get_events_from_db(
+            missing_events = await self._get_events_from_db(
                 missing_events_ids, allow_rejected=allow_rejected
             )
 
@@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
-    @defer.inlineCallbacks
-    def _get_events_from_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the database.
 
         Returned events will be added to the cache for future lookups.
@@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result. May return extra events which
                 weren't asked for.
         """
@@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
         events_to_fetch = event_ids
 
         while events_to_fetch:
-            row_map = yield self._enqueue_events(events_to_fetch)
+            row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids = set()
@@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events):
+    async def _enqueue_events(self, events):
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore):
             events (Iterable[str]): events to be fetched.
 
         Returns:
-            Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+            Dict[str, Dict]: map from event id to row data from the database.
                 May contain events that weren't requested.
         """
 
@@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
-            row_map = yield events_d
+            row_map = await events_d
         logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
 
         return row_map
@@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    @defer.inlineCallbacks
-    def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield defer.ensureDeferred(
-            self.db_pool.simple_select_many_batch(
-                table="events",
-                retcols=("event_id",),
-                column="event_id",
-                iterable=list(event_ids),
-                keyvalues={"outlier": False},
-                desc="have_events_in_timeline",
-            )
+        rows = await self.db_pool.simple_select_many_batch(
+            table="events",
+            retcols=("event_id",),
+            column="event_id",
+            iterable=list(event_ids),
+            keyvalues={"outlier": False},
+            desc="have_events_in_timeline",
         )
 
         return {r["event_id"] for r in rows}
 
-    @defer.inlineCallbacks
-    def have_seen_events(self, event_ids):
+    async def have_seen_events(self, event_ids):
         """Given a list of event ids, check if we have already processed them.
 
         Args:
             event_ids (iterable[str]):
 
         Returns:
-            Deferred[set[str]]: The events we have already seen.
+            set[str]: The events we have already seen.
         """
         results = set()
 
@@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore):
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )
         return results
@@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    @defer.inlineCallbacks
-    def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id):
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore):
             room_id (str)
 
         Returns:
-            Deferred[dict[str:int]] of complexity version to complexity.
+            dict[str:int] of complexity version to complexity.
         """
-        state_events = yield self.get_current_state_event_counts(room_id)
+        state_events = await self.get_current_state_event_counts(room_id)
 
         # Call this one "v1", so we can introduce new ones as we want to develop
         # it.
@@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_event_ordering(self, event_id):
-        res = yield self.db_pool.simple_select_one(
+    @cached(max_entries=5000)
+    async def get_event_ordering(self, event_id):
+        res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4377bddb8c..497f607703 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         limit: int = 0,
         order: str = "DESC",
     ) -> Tuple[List[EventBase], str]:
-
         """Get new room events in stream ordering since `from_key`.
 
         Args: