diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 62 |
1 files changed, 32 insertions, 30 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0bfe1b4550..1980a87108 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -422,7 +422,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # Retrieve the room's create event create_event = yield self.get_create_event_for_room(room_id) - defer.returnValue(create_event.content.get("room_version", "1")) + return create_event.content.get("room_version", "1") @defer.inlineCallbacks def get_room_predecessor(self, room_id): @@ -442,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = yield self.get_create_event_for_room(room_id) # Return predecessor if present - defer.returnValue(create_event.content.get("predecessor", None)) + return create_event.content.get("predecessor", None) @defer.inlineCallbacks def get_create_event_for_room(self, room_id): @@ -466,7 +466,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # Retrieve the room's create event and return create_event = yield self.get_event(create_id) - defer.returnValue(create_event) + return create_event @cached(max_entries=100000, iterable=True) def get_current_state_ids(self, room_id): @@ -510,6 +510,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): event ID. """ + where_clause, where_args = state_filter.make_sql_filter_clause() + + if not where_clause: + # We delegate to the cached version + return self.get_current_state_ids(room_id) + def _get_filtered_current_state_ids_txn(txn): results = {} sql = """ @@ -517,8 +523,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): WHERE room_id = ? """ - where_clause, where_args = state_filter.make_sql_filter_clause() - if where_clause: sql += " AND (%s)" % (where_clause,) @@ -559,7 +563,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event: return - defer.returnValue(event.content.get("canonical_alias")) + return event.content.get("canonical_alias") @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): @@ -609,14 +613,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: - defer.returnValue({}) + return {} event_to_groups = yield self._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups) - defer.returnValue(group_to_state) + return group_to_state @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): @@ -630,7 +634,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ group_to_state = yield self._get_state_for_groups((state_group,)) - defer.returnValue(group_to_state[state_group]) + return group_to_state[state_group] @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): @@ -641,7 +645,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): dict of state_group_id -> list of state events. """ if not event_ids: - defer.returnValue({}) + return {} group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) @@ -654,16 +658,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): get_prev_content=False, ) - defer.returnValue( - { - group: [ - state_event_map[v] - for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - } - ) + return { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, state_filter): @@ -690,7 +692,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) results.update(res) - defer.returnValue(results) + return results def _get_state_groups_from_groups_txn( self, txn, groups, state_filter=StateFilter.all() @@ -825,7 +827,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for event_id, group in iteritems(event_to_groups) } - defer.returnValue({event: event_to_state[event] for event in event_ids}) + return {event: event_to_state[event] for event in event_ids} @defer.inlineCallbacks def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): @@ -851,7 +853,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for event_id, group in iteritems(event_to_groups) } - defer.returnValue({event: event_to_state[event] for event in event_ids}) + return {event: event_to_state[event] for event in event_ids} @defer.inlineCallbacks def get_state_for_event(self, event_id, state_filter=StateFilter.all()): @@ -867,7 +869,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_for_events([event_id], state_filter) - defer.returnValue(state_map[event_id]) + return state_map[event_id] @defer.inlineCallbacks def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): @@ -883,7 +885,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_ids_for_events([event_id], state_filter) - defer.returnValue(state_map[event_id]) + return state_map[event_id] @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): @@ -913,7 +915,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="_get_state_group_for_events", ) - defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) + return {row["event_id"]: row["state_group"] for row in rows} def _get_state_for_group_using_cache(self, cache, group, state_filter): """Checks if group is in cache. See `_get_state_for_groups` @@ -993,7 +995,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): incomplete_groups = incomplete_groups_m | incomplete_groups_nm if not incomplete_groups: - defer.returnValue(state) + return state cache_sequence_nm = self._state_group_cache.sequence cache_sequence_m = self._state_group_members_cache.sequence @@ -1020,7 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # everything we need from the database anyway. state[group] = state_filter.filter_state(group_state_dict) - defer.returnValue(state) + return state def _get_state_for_groups_using_cache(self, groups, cache, state_filter): """Gets the state at each of a list of state groups, optionally @@ -1494,7 +1496,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) - defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR) + return result * BATCH_SIZE_SCALE_FACTOR @defer.inlineCallbacks def _background_index_state(self, progress, batch_size): @@ -1524,4 +1526,4 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - defer.returnValue(1) + return 1 |