summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py62
1 files changed, 32 insertions, 30 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0bfe1b4550..1980a87108 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -422,7 +422,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         # Retrieve the room's create event
         create_event = yield self.get_create_event_for_room(room_id)
-        defer.returnValue(create_event.content.get("room_version", "1"))
+        return create_event.content.get("room_version", "1")
 
     @defer.inlineCallbacks
     def get_room_predecessor(self, room_id):
@@ -442,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         create_event = yield self.get_create_event_for_room(room_id)
 
         # Return predecessor if present
-        defer.returnValue(create_event.content.get("predecessor", None))
+        return create_event.content.get("predecessor", None)
 
     @defer.inlineCallbacks
     def get_create_event_for_room(self, room_id):
@@ -466,7 +466,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         # Retrieve the room's create event and return
         create_event = yield self.get_event(create_id)
-        defer.returnValue(create_event)
+        return create_event
 
     @cached(max_entries=100000, iterable=True)
     def get_current_state_ids(self, room_id):
@@ -510,6 +510,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             event ID.
         """
 
+        where_clause, where_args = state_filter.make_sql_filter_clause()
+
+        if not where_clause:
+            # We delegate to the cached version
+            return self.get_current_state_ids(room_id)
+
         def _get_filtered_current_state_ids_txn(txn):
             results = {}
             sql = """
@@ -517,8 +523,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 WHERE room_id = ?
             """
 
-            where_clause, where_args = state_filter.make_sql_filter_clause()
-
             if where_clause:
                 sql += " AND (%s)" % (where_clause,)
 
@@ -559,7 +563,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         if not event:
             return
 
-        defer.returnValue(event.content.get("canonical_alias"))
+        return event.content.get("canonical_alias")
 
     @cached(max_entries=10000, iterable=True)
     def get_state_group_delta(self, state_group):
@@ -609,14 +613,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
-            defer.returnValue({})
+            return {}
 
         event_to_groups = yield self._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups)
 
-        defer.returnValue(group_to_state)
+        return group_to_state
 
     @defer.inlineCallbacks
     def get_state_ids_for_group(self, state_group):
@@ -630,7 +634,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         group_to_state = yield self._get_state_for_groups((state_group,))
 
-        defer.returnValue(group_to_state[state_group])
+        return group_to_state[state_group]
 
     @defer.inlineCallbacks
     def get_state_groups(self, room_id, event_ids):
@@ -641,7 +645,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 dict of state_group_id -> list of state events.
         """
         if not event_ids:
-            defer.returnValue({})
+            return {}
 
         group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
 
@@ -654,16 +658,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             get_prev_content=False,
         )
 
-        defer.returnValue(
-            {
-                group: [
-                    state_event_map[v]
-                    for v in itervalues(event_id_map)
-                    if v in state_event_map
-                ]
-                for group, event_id_map in iteritems(group_to_ids)
-            }
-        )
+        return {
+            group: [
+                state_event_map[v]
+                for v in itervalues(event_id_map)
+                if v in state_event_map
+            ]
+            for group, event_id_map in iteritems(group_to_ids)
+        }
 
     @defer.inlineCallbacks
     def _get_state_groups_from_groups(self, groups, state_filter):
@@ -690,7 +692,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             )
             results.update(res)
 
-        defer.returnValue(results)
+        return results
 
     def _get_state_groups_from_groups_txn(
         self, txn, groups, state_filter=StateFilter.all()
@@ -825,7 +827,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             for event_id, group in iteritems(event_to_groups)
         }
 
-        defer.returnValue({event: event_to_state[event] for event in event_ids})
+        return {event: event_to_state[event] for event in event_ids}
 
     @defer.inlineCallbacks
     def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -851,7 +853,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             for event_id, group in iteritems(event_to_groups)
         }
 
-        defer.returnValue({event: event_to_state[event] for event in event_ids})
+        return {event: event_to_state[event] for event in event_ids}
 
     @defer.inlineCallbacks
     def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -867,7 +869,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             A deferred dict from (type, state_key) -> state_event
         """
         state_map = yield self.get_state_for_events([event_id], state_filter)
-        defer.returnValue(state_map[event_id])
+        return state_map[event_id]
 
     @defer.inlineCallbacks
     def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -883,7 +885,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             A deferred dict from (type, state_key) -> state_event
         """
         state_map = yield self.get_state_ids_for_events([event_id], state_filter)
-        defer.returnValue(state_map[event_id])
+        return state_map[event_id]
 
     @cached(max_entries=50000)
     def _get_state_group_for_event(self, event_id):
@@ -913,7 +915,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="_get_state_group_for_events",
         )
 
-        defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
+        return {row["event_id"]: row["state_group"] for row in rows}
 
     def _get_state_for_group_using_cache(self, cache, group, state_filter):
         """Checks if group is in cache. See `_get_state_for_groups`
@@ -993,7 +995,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         incomplete_groups = incomplete_groups_m | incomplete_groups_nm
 
         if not incomplete_groups:
-            defer.returnValue(state)
+            return state
 
         cache_sequence_nm = self._state_group_cache.sequence
         cache_sequence_m = self._state_group_members_cache.sequence
@@ -1020,7 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             # everything we need from the database anyway.
             state[group] = state_filter.filter_state(group_state_dict)
 
-        defer.returnValue(state)
+        return state
 
     def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
         """Gets the state at each of a list of state groups, optionally
@@ -1494,7 +1496,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                 self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
             )
 
-        defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
+        return result * BATCH_SIZE_SCALE_FACTOR
 
     @defer.inlineCallbacks
     def _background_index_state(self, progress, batch_size):
@@ -1524,4 +1526,4 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
 
         yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
 
-        defer.returnValue(1)
+        return 1