summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/state.py219
1 files changed, 142 insertions, 77 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5588c9e697..a04731ae11 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,11 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cached, cachedInlineCallbacks
+from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList
 
 from twisted.internet import defer
 
-from synapse.util import unwrapFirstError
 from synapse.util.stringutils import random_string
 
 import logging
@@ -50,32 +49,20 @@ class StateStore(SQLBaseStore):
 
         The return value is a dict mapping group names to lists of events.
         """
+        if not event_ids:
+            defer.returnValue({})
 
-        event_and_groups = yield defer.gatherResults(
-            [
-                self._get_state_group_for_event(
-                    room_id, event_id,
-                ).addCallback(lambda group, event_id: (event_id, group), event_id)
-                for event_id in event_ids
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        event_to_groups = yield self._get_state_group_for_events(
+            room_id, event_ids,
+        )
 
-        groups = set(group for _, group in event_and_groups if group)
+        groups = set(event_to_groups.values())
 
-        group_to_state = yield defer.gatherResults(
-            [
-                self._get_state_for_group(
-                    group,
-                ).addCallback(lambda state_dict, group: (group, state_dict), group)
-                for group in groups
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        group_to_state = yield self._get_state_for_groups(groups)
 
         defer.returnValue({
             group: state_map.values()
-            for group, state_map in group_to_state
+            for group, state_map in group_to_state.items()
         })
 
     @cached(num_args=1)
@@ -212,17 +199,48 @@ class StateStore(SQLBaseStore):
 
             txn.execute(sql, args)
 
-            return group, [
-                r[0]
-                for r in txn.fetchall()
-            ]
+            return [r[0] for r in txn.fetchall()]
 
         return self.runInteraction(
             "_get_state_groups_from_group",
             f,
         )
 
-    @cached(num_args=3, lru=True, max_entries=20000)
+    def _get_state_groups_from_groups(self, groups_and_types):
+        def f(txn):
+            results = {}
+            for group, types in groups_and_types:
+                if types is not None:
+                    where_clause = "AND (%s)" % (
+                        " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
+                    )
+                else:
+                    where_clause = ""
+
+                sql = (
+                    "SELECT event_id FROM state_groups_state WHERE"
+                    " state_group = ? %s"
+                ) % (where_clause,)
+
+                args = [group]
+                if types is not None:
+                    args.extend([i for typ in types for i in typ])
+
+                txn.execute(sql, args)
+
+                results[group] = [
+                    r[0]
+                    for r in txn.fetchall()
+                ]
+
+            return results
+
+        return self.runInteraction(
+            "_get_state_groups_from_groups",
+            f,
+        )
+
+    @cached(num_args=3, lru=True, max_entries=10000)
     def _get_state_for_event_id(self, room_id, event_id, types):
         def f(txn):
             type_and_state_sql = " OR ".join([
@@ -274,33 +292,19 @@ class StateStore(SQLBaseStore):
             deferred: A list of dicts corresponding to the event_ids given.
             The dicts are mappings from (type, state_key) -> state_events
         """
-        event_and_groups = yield defer.gatherResults(
-            [
-                self._get_state_group_for_event(
-                    room_id, event_id,
-                ).addCallback(lambda group, event_id: (event_id, group), event_id)
-                for event_id in event_ids
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError)
-
-        groups = set(group for _, group in event_and_groups)
+        event_to_groups = yield self._get_state_group_for_events(
+            room_id, event_ids,
+        )
 
-        res = yield defer.gatherResults(
-            [
-                self._get_state_for_group(
-                    group, types
-                ).addCallback(lambda state_dict, group: (group, state_dict), group)
-                for group in groups
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        groups = set(event_to_groups.values())
 
-        group_to_state = dict(res)
+        group_to_state = yield self._get_state_for_groups(
+            groups, types
+        )
 
         event_to_state = {
             event_id: group_to_state[group]
-            for event_id, group in event_and_groups
+            for event_id, group in event_to_groups.items()
         }
 
         defer.returnValue([
@@ -320,8 +324,29 @@ class StateStore(SQLBaseStore):
             desc="_get_state_group_for_event",
         )
 
-    @defer.inlineCallbacks
-    def _get_state_for_group(self, group, types=None):
+    @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", num_args=2)
+    def _get_state_group_for_events(self, room_id, event_ids):
+        def f(txn):
+            results = {}
+            for event_id in event_ids:
+                results[event_id] = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="event_to_state_groups",
+                    keyvalues={
+                        "event_id": event_id,
+                    },
+                    retcol="state_group",
+                    allow_none=True,
+                )
+
+            return results
+
+        return self.runInteraction(
+            "_get_state_group_for_events",
+            f,
+        )
+
+    def _get_state_for_group_from_cache(self, group, types=None):
         is_all, state_dict = self._state_group_cache.get(group)
 
         type_to_key = {}
@@ -339,7 +364,7 @@ class StateStore(SQLBaseStore):
                         missing_types.add((typ, state_key))
 
         if is_all and types is None:
-            defer.returnValue(state_dict)
+            return state_dict, missing_types
 
         if is_all or (types is not None and not missing_types):
             sentinel = object()
@@ -354,41 +379,81 @@ class StateStore(SQLBaseStore):
                     return True
                 return False
 
-            defer.returnValue({
+            return {
                 k: v
                 for k, v in state_dict.items()
                 if v and include(k[0], k[1])
-            })
+            }, missing_types
+
+        return {}, missing_types
+
+    @defer.inlineCallbacks
+    def _get_state_for_groups(self, groups, types=None):
+        results = {}
+        missing_groups_and_types = []
+        for group in groups:
+            state_dict, missing_types = self._get_state_for_group_from_cache(
+                group, types
+            )
+
+            if types is not None and not missing_types:
+                results[group] = {
+                    key: value
+                    for key, value in state_dict.items()
+                    if value
+                }
+            else:
+                missing_groups_and_types.append((
+                    group,
+                    missing_types if types else None
+                ))
+
+        if not missing_groups_and_types:
+            defer.returnValue(results)
 
         # Okay, so we have some missing_types, lets fetch them.
         cache_seq_num = self._state_group_cache.sequence
-        _, state_ids = yield self._get_state_groups_from_group(
-            group,
-            frozenset(missing_types) if types else None
+
+        group_state_dict = yield self._get_state_groups_from_groups(
+            missing_groups_and_types
         )
-        state_events = yield self._get_events(state_ids, get_prev_content=False)
-        state_dict = {
-            key: None
-            for key in missing_types
-        }
-        state_dict.update({
-            (e.type, e.state_key): e
-            for e in state_events
-        })
 
-        # Update the cache
-        self._state_group_cache.update(
-            cache_seq_num,
-            key=group,
-            value=state_dict,
-            full=(types is None),
+        state_events = yield self._get_events(
+            [e_id for l in group_state_dict.values() for e_id in l],
+            get_prev_content=False
         )
 
-        defer.returnValue({
-            key: value
-            for key, value in state_dict.items()
-            if value
-        })
+        state_events = {
+            e.event_id: e
+            for e in state_events
+        }
+
+        for group, state_ids in group_state_dict.items():
+            state_dict = {
+                key: None
+                for key in missing_types
+            }
+            evs = [state_events[e_id] for e_id in state_ids]
+            state_dict.update({
+                (e.type, e.state_key): e
+                for e in evs
+            })
+
+            # Update the cache
+            self._state_group_cache.update(
+                cache_seq_num,
+                key=group,
+                value=state_dict,
+                full=(types is None),
+            )
+
+            results[group] = {
+                key: value
+                for key, value in state_dict.items()
+                if value
+            }
+
+        defer.returnValue(results)
 
 
 def _make_group_id(clock):