summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py114
1 files changed, 72 insertions, 42 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index ffa4246031..89a05c4618 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,17 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import namedtuple
 import logging
+from collections import namedtuple
+
+from six import iteritems, itervalues
+from six.moves import range
 
 from twisted.internet import defer
 
 from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.engines import PostgresEngine
-from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
+from synapse.util.caches import get_cache_factor_for, intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.stringutils import to_ascii
+
 from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
@@ -54,7 +58,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         super(StateGroupWorkerStore, self).__init__(db_conn, hs)
 
         self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
+            "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
         )
 
     @cached(max_entries=100000, iterable=True)
@@ -134,7 +138,7 @@ class StateGroupWorkerStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.itervalues())
+        groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups)
 
         defer.returnValue(group_to_state)
@@ -166,18 +170,18 @@ class StateGroupWorkerStore(SQLBaseStore):
 
         state_event_map = yield self.get_events(
             [
-                ev_id for group_ids in group_to_ids.itervalues()
-                for ev_id in group_ids.itervalues()
+                ev_id for group_ids in itervalues(group_to_ids)
+                for ev_id in itervalues(group_ids)
             ],
             get_prev_content=False
         )
 
         defer.returnValue({
             group: [
-                state_event_map[v] for v in event_id_map.itervalues()
+                state_event_map[v] for v in itervalues(event_id_map)
                 if v in state_event_map
             ]
-            for group, event_id_map in group_to_ids.iteritems()
+            for group, event_id_map in iteritems(group_to_ids)
         })
 
     @defer.inlineCallbacks
@@ -186,7 +190,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         """
         results = {}
 
-        chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
+        chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
@@ -269,7 +273,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                 for typ in types:
                     if typ[1] is None:
                         where_clauses.append("(type = ?)")
-                        where_args.extend(typ[0])
+                        where_args.append(typ[0])
                         wildcard_types = True
                     else:
                         where_clauses.append("(type = ? AND state_key = ?)")
@@ -347,21 +351,21 @@ class StateGroupWorkerStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.itervalues())
+        groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         state_event_map = yield self.get_events(
-            [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
+            [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
             get_prev_content=False
         )
 
         event_to_state = {
             event_id: {
                 k: state_event_map[v]
-                for k, v in group_to_state[group].iteritems()
+                for k, v in iteritems(group_to_state[group])
                 if v in state_event_map
             }
-            for event_id, group in event_to_groups.iteritems()
+            for event_id, group in iteritems(event_to_groups)
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -384,12 +388,12 @@ class StateGroupWorkerStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.itervalues())
+        groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         event_to_state = {
             event_id: group_to_state[group]
-            for event_id, group in event_to_groups.iteritems()
+            for event_id, group in iteritems(event_to_groups)
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -503,7 +507,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         got_all = is_all or not missing_types
 
         return {
-            k: v for k, v in state_dict_ids.iteritems()
+            k: v for k, v in iteritems(state_dict_ids)
             if include(k[0], k[1])
         }, missing_types, got_all
 
@@ -523,10 +527,23 @@ class StateGroupWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def _get_state_for_groups(self, groups, types=None):
-        """Given list of groups returns dict of group -> list of state events
-        with matching types. `types` is a list of `(type, state_key)`, where
-        a `state_key` of None matches all state_keys. If `types` is None then
-        all events are returned.
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            types (None|iterable[(str, None|str)]):
+                indicates the state type/keys required. If None, the whole
+                state is fetched and returned.
+
+                Otherwise, each entry should be a `(type, state_key)` tuple to
+                include in the response. A `state_key` of None is a wildcard
+                meaning that we require all state with that type.
+
+        Returns:
+            Deferred[dict[int, dict[(type, state_key), EventBase]]]
+                a dictionary mapping from state group to state dictionary.
         """
         if types:
             types = frozenset(types)
@@ -535,7 +552,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         if types is not None:
             for group in set(groups):
                 state_dict_ids, _, got_all = self._get_some_state_from_cache(
-                    group, types
+                    group, types,
                 )
                 results[group] = state_dict_ids
 
@@ -556,26 +573,40 @@ class StateGroupWorkerStore(SQLBaseStore):
             # Okay, so we have some missing_types, lets fetch them.
             cache_seq_num = self._state_group_cache.sequence
 
+            # the DictionaryCache knows if it has *all* the state, but
+            # does not know if it has all of the keys of a particular type,
+            # which makes wildcard lookups expensive unless we have a complete
+            # cache. Hence, if we are doing a wildcard lookup, populate the
+            # cache fully so that we can do an efficient lookup next time.
+
+            if types and any(k is None for (t, k) in types):
+                types_to_fetch = None
+            else:
+                types_to_fetch = types
+
             group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types
+                missing_groups, types_to_fetch,
             )
 
-            # Now we want to update the cache with all the things we fetched
-            # from the database.
-            for group, group_state_dict in group_to_state_dict.iteritems():
+            for group, group_state_dict in iteritems(group_to_state_dict):
                 state_dict = results[group]
 
-                state_dict.update(
-                    ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
-                    for k, v in group_state_dict.iteritems()
-                )
-
+                # update the result, filtering by `types`.
+                if types:
+                    for k, v in iteritems(group_state_dict):
+                        (typ, _) = k
+                        if k in types or (typ, None) in types:
+                            state_dict[k] = v
+                else:
+                    state_dict.update(group_state_dict)
+
+                # update the cache with all the things we fetched from the
+                # database.
                 self._state_group_cache.update(
                     cache_seq_num,
                     key=group,
-                    value=state_dict,
-                    full=(types is None),
-                    known_absent=types,
+                    value=group_state_dict,
+                    fetched_keys=types_to_fetch,
                 )
 
         defer.returnValue(results)
@@ -654,7 +685,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in delta_ids.iteritems()
+                        for key, state_id in iteritems(delta_ids)
                     ],
                 )
             else:
@@ -669,7 +700,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in current_state_ids.iteritems()
+                        for key, state_id in iteritems(current_state_ids)
                     ],
                 )
 
@@ -682,7 +713,6 @@ class StateGroupWorkerStore(SQLBaseStore):
                 self._state_group_cache.sequence,
                 key=state_group,
                 value=dict(current_state_ids),
-                full=True,
             )
 
             return state_group
@@ -794,11 +824,11 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                     "state_group": state_group_id,
                     "event_id": event_id,
                 }
-                for event_id, state_group_id in state_groups.iteritems()
+                for event_id, state_group_id in iteritems(state_groups)
             ],
         )
 
-        for event_id, state_group_id in state_groups.iteritems():
+        for event_id, state_group_id in iteritems(state_groups):
             txn.call_after(
                 self._get_state_group_for_event.prefill,
                 (event_id,), state_group_id
@@ -826,7 +856,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
 
         def reindex_txn(txn):
             new_last_state_group = last_state_group
-            for count in xrange(batch_size):
+            for count in range(batch_size):
                 txn.execute(
                     "SELECT id, room_id FROM state_groups"
                     " WHERE ? < id AND id <= ?"
@@ -884,7 +914,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                         # of keys
 
                         delta_state = {
-                            key: value for key, value in curr_state.iteritems()
+                            key: value for key, value in iteritems(curr_state)
                             if prev_state.get(key, None) != value
                         }
 
@@ -924,7 +954,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                                     "state_key": key[1],
                                     "event_id": state_id,
                                 }
-                                for key, state_id in delta_state.iteritems()
+                                for key, state_id in iteritems(delta_state)
                             ],
                         )