diff --git a/changelog.d/8194.misc b/changelog.d/8194.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8194.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 1419d72e94..9d5b1828df 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -451,7 +451,7 @@ class RoomCreationHandler(BaseHandler):
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
- for k, old_event in old_room_member_state_events.items():
+ for old_event in old_room_member_state_events.values():
# Only transfer ban events
if (
"membership" in old_event.content
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 458f169617..5c6168e301 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
+from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
- def get_current_state_ids(self, room_id):
+ async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
Args:
- room_id (str)
+ room_id: The room to get the state IDs of.
Returns:
- deferred: dict of (type, state_key) -> event_id
+ The current state of the room.
"""
def _get_current_state_ids_txn(txn):
@@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(
+ async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
@@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
from the database.
Returns:
- defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+ Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version
- return self.get_current_state_ids(room_id)
+ return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn):
results = {}
@@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 0d963c98ff..356623fc6e 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -14,8 +14,7 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Tuple
from synapse.storage._base import SQLBaseStore
@@ -23,7 +22,9 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
- def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
+ async def get_current_state_deltas(
+ self, prev_stream_id: int, max_stream_id: int
+ ) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore):
if it's new state.
Args:
- prev_stream_id (int): point to get changes since (exclusive)
- max_stream_id (int): the point that we know has been correctly persisted
+ prev_stream_id: point to get changes since (exclusive)
+ max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns:
- Deferred[tuple[int, list[dict]]: A tuple consisting of:
+ A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
@@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
- return defer.succeed((max_stream_id, []))
+ return (max_stream_id, [])
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore):
txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
@@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore):
retcol="COALESCE(MAX(stream_id), -1)",
)
- def get_max_stream_id_in_current_state_deltas(self):
- return self.db_pool.runInteraction(
+ async def get_max_stream_id_in_current_state_deltas(self):
+ return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 497f607703..24f44a7e36 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -539,7 +539,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
- def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
+ async def get_room_event_before_stream_ordering(
+ self, room_id: str, stream_ordering: int
+ ) -> Tuple[int, int, str]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
@@ -547,8 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
stream_ordering:
Returns:
- Deferred[(int, int, str)]:
- (stream ordering, topological ordering, event_id)
+ A tuple of (stream ordering, topological ordering, event_id)
"""
def _f(txn):
@@ -563,7 +564,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
- return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
+ return await self.db_pool.runInteraction(
+ "get_room_event_before_stream_ordering", _f
+ )
async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream.
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7f104ad936..e924f1ca3b 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -17,8 +17,6 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@cached(max_entries=10000, iterable=True)
- def get_state_group_delta(self, state_group):
+ async def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn
)
@@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
- def store_state_group(
+ async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
- ):
+ ) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
@@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to event_id.
Returns:
- Deferred[int]: The state group ID
+ The state group ID
"""
def _store_state_group_txn(txn):
@@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_group
- return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
+ return await self.db_pool.runInteraction(
+ "store_state_group", _store_state_group_txn
+ )
- def purge_unreferenced_state_groups(
+ async def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete
- ) -> defer.Deferred:
+ ) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to delete.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
@@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows}
- def purge_room_state(self, room_id, state_groups_to_delete):
+ async def purge_room_state(self, room_id, state_groups_to_delete):
"""Deletes all record of a room from state tables
Args:
@@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete (list[int]): State groups to delete
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 534883361f..96a1b59d64 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -333,7 +333,7 @@ class StateGroupStorage(object):
def __init__(self, hs, stores):
self.stores = stores
- def get_state_group_delta(self, state_group: int):
+ async def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -341,11 +341,11 @@ class StateGroupStorage(object):
state_group: The state group used to retrieve state deltas.
Returns:
- Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
+ Tuple[Optional[int], Optional[StateMap[str]]]:
(prev_group, delta_ids)
"""
- return self.stores.state.get_state_group_delta(state_group)
+ return await self.stores.state.get_state_group_delta(state_group)
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
@@ -525,7 +525,7 @@ class StateGroupStorage(object):
state_filter: The state filter used to fetch state from the database.
Returns:
- A deferred dict from (type, state_key) -> state_event
+ A dict from (type, state_key) -> state_event
"""
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
@@ -546,14 +546,14 @@ class StateGroupStorage(object):
"""
return self.stores.state._get_state_for_groups(groups, state_filter)
- def store_state_group(
+ async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
- ):
+ ) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
@@ -567,8 +567,8 @@ class StateGroupStorage(object):
to event_id.
Returns:
- Deferred[int]: The state group ID
+ The state group ID
"""
- return self.stores.state.store_state_group(
+ return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
|