summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-28 09:37:55 -0400
committerGitHub <noreply@github.com>2020-08-28 09:37:55 -0400
commitaec7085179464cc0a05177108ad2962d09ca4f4a (patch)
treeaccd3496816a54257d835982a95fe73995428987
parentEnsure that the OpenID Connect remote ID is a string. (#8190) (diff)
downloadsynapse-aec7085179464cc0a05177108ad2962d09ca4f4a.tar.xz
Convert state and stream stores and related code to async (#8194)
Diffstat (limited to '')
-rw-r--r--changelog.d/8194.misc1
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/storage/databases/main/state.py19
-rw-r--r--synapse/storage/databases/main/state_deltas.py21
-rw-r--r--synapse/storage/databases/main/stream.py11
-rw-r--r--synapse/storage/databases/state/store.py26
-rw-r--r--synapse/storage/state.py16
7 files changed, 51 insertions, 45 deletions
diff --git a/changelog.d/8194.misc b/changelog.d/8194.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8194.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 1419d72e94..9d5b1828df 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -451,7 +451,7 @@ class RoomCreationHandler(BaseHandler):
         old_room_member_state_events = await self.store.get_events(
             old_room_member_state_ids.values()
         )
-        for k, old_event in old_room_member_state_events.items():
+        for old_event in old_room_member_state_events.values():
             # Only transfer ban events
             if (
                 "membership" in old_event.content
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 458f169617..5c6168e301 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
+from synapse.types import StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return create_event
 
     @cached(max_entries=100000, iterable=True)
-    def get_current_state_ids(self, room_id):
+    async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
         """Get the current state event ids for a room based on the
         current_state_events table.
 
         Args:
-            room_id (str)
+            room_id: The room to get the state IDs of.
 
         Returns:
-            deferred: dict of (type, state_key) -> event_id
+            The current state of the room.
         """
 
         def _get_current_state_ids_txn(txn):
@@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_current_state_ids", _get_current_state_ids_txn
         )
 
     # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(
+    async def get_filtered_current_state_ids(
         self, room_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[str]:
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
         of doing a fresh state resolution as per state_handler.get_current_state
@@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 from the database.
 
         Returns:
-            defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+            Map from type/state_key to event ID.
         """
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
 
         if not where_clause:
             # We delegate to the cached version
-            return self.get_current_state_ids(room_id)
+            return await self.get_current_state_ids(room_id)
 
         def _get_filtered_current_state_ids_txn(txn):
             results = {}
@@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return results
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )
 
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 0d963c98ff..356623fc6e 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -14,8 +14,7 @@
 # limitations under the License.
 
 import logging
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Tuple
 
 from synapse.storage._base import SQLBaseStore
 
@@ -23,7 +22,9 @@ logger = logging.getLogger(__name__)
 
 
 class StateDeltasStore(SQLBaseStore):
-    def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
+    async def get_current_state_deltas(
+        self, prev_stream_id: int, max_stream_id: int
+    ) -> Tuple[int, List[Dict[str, Any]]]:
         """Fetch a list of room state changes since the given stream id
 
         Each entry in the result contains the following fields:
@@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore):
                 if it's new state.
 
         Args:
-            prev_stream_id (int): point to get changes since (exclusive)
-            max_stream_id (int): the point that we know has been correctly persisted
+            prev_stream_id: point to get changes since (exclusive)
+            max_stream_id: the point that we know has been correctly persisted
                - ie, an upper limit to return changes from.
 
         Returns:
-            Deferred[tuple[int, list[dict]]: A tuple consisting of:
+            A tuple consisting of:
                - the stream id which these results go up to
                - list of current_state_delta_stream rows. If it is empty, we are
                  up to date.
@@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
             # if the CSDs haven't changed between prev_stream_id and now, we
             # know for certain that they haven't changed between prev_stream_id and
             # max_stream_id.
-            return defer.succeed((max_stream_id, []))
+            return (max_stream_id, [])
 
         def get_current_state_deltas_txn(txn):
             # First we calculate the max stream id that will give us less than
@@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore):
             txn.execute(sql, (prev_stream_id, clipped_stream_id))
             return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_current_state_deltas", get_current_state_deltas_txn
         )
 
@@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore):
             retcol="COALESCE(MAX(stream_id), -1)",
         )
 
-    def get_max_stream_id_in_current_state_deltas(self):
-        return self.db_pool.runInteraction(
+    async def get_max_stream_id_in_current_state_deltas(self):
+        return await self.db_pool.runInteraction(
             "get_max_stream_id_in_current_state_deltas",
             self._get_max_stream_id_in_current_state_deltas_txn,
         )
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 497f607703..24f44a7e36 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -539,7 +539,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, token
 
-    def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
+    async def get_room_event_before_stream_ordering(
+        self, room_id: str, stream_ordering: int
+    ) -> Tuple[int, int, str]:
         """Gets details of the first event in a room at or before a stream ordering
 
         Args:
@@ -547,8 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             stream_ordering:
 
         Returns:
-            Deferred[(int, int, str)]:
-                (stream ordering, topological ordering, event_id)
+            A tuple of (stream ordering, topological ordering, event_id)
         """
 
         def _f(txn):
@@ -563,7 +564,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn.execute(sql, (room_id, stream_ordering))
             return txn.fetchone()
 
-        return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
+        return await self.db_pool.runInteraction(
+            "get_room_event_before_stream_ordering", _f
+        )
 
     async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
         """Returns the current token for rooms stream.
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7f104ad936..e924f1ca3b 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -17,8 +17,6 @@ import logging
 from collections import namedtuple
 from typing import Dict, Iterable, List, Set, Tuple
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
@@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     @cached(max_entries=10000, iterable=True)
-    def get_state_group_delta(self, state_group):
+    async def get_state_group_delta(self, state_group):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
@@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_state_group_delta", _get_state_group_delta_txn
         )
 
@@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 fetched_keys=non_member_types,
             )
 
-    def store_state_group(
+    async def store_state_group(
         self, event_id, room_id, prev_group, delta_ids, current_state_ids
-    ):
+    ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
@@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 to event_id.
 
         Returns:
-            Deferred[int]: The state group ID
+            The state group ID
         """
 
         def _store_state_group_txn(txn):
@@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
             return state_group
 
-        return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
+        return await self.db_pool.runInteraction(
+            "store_state_group", _store_state_group_txn
+        )
 
-    def purge_unreferenced_state_groups(
+    async def purge_unreferenced_state_groups(
         self, room_id: str, state_groups_to_delete
-    ) -> defer.Deferred:
+    ) -> None:
         """Deletes no longer referenced state groups and de-deltas any state
         groups that reference them.
 
@@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 to delete.
         """
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "purge_unreferenced_state_groups",
             self._purge_unreferenced_state_groups,
             room_id,
@@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return {row["state_group"]: row["prev_state_group"] for row in rows}
 
-    def purge_room_state(self, room_id, state_groups_to_delete):
+    async def purge_room_state(self, room_id, state_groups_to_delete):
         """Deletes all record of a room from state tables
 
         Args:
@@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_groups_to_delete (list[int]): State groups to delete
         """
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "purge_room_state",
             self._purge_room_state_txn,
             room_id,
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 534883361f..96a1b59d64 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -333,7 +333,7 @@ class StateGroupStorage(object):
     def __init__(self, hs, stores):
         self.stores = stores
 
-    def get_state_group_delta(self, state_group: int):
+    async def get_state_group_delta(self, state_group: int):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
@@ -341,11 +341,11 @@ class StateGroupStorage(object):
             state_group: The state group used to retrieve state deltas.
 
         Returns:
-            Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
+            Tuple[Optional[int], Optional[StateMap[str]]]:
                 (prev_group, delta_ids)
         """
 
-        return self.stores.state.get_state_group_delta(state_group)
+        return await self.stores.state.get_state_group_delta(state_group)
 
     async def get_state_groups_ids(
         self, _room_id: str, event_ids: Iterable[str]
@@ -525,7 +525,7 @@ class StateGroupStorage(object):
             state_filter: The state filter used to fetch state from the database.
 
         Returns:
-            A deferred dict from (type, state_key) -> state_event
+            A dict from (type, state_key) -> state_event
         """
         state_map = await self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
@@ -546,14 +546,14 @@ class StateGroupStorage(object):
         """
         return self.stores.state._get_state_for_groups(groups, state_filter)
 
-    def store_state_group(
+    async def store_state_group(
         self,
         event_id: str,
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[dict],
         current_state_ids: dict,
-    ):
+    ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
@@ -567,8 +567,8 @@ class StateGroupStorage(object):
                 to event_id.
 
         Returns:
-            Deferred[int]: The state group ID
+            The state group ID
         """
-        return self.stores.state.store_state_group(
+        return await self.stores.state.store_state_group(
             event_id, room_id, prev_group, delta_ids, current_state_ids
         )