diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 17b14d464b..dd03c4168b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -116,6 +116,69 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
_get_current_state_ids_txn,
)
+ # FIXME: how should this be cached?
+ def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+ """Get the current state event of a given type for a room based on the
+ current_state_events table. This may not be as up-to-date as the result
+ of doing a fresh state resolution as per state_handler.get_current_state
+ Args:
+ room_id (str)
+ types (list[(Str, (Str|None))]): List of (type, state_key) tuples
+ which are used to filter the state fetched. `state_key` may be
+ None, which matches any `state_key`
+ filtered_types (list[Str]|None): List of types to apply the above filter to.
+ Returns:
+ deferred: dict of (type, state_key) -> event
+ """
+
+ include_other_types = False if filtered_types is None else True
+
+ def _get_filtered_current_state_ids_txn(txn):
+ results = {}
+ sql = """SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ? %s"""
+ # Turns out that postgres doesn't like doing a list of OR's and
+ # is about 1000x slower, so we just issue a query for each specific
+ # type seperately.
+ if types:
+ clause_to_args = [
+ (
+ "AND type = ? AND state_key = ?",
+ (etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
+ )
+ for etype, state_key in types
+ ]
+
+ if include_other_types:
+ unique_types = set(filtered_types)
+ clause_to_args.append(
+ (
+ "AND type <> ? " * len(unique_types),
+ list(unique_types)
+ )
+ )
+ else:
+ # If types is None we fetch all the state, and so just use an
+ # empty where clause with no extra args.
+ clause_to_args = [("", [])]
+ for where_clause, where_args in clause_to_args:
+ args = [room_id]
+ args.extend(where_args)
+ txn.execute(sql % (where_clause,), args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (intern_string(typ), intern_string(state_key))
+ results[key] = event_id
+ return results
+
+ return self.runInteraction(
+ "get_filtered_current_state_ids",
+ _get_filtered_current_state_ids_txn,
+ )
+
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@@ -389,8 +452,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- deferred: A list of dicts corresponding to the event_ids given.
- The dicts are mappings from (type, state_key) -> state_events
+ deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
@@ -418,7 +480,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
"""
- Get the state dicts corresponding to a list of events
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
@@ -431,7 +494,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- A deferred dict from event_id -> (type, state_key) -> state_event
+ A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
|