summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/state.py132
1 files changed, 98 insertions, 34 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 89a05c4618..40ca8bd2a2 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -185,8 +185,21 @@ class StateGroupWorkerStore(SQLBaseStore):
         })
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, types):
-        """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
+    def _get_state_groups_from_groups(self, groups, types, filtered_types=None):
+        """Returns the state groups for a given set of groups, filtering on
+        types of state events.
+
+        Args:
+            groups(list[int]): list of state group IDs to query
+            types (Iterable[str, str|None]|None): list of 2-tuples of the form
+                (`type`, `state_key`), where a `state_key` of `None` matches all
+                state_keys for the `type`. If None, all types are returned.
+            filtered_types(Iterable[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
+
+        Returns:
+            dictionary state_group -> (dict of (type, state_key) -> event id)
         """
         results = {}
 
@@ -194,14 +207,17 @@ class StateGroupWorkerStore(SQLBaseStore):
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn, chunk, types,
+                self._get_state_groups_from_groups_txn, chunk, types, filtered_types,
             )
             results.update(res)
 
         defer.returnValue(results)
 
-    def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
+    def _get_state_groups_from_groups_txn(
+        self, txn, groups, types=None, filtered_types=None,
+    ):
         results = {group: {} for group in groups}
+
         if types is not None:
             types = list(set(types))  # deduplicate types list
 
@@ -250,6 +266,17 @@ class StateGroupWorkerStore(SQLBaseStore):
                     )
                     for etype, state_key in types
                 ]
+
+                if filtered_types is not None:
+                    # XXX: check whether this slows postgres down like a list of
+                    # ORs does too?
+                    unique_types = set(filtered_types)
+                    clause_to_args.append(
+                        (
+                            "AND type <> ? " * len(unique_types),
+                            list(unique_types)
+                        )
+                    )
             else:
                 # If types is None we fetch all the state, and so just use an
                 # empty where clause with no extra args.
@@ -278,6 +305,14 @@ class StateGroupWorkerStore(SQLBaseStore):
                     else:
                         where_clauses.append("(type = ? AND state_key = ?)")
                         where_args.extend([typ[0], typ[1]])
+
+                if filtered_types is not None:
+                    unique_types = set(filtered_types)
+                    where_clauses.append(
+                        "(" + " AND ".join(["type <> ?"] * len(unique_types)) + ")"
+                    )
+                    where_args.extend(list(unique_types))
+
                 where_clause = "AND (%s)" % (" OR ".join(where_clauses))
             else:
                 where_clause = ""
@@ -332,16 +367,20 @@ class StateGroupWorkerStore(SQLBaseStore):
         return results
 
     @defer.inlineCallbacks
-    def get_state_for_events(self, event_ids, types):
+    def get_state_for_events(self, event_ids, types, filtered_types=None):
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event. The state dicts will only have the type/state_keys
         that are in the `types` list.
 
         Args:
-            event_ids (list)
-            types (list): List of (type, state_key) tuples which are used to
-                filter the state fetched. `state_key` may be None, which matches
-                any `state_key`
+            event_ids (list[string])
+            types (list[(str, str|None)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. If `state_key` is None,
+                all events are returned of the given type.
+                May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             deferred: A list of dicts corresponding to the event_ids given.
@@ -352,7 +391,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         )
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, types)
+        group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
 
         state_event_map = yield self.get_events(
             [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@@ -371,15 +410,19 @@ class StateGroupWorkerStore(SQLBaseStore):
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, types=None):
+    def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
         """
         Get the state dicts corresponding to a list of events
 
         Args:
             event_ids(list(str)): events whose state should be returned
-            types(list[(str, str)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. May be None, which
-                matches any key
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. If `state_key` is None,
+                all events are returned of the given type.
+                May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             A deferred dict from event_id -> (type, state_key) -> state_event
@@ -389,7 +432,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         )
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, types)
+        group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
 
         event_to_state = {
             event_id: group_to_state[group]
@@ -399,37 +442,45 @@ class StateGroupWorkerStore(SQLBaseStore):
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_for_event(self, event_id, types=None):
+    def get_state_for_event(self, event_id, types=None, filtered_types=None):
         """
         Get the state dict corresponding to a particular event
 
         Args:
             event_id(str): event whose state should be returned
-            types(list[(str, str)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. May be None, which
-                matches any key
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. If `state_key` is None,
+                all events are returned of the given type.
+                May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_for_events([event_id], types)
+        state_map = yield self.get_state_for_events([event_id], types, filtered_types)
         defer.returnValue(state_map[event_id])
 
     @defer.inlineCallbacks
-    def get_state_ids_for_event(self, event_id, types=None):
+    def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
         """
         Get the state dict corresponding to a particular event
 
         Args:
             event_id(str): event whose state should be returned
-            types(list[(str, str)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. May be None, which
-                matches any key
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. If `state_key` is None,
+                all events are returned of the given type.
+                May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_ids_for_events([event_id], types)
+        state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
         defer.returnValue(state_map[event_id])
 
     @cached(max_entries=50000)
@@ -460,7 +511,7 @@ class StateGroupWorkerStore(SQLBaseStore):
 
         defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
 
-    def _get_some_state_from_cache(self, group, types):
+    def _get_some_state_from_cache(self, group, types, filtered_types=None):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
@@ -470,20 +521,30 @@ class StateGroupWorkerStore(SQLBaseStore):
         missing state.
 
         Args:
-            group: The state group to lookup
-            types (list): List of 2-tuples of the form (`type`, `state_key`),
-                where a `state_key` of `None` matches all state_keys for the
-                `type`.
+            group(int): The state group to lookup
+            types(list[str, str|None]): List of 2-tuples of the form
+                (`type`, `state_key`), where a `state_key` of `None` matches all
+                state_keys for the `type`.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
         """
         is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
 
         type_to_key = {}
+
+        # tracks which of the requested types are missing from our cache
         missing_types = set()
 
         for typ, state_key in types:
             key = (typ, state_key)
+
             if state_key is None:
                 type_to_key[typ] = None
+                # we mark the type as missing from the cache because
+                # when the cache was populated it might have been done with a
+                # restricted set of state_keys, so the wildcard will not work
+                # and the cache may be incomplete.
                 missing_types.add(key)
             else:
                 if type_to_key.get(typ, object()) is not None:
@@ -497,7 +558,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         def include(typ, state_key):
             valid_state_keys = type_to_key.get(typ, sentinel)
             if valid_state_keys is sentinel:
-                return False
+                return filtered_types is not None and typ not in filtered_types
             if valid_state_keys is None:
                 return True
             if state_key in valid_state_keys:
@@ -526,7 +587,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         return state_dict_ids, is_all
 
     @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, types=None):
+    def _get_state_for_groups(self, groups, types=None, filtered_types=None):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
@@ -540,6 +601,9 @@ class StateGroupWorkerStore(SQLBaseStore):
                 Otherwise, each entry should be a `(type, state_key)` tuple to
                 include in the response. A `state_key` of None is a wildcard
                 meaning that we require all state with that type.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             Deferred[dict[int, dict[(type, state_key), EventBase]]]
@@ -552,7 +616,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         if types is not None:
             for group in set(groups):
                 state_dict_ids, _, got_all = self._get_some_state_from_cache(
-                    group, types,
+                    group, types, filtered_types
                 )
                 results[group] = state_dict_ids
 
@@ -585,7 +649,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                 types_to_fetch = types
 
             group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types_to_fetch,
+                missing_groups, types_to_fetch, filtered_types
             )
 
             for group, group_state_dict in iteritems(group_to_state_dict):