summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/state.py113
1 files changed, 60 insertions, 53 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c5ff44fef7..ee531a2ce0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -185,7 +185,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         })
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, types):
+    def _get_state_groups_from_groups(self, groups, types, filtered_types=None):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
@@ -193,9 +193,10 @@ class StateGroupWorkerStore(SQLBaseStore):
             groups(list[int]): list of state group IDs to query
             types(list[str|None, str|None])|None: List of 2-tuples of the form
                 (`type`, `state_key`), where a `state_key` of `None` matches all
-                state_keys for the `type`. Presence of type of `None` indicates
-                that types not in the list should not be filtered out. If None,
-                all types are returned.
+                state_keys for the `type`. If None, all types are returned.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             dictionary state_group -> (dict of (type, state_key) -> event id)
@@ -206,26 +207,21 @@ class StateGroupWorkerStore(SQLBaseStore):
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn, chunk, types,
+                self._get_state_groups_from_groups_txn, chunk, types, filtered_types
             )
             results.update(res)
 
         defer.returnValue(results)
 
-    def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
+    def _get_state_groups_from_groups_txn(
+        self, txn, groups, types=None, filtered_types=None
+    ):
         results = {group: {} for group in groups}
 
-        include_other_types = False
+        include_other_types = False if filtered_types is None else True
 
         if types is not None:
-            type_set = set(types)
-            if (None, None) in type_set:
-                # special case (None, None) to mean that other types should be
-                # returned - i.e. we were just filtering down the state keys
-                # for particular types.
-                include_other_types = True
-                type_set.remove((None, None))
-            types = list(type_set)  # deduplicate types list
+            types = list(set(types))  # deduplicate types list
 
         if isinstance(self.database_engine, PostgresEngine):
             # Temporarily disable sequential scans in this transaction. This is
@@ -276,7 +272,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                 if include_other_types:
                     # XXX: check whether this slows postgres down like a list of
                     # ORs does too?
-                    unique_types = set([t for (t, _) in types])
+                    unique_types = set(filtered_types)
                     clause_to_args.append(
                         (
                             "AND type <> ? " * len(unique_types),
@@ -313,7 +309,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                         where_args.extend([typ[0], typ[1]])
 
                 if include_other_types:
-                    unique_types = set([t for (t, _) in types])
+                    unique_types = set(filtered_types)
                     where_clauses.append(
                         "(" + " AND ".join(["type <> ?"] * len(unique_types)) + ")"
                     )
@@ -373,18 +369,20 @@ class StateGroupWorkerStore(SQLBaseStore):
         return results
 
     @defer.inlineCallbacks
-    def get_state_for_events(self, event_ids, types):
+    def get_state_for_events(self, event_ids, types, filtered_types):
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event. The state dicts will only have the type/state_keys
         that are in the `types` list.
 
         Args:
             event_ids (list[string])
-            types (list[(str|None, str|None)]|None): List of (type, state_key) tuples
+            types (list[(str, str|None)]|None): List of (type, state_key) tuples
                 which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.  Presence of type of `None`
-                indicates that types not in the list should not be filtered out.
+                all events are returned of the given type.
                 May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             deferred: A list of dicts corresponding to the event_ids given.
@@ -395,7 +393,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         )
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, types)
+        group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
 
         state_event_map = yield self.get_events(
             [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@@ -414,17 +412,19 @@ class StateGroupWorkerStore(SQLBaseStore):
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, types=None):
+    def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
         """
         Get the state dicts corresponding to a list of events
 
         Args:
             event_ids(list(str)): events whose state should be returned
-            types(list[(str|None, str|None)]|None): List of (type, state_key) tuples
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
                 which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.  Presence of type of `None`
-                indicates that types not in the list should not be filtered out.
+                all events are returned of the given type.
                 May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             A deferred dict from event_id -> (type, state_key) -> state_event
@@ -434,7 +434,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         )
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, types)
+        group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
 
         event_to_state = {
             event_id: group_to_state[group]
@@ -444,41 +444,45 @@ class StateGroupWorkerStore(SQLBaseStore):
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_for_event(self, event_id, types=None):
+    def get_state_for_event(self, event_id, types=None, filtered_types=None):
         """
         Get the state dict corresponding to a particular event
 
         Args:
             event_id(str): event whose state should be returned
-            types(list[(str|None, str|None)]|None): List of (type, state_key) tuples
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
                 which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.  Presence of type of `None`
-                indicates that types not in the list should not be filtered out.
+                all events are returned of the given type.
                 May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_for_events([event_id], types)
+        state_map = yield self.get_state_for_events([event_id], types, filtered_types)
         defer.returnValue(state_map[event_id])
 
     @defer.inlineCallbacks
-    def get_state_ids_for_event(self, event_id, types=None):
+    def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
         """
         Get the state dict corresponding to a particular event
 
         Args:
             event_id(str): event whose state should be returned
-            types(list[(str|None, str|None)]|None): List of (type, state_key) tuples
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
                 which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.  Presence of type of `None`
-                indicates that types not in the list should not be filtered out.
+                all events are returned of the given type.
                 May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_ids_for_events([event_id], types)
+        state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
         defer.returnValue(state_map[event_id])
 
     @cached(max_entries=50000)
@@ -509,7 +513,7 @@ class StateGroupWorkerStore(SQLBaseStore):
 
         defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
 
-    def _get_some_state_from_cache(self, group, types):
+    def _get_some_state_from_cache(self, group, types, filtered_types=None):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
@@ -520,29 +524,30 @@ class StateGroupWorkerStore(SQLBaseStore):
 
         Args:
             group(int): The state group to lookup
-            types(list[str|None, str|None]): List of 2-tuples of the form
+            types(list[str, str|None]): List of 2-tuples of the form
                 (`type`, `state_key`), where a `state_key` of `None` matches all
-                state_keys for the `type`. Presence of type of `None` indicates
-                that types not in the list should not be filtered out.
+                state_keys for the `type`.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
         """
         is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
 
         type_to_key = {}
+
+        # tracks which of the requested types are missing from our cache
         missing_types = set()
 
-        include_other_types = False
+        include_other_types = True if filtered_types is None else False
 
         for typ, state_key in types:
             key = (typ, state_key)
 
-            if typ is None:
-                include_other_types = True
-                next
-
             if state_key is None:
                 type_to_key[typ] = None
                 # XXX: why do we mark the type as missing from our cache just
                 # because we weren't filtering on a specific value of state_key?
+                # is it because the cache doesn't handle wildcards?
                 missing_types.add(key)
             else:
                 if type_to_key.get(typ, object()) is not None:
@@ -556,7 +561,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         def include(typ, state_key):
             valid_state_keys = type_to_key.get(typ, sentinel)
             if valid_state_keys is sentinel:
-                return include_other_types
+                return include_other_types and typ not in filtered_types
             if valid_state_keys is None:
                 return True
             if state_key in valid_state_keys:
@@ -585,21 +590,23 @@ class StateGroupWorkerStore(SQLBaseStore):
         return state_dict_ids, is_all
 
     @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, types=None):
+    def _get_state_for_groups(self, groups, types=None, filtered_types=None):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
         Args:
             groups (iterable[int]): list of state groups for which we want
                 to get the state.
-            types (None|iterable[(None|str, None|str)]):
+            types (None|iterable[(None, None|str)]):
                 indicates the state type/keys required. If None, the whole
                 state is fetched and returned.
 
                 Otherwise, each entry should be a `(type, state_key)` tuple to
                 include in the response. A `state_key` of None is a wildcard
-                meaning that we require all state with that type. A `type` of None
-                indicates that types not in the list should not be filtered out.
+                meaning that we require all state with that type.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             Deferred[dict[int, dict[(type, state_key), EventBase]]]
@@ -612,7 +619,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         if types is not None:
             for group in set(groups):
                 state_dict_ids, _, got_all = self._get_some_state_from_cache(
-                    group, types,
+                    group, types, filtered_types
                 )
                 results[group] = state_dict_ids
 
@@ -645,7 +652,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                 types_to_fetch = types
 
             group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types_to_fetch,
+                missing_groups, types_to_fetch, filtered_types
             )
 
             for group, group_state_dict in iteritems(group_to_state_dict):