diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b24de34f23..7ce51b9bdc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@@ -81,31 +81,41 @@ class StateStore(SQLBaseStore):
f,
)
- @defer.inlineCallbacks
- def c(vals):
- vals[:] = yield self._get_events(vals, get_prev_content=False)
-
- yield defer.gatherResults(
+ state_list = yield defer.gatherResults(
[
- c(vals)
- for vals in states.values()
+ self._fetch_events_for_group(group, vals)
+ for group, vals in states.items()
],
consumeErrors=True,
)
- defer.returnValue(states)
+ defer.returnValue(dict(state_list))
+
+ def _fetch_events_for_group(self, key, events):
+ return self._get_events(
+ events, get_prev_content=False
+ ).addCallback(
+ lambda evs: (key, evs)
+ )
def _store_state_groups_txn(self, txn, event, context):
- if context.current_state is None:
- return
+ return self._store_mult_state_groups_txn(txn, [(event, context)])
- state_events = dict(context.current_state)
+ def _store_mult_state_groups_txn(self, txn, events_and_contexts):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if context.current_state is None:
+ continue
- if event.is_state():
- state_events[(event.type, event.state_key)] = event
+ if context.state_group is not None:
+ state_groups[event.event_id] = context.state_group
+ continue
+
+ state_events = dict(context.current_state)
+
+ if event.is_state():
+ state_events[(event.type, event.state_key)] = event
- state_group = context.state_group
- if not state_group:
state_group = self._state_groups_id_gen.get_next_txn(txn)
self._simple_insert_txn(
txn,
@@ -131,14 +141,19 @@ class StateStore(SQLBaseStore):
for state in state_events.values()
],
)
+ state_groups[event.event_id] = state_group
- self._simple_insert_txn(
+ self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
- values={
- "state_group": state_group,
- "event_id": event.event_id,
- },
+ values=[
+ {
+ "state_group": state_groups[event.event_id],
+ "event_id": event.event_id,
+ }
+ for event, context in events_and_contexts
+ if context.current_state is not None
+ ],
)
@defer.inlineCallbacks
@@ -173,8 +188,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
- @cached(num_args=3)
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(num_args=3)
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
@@ -190,6 +204,65 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
+ @defer.inlineCallbacks
+ def get_state_for_events(self, room_id, event_ids):
+ def f(txn):
+ groups = set()
+ event_to_group = {}
+ for event_id in event_ids:
+ # TODO: Remove this loop.
+ group = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ )
+ if group:
+ event_to_group[event_id] = group
+ groups.add(group)
+
+ group_to_state_ids = {}
+ for group in groups:
+ state_ids = self._simple_select_onecol_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+
+ group_to_state_ids[group] = state_ids
+
+ return event_to_group, group_to_state_ids
+
+ res = yield self.runInteraction(
+ "annotate_events_with_state_groups",
+ f,
+ )
+
+ event_to_group, group_to_state_ids = res
+
+ state_list = yield defer.gatherResults(
+ [
+ self._fetch_events_for_group(group, vals)
+ for group, vals in group_to_state_ids.items()
+ ],
+ consumeErrors=True,
+ )
+
+ state_dict = {
+ group: {
+ (ev.type, ev.state_key): ev
+ for ev in state
+ }
+ for group, state in state_list
+ }
+
+ defer.returnValue([
+ state_dict.get(event_to_group.get(event, None), None)
+ for event in event_ids
+ ])
+
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)
|