diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 84f29e3867..1d0f0058a2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -18,7 +18,7 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.metrics import Measure
from twisted.internet import defer
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6c32e8f7b3..90ec50bb50 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -264,26 +264,20 @@ class StateStore(SQLBaseStore):
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
- num_args=1)
+ num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- def f(txn):
- results = {}
- for event_id in event_ids:
- results[event_id] = self._simple_select_one_onecol_txn(
- txn,
- table="event_to_state_groups",
- keyvalues={
- "event_id": event_id,
- },
- retcol="state_group",
- allow_none=True,
- )
-
- return results
+ rows = yield self._simple_select_many_batch(
+ table="event_to_state_groups",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id", "state_group",),
+ desc="_get_state_group_for_events",
+ )
- return self.runInteraction("_get_state_group_for_events", f)
+ defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups`
|