summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py128
1 files changed, 93 insertions, 35 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0e8fa93e1f..ec551b0b4f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
     """
 
     @defer.inlineCallbacks
-    def get_state_groups(self, room_id, event_ids):
-        """ Get the state groups for the given list of event_ids
-
-        The return value is a dict mapping group names to lists of events.
-        """
+    def get_state_groups_ids(self, room_id, event_ids):
         if not event_ids:
             defer.returnValue({})
 
@@ -59,36 +55,64 @@ class StateStore(SQLBaseStore):
         groups = set(event_to_groups.values())
         group_to_state = yield self._get_state_for_groups(groups)
 
+        defer.returnValue(group_to_state)
+
+    @defer.inlineCallbacks
+    def get_state_groups(self, room_id, event_ids):
+        """ Get the state groups for the given list of event_ids
+
+        The return value is a dict mapping group names to lists of events.
+        """
+        if not event_ids:
+            defer.returnValue({})
+
+        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+        state_event_map = yield self.get_events(
+            [
+                ev_id for group_ids in group_to_ids.values()
+                for ev_id in group_ids.values()
+            ],
+            get_prev_content=False
+        )
+
         defer.returnValue({
-            group: state_map.values()
-            for group, state_map in group_to_state.items()
+            group: [
+                state_event_map[v] for v in event_id_map.values() if v in state_event_map
+            ]
+            for group, event_id_map in group_to_ids.items()
         })
 
+    def _have_persisted_state_group_txn(self, txn, state_group):
+        txn.execute(
+            "SELECT count(*) FROM state_groups WHERE id = ?",
+            (state_group,)
+        )
+        row = txn.fetchone()
+        return row and row[0]
+
     def _store_mult_state_groups_txn(self, txn, events_and_contexts):
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
                 continue
 
-            if context.current_state is None:
-                continue
-
-            if context.state_group is not None:
-                state_groups[event.event_id] = context.state_group
+            if context.current_state_ids is None:
                 continue
 
-            state_events = dict(context.current_state)
+            state_groups[event.event_id] = context.state_group
 
-            if event.is_state():
-                state_events[(event.type, event.state_key)] = event
+            if self._have_persisted_state_group_txn(txn, context.state_group):
+                logger.info("Already persisted state_group: %r", context.state_group)
+                continue
 
-            state_group = context.new_state_group_id
+            state_event_ids = dict(context.current_state_ids)
 
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
-                    "id": state_group,
+                    "id": context.state_group,
                     "room_id": event.room_id,
                     "event_id": event.event_id,
                 },
@@ -99,16 +123,15 @@ class StateStore(SQLBaseStore):
                 table="state_groups_state",
                 values=[
                     {
-                        "state_group": state_group,
-                        "room_id": state.room_id,
-                        "type": state.type,
-                        "state_key": state.state_key,
-                        "event_id": state.event_id,
+                        "state_group": context.state_group,
+                        "room_id": event.room_id,
+                        "type": key[0],
+                        "state_key": key[1],
+                        "event_id": state_id,
                     }
-                    for state in state_events.values()
+                    for key, state_id in state_event_ids.items()
                 ],
             )
-            state_groups[event.event_id] = state_group
 
         self._simple_insert_many_txn(
             txn,
@@ -248,6 +271,31 @@ class StateStore(SQLBaseStore):
         groups = set(event_to_groups.values())
         group_to_state = yield self._get_state_for_groups(groups, types)
 
+        state_event_map = yield self.get_events(
+            [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+            get_prev_content=False
+        )
+
+        event_to_state = {
+            event_id: {
+                k: state_event_map[v]
+                for k, v in group_to_state[group].items()
+                if v in state_event_map
+            }
+            for event_id, group in event_to_groups.items()
+        }
+
+        defer.returnValue({event: event_to_state[event] for event in event_ids})
+
+    @defer.inlineCallbacks
+    def get_state_ids_for_events(self, event_ids, types):
+        event_to_groups = yield self._get_state_group_for_events(
+            event_ids,
+        )
+
+        groups = set(event_to_groups.values())
+        group_to_state = yield self._get_state_for_groups(groups, types)
+
         event_to_state = {
             event_id: group_to_state[group]
             for event_id, group in event_to_groups.items()
@@ -272,6 +320,23 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
+    @defer.inlineCallbacks
+    def get_state_ids_for_event(self, event_id, types=None):
+        """
+        Get the state dict corresponding to a particular event
+
+        Args:
+            event_id(str): event whose state should be returned
+            types(list[(str, str)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. May be None, which
+                matches any key
+
+        Returns:
+            A deferred dict from (type, state_key) -> state_event
+        """
+        state_map = yield self.get_state_ids_for_events([event_id], types)
+        defer.returnValue(state_map[event_id])
+
     @cached(num_args=2, max_entries=10000)
     def _get_state_group_for_event(self, room_id, event_id):
         return self._simple_select_one_onecol(
@@ -428,20 +493,13 @@ class StateStore(SQLBaseStore):
                     full=(types is None),
                 )
 
-        state_events = yield self._get_events(
-            [ev_id for sd in results.values() for ev_id in sd.values()],
-            get_prev_content=False
-        )
-
-        state_events = {e.event_id: e for e in state_events}
-
         # Remove all the entries with None values. The None values were just
         # used for bookkeeping in the cache.
         for group, state_dict in results.items():
             results[group] = {
-                key: state_events[event_id]
+                key: event_id
                 for key, event_id in state_dict.items()
-                if event_id and event_id in state_events
+                if event_id
             }
 
         defer.returnValue(results)
@@ -473,5 +531,5 @@ class StateStore(SQLBaseStore):
             "get_all_new_state_groups", get_all_new_state_groups_txn
         )
 
-    def get_state_stream_token(self):
-        return self._state_groups_id_gen.get_current_token()
+    def get_next_state_group(self):
+        return self._state_groups_id_gen.get_next()