summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py71
1 files changed, 67 insertions, 4 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 17b14d464b..dd03c4168b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -116,6 +116,69 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             _get_current_state_ids_txn,
         )
 
+    # FIXME: how should this be cached?
+    def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+        """Get the current state event of a given type for a room based on the
+        current_state_events table.  This may not be as up-to-date as the result
+        of doing a fresh state resolution as per state_handler.get_current_state
+        Args:
+            room_id (str)
+            types (list[(Str, (Str|None))]): List of (type, state_key) tuples
+                which are used to filter the state fetched. `state_key` may be
+                None, which matches any `state_key`
+            filtered_types (list[Str]|None): List of types to apply the above filter to.
+        Returns:
+            deferred: dict of (type, state_key) -> event
+        """
+
+        include_other_types = False if filtered_types is None else True
+
+        def _get_filtered_current_state_ids_txn(txn):
+            results = {}
+            sql = """SELECT type, state_key, event_id FROM current_state_events
+                     WHERE room_id = ? %s"""
+            # Turns out that postgres doesn't like doing a list of OR's and
+            # is about 1000x slower, so we just issue a query for each specific
+            # type seperately.
+            if types:
+                clause_to_args = [
+                    (
+                        "AND type = ? AND state_key = ?",
+                        (etype, state_key)
+                    ) if state_key is not None else (
+                        "AND type = ?",
+                        (etype,)
+                    )
+                    for etype, state_key in types
+                ]
+
+                if include_other_types:
+                    unique_types = set(filtered_types)
+                    clause_to_args.append(
+                        (
+                            "AND type <> ? " * len(unique_types),
+                            list(unique_types)
+                        )
+                    )
+            else:
+                # If types is None we fetch all the state, and so just use an
+                # empty where clause with no extra args.
+                clause_to_args = [("", [])]
+            for where_clause, where_args in clause_to_args:
+                args = [room_id]
+                args.extend(where_args)
+                txn.execute(sql % (where_clause,), args)
+                for row in txn:
+                    typ, state_key, event_id = row
+                    key = (intern_string(typ), intern_string(state_key))
+                    results[key] = event_id
+            return results
+
+        return self.runInteraction(
+            "get_filtered_current_state_ids",
+            _get_filtered_current_state_ids_txn,
+        )
+
     @cached(max_entries=10000, iterable=True)
     def get_state_group_delta(self, state_group):
         """Given a state group try to return a previous group and a delta between
@@ -389,8 +452,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 If None, `types` filtering is applied to all events.
 
         Returns:
-            deferred: A list of dicts corresponding to the event_ids given.
-            The dicts are mappings from (type, state_key) -> state_events
+            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
         """
         event_to_groups = yield self._get_state_group_for_events(
             event_ids,
@@ -418,7 +480,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     @defer.inlineCallbacks
     def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
         """
-        Get the state dicts corresponding to a list of events
+        Get the state dicts corresponding to a list of events, containing the event_ids
+        of the state events (as opposed to the events themselves)
 
         Args:
             event_ids(list(str)): events whose state should be returned
@@ -431,7 +494,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 If None, `types` filtering is applied to all events.
 
         Returns:
-            A deferred dict from event_id -> (type, state_key) -> state_event
+            A deferred dict from event_id -> (type, state_key) -> event_id
         """
         event_to_groups = yield self._get_state_group_for_events(
             event_ids,