diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2f3a70b4e5..55ea567793 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,43 +14,71 @@
# limitations under the License.
from ._base import SQLBaseStore
-from twisted.internet import defer
class StateStore(SQLBaseStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
- @defer.inlineCallbacks
def get_state_groups(self, event_ids):
- groups = set()
- for event_id in event_ids:
- group = yield self._simple_select_one_onecol(
- table="event_to_state_groups",
- keyvalues={"event_id": event_id},
- retcol="state_group",
- allow_none=True,
- )
- if group:
- groups.add(group)
-
- res = {}
- for group in groups:
- state_ids = yield self._simple_select_onecol(
- table="state_groups_state",
- keyvalues={"state_group": group},
- retcol="event_id",
- )
- state = []
- for state_id in state_ids:
- s = yield self.get_event(
- state_id,
+ """ Get the state groups for the given list of event_ids
+
+ The return value is a dict mapping group names to lists of events.
+ """
+
+ def f(txn):
+ groups = set()
+ for event_id in event_ids:
+ group = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
allow_none=True,
)
- if s:
- state.append(s)
+ if group:
+ groups.add(group)
- res[group] = state
+ res = {}
+ for group in groups:
+ state_ids = self._simple_select_onecol_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+ state = []
+ for state_id in state_ids:
+ s = self._get_events_txn(
+ txn,
+ [state_id],
+ )
+ if s:
+ state.extend(s)
+
+ res[group] = state
- defer.returnValue(res)
+ return res
+
+ return self.runInteraction(
+ "get_state_groups",
+ f,
+ )
def store_state_groups(self, event):
return self.runInteraction(
|