diff options
author | Amber Brown <hawkowl@atleastfornow.net> | 2019-04-03 20:07:29 +1100 |
---|---|---|
committer | Richard van der Hoff <1389908+richvdh@users.noreply.github.com> | 2019-04-03 10:07:29 +0100 |
commit | 7efd1d87c2c424365c99ba6103135edb1845fd88 (patch) | |
tree | eadf93e88f277cca7f6fb694c8457ca5c73f5646 /synapse/storage/state.py | |
parent | Merge pull request #4991 from matrix-org/erikj/stagger_push_startup (diff) | |
download | synapse-7efd1d87c2c424365c99ba6103135edb1845fd88.tar.xz |
Run black on the rest of the storage module (#4996)
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 237 |
1 files changed, 102 insertions, 135 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 6ddc4055d2..0bfe1b4550 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -40,10 +40,13 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): """Return type of get_state_group_delta that implements __len__, which lets us use the itrable flag when caching """ + __slots__ = [] def __len__(self): @@ -70,10 +73,7 @@ class StateFilter(object): # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: - self.types = { - k: v for k, v in iteritems(self.types) - if v is not None - } + self.types = {k: v for k, v in iteritems(self.types) if v is not None} @staticmethod def all(): @@ -130,10 +130,7 @@ class StateFilter(object): Returns: StateFilter """ - return StateFilter( - types={EventTypes.Member: set(members)}, - include_others=True, - ) + return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) def return_expanded(self): """Creates a new StateFilter where type wild cards have been removed @@ -243,9 +240,7 @@ class StateFilter(object): if where_clause: where_clause += " OR " - where_clause += "type NOT IN (%s)" % ( - ",".join(["?"] * len(self.types)), - ) + where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) where_args.extend(self.types) return where_clause, where_args @@ -305,12 +300,8 @@ class StateFilter(object): bool """ - return ( - self.include_others - or any( - state_keys is None - for state_keys in itervalues(self.types) - ) + return self.include_others or any( + state_keys is None for state_keys in itervalues(self.types) ) def concrete_types(self): @@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): self._state_group_cache = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache") + 50000 * get_cache_factor_for("stateGroupCache"), ) self._state_group_members_cache = DictionaryCache( "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache") + 500000 * get_cache_factor_for("stateGroupMembersCache"), ) @defer.inlineCallbacks @@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: deferred: dict of (type, state_key) -> event_id """ + def _get_current_state_ids_txn(txn): txn.execute( """SELECT type, state_key, event_id FROM current_state_events WHERE room_id = ? """, - (room_id,) + (room_id,), ) return { (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn } - return self.runInteraction( - "get_current_state_ids", - _get_current_state_ids_txn, - ) + return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) # FIXME: how should this be cached? def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): @@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results return self.runInteraction( - "get_filtered_current_state_ids", - _get_filtered_current_state_ids_txn, + "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) @defer.inlineCallbacks @@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[str|None]: The canonical alias, if any """ - state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types( - [(EventTypes.CanonicalAlias, "")] - )) + state = yield self.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) event_id = state.get((EventTypes.CanonicalAlias, "")) if not event_id: @@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: (prev_group, delta_ids), where both may be None. """ + def _get_state_group_delta_txn(txn): prev_group = self._simple_select_one_onecol_txn( txn, table="state_group_edges", - keyvalues={ - "state_group": state_group, - }, + keyvalues={"state_group": state_group}, retcol="prev_state_group", allow_none=True, ) @@ -598,20 +585,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): delta_ids = self._simple_select_list_txn( txn, table="state_groups_state", - keyvalues={ - "state_group": state_group, - }, - retcols=("type", "state_key", "event_id",) + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), ) - return _GetStateGroupDelta(prev_group, { - (row["type"], row["state_key"]): row["event_id"] - for row in delta_ids - }) - return self.runInteraction( - "get_state_group_delta", - _get_state_group_delta_txn, - ) + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event_ids: defer.returnValue({}) - event_to_groups = yield self._get_state_group_for_events( - event_ids, - ) + event_to_groups = yield self._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups) @@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): state_event_map = yield self.get_events( [ - ev_id for group_ids in itervalues(group_to_ids) + ev_id + for group_ids in itervalues(group_to_ids) for ev_id in itervalues(group_ids) ], - get_prev_content=False + get_prev_content=False, ) - defer.returnValue({ - group: [ - state_event_map[v] for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - }) + defer.returnValue( + { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } + ) @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, state_filter): @@ -695,18 +680,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ results = {} - chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)] + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, state_filter, + self._get_state_groups_from_groups_txn, + chunk, + state_filter, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all(), + self, txn, groups, state_filter=StateFilter.all() ): results = {group: {} for group in groups} @@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" " WHERE state_group = ? " + where_clause, - args + args, ) results[group].update( ((typ, state_key), event_id) @@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # wildcards (i.e. Nones) in which case we have to do an exhaustive # search if ( - max_entries_returned is not None and - len(results[group]) == max_entries_returned + max_entries_returned is not None + and len(results[group]) == max_entries_returned ): break @@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self._get_state_group_for_events( - event_ids, - ) + event_to_groups = yield self._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, state_filter) state_event_map = yield self.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], - get_prev_content=False + get_prev_content=False, ) event_to_state = { @@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self._get_state_group_for_events( - event_ids, - ) + event_to_groups = yield self._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, state_filter) @@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_state_group_for_event(self, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", - keyvalues={ - "event_id": event_id, - }, + keyvalues={"event_id": event_id}, retcol="state_group", allow_none=True, desc="_get_state_group_for_event", ) - @cachedList(cached_method_name="_get_state_group_for_event", - list_name="event_ids", num_args=1, inlineCallbacks=True) + @cachedList( + cached_method_name="_get_state_group_for_event", + list_name="event_ids", + num_args=1, + inlineCallbacks=True, + ) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ @@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): column="event_id", iterable=event_ids, keyvalues={}, - retcols=("event_id", "state_group",), + retcols=("event_id", "state_group"), desc="_get_state_group_for_events", ) @@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # Now we look them up in the member and non-member caches non_member_state, incomplete_groups_nm, = ( yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, - state_filter=non_member_filter, + groups, self._state_group_cache, state_filter=non_member_filter ) ) member_state, incomplete_groups_m, = ( yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, - state_filter=member_filter, + groups, self._state_group_members_cache, state_filter=member_filter ) ) @@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): db_state_filter = state_filter.return_expanded() group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), - state_filter=db_state_filter, + list(incomplete_groups), state_filter=db_state_filter ) # Now lets update the caches @@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue(state) - def _get_state_for_groups_using_cache( - self, groups, cache, state_filter, - ): + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. @@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results, incomplete_groups - def _insert_into_cache(self, group_to_state_dict, state_filter, - cache_seq_num_members, cache_seq_num_non_members): + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): """Inserts results from querying the database into the relevant cache. Args: @@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): fetched_keys=non_member_types, ) - def store_state_group(self, event_id, room_id, prev_group, delta_ids, - current_state_ids): + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): """Store a new set of state, returning a newly assigned state group. Args: @@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: Deferred[int]: The state group ID """ + def _store_state_group_txn(txn): if current_state_ids is None: # AFAIK, this can never happen @@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): self._simple_insert_txn( txn, table="state_groups", - values={ - "id": state_group, - "room_id": room_id, - "event_id": event_id, - }, + values={"id": state_group, "room_id": room_id, "event_id": event_id}, ) # We persist as a delta if we can, while also ensuring the chain @@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): % (prev_group,) ) - potential_hops = self._count_state_group_hops_txn( - txn, prev_group - ) + potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: self._simple_insert_txn( txn, table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, + values={"state_group": state_group, "prev_state_group": prev_group}, ) self._simple_insert_many_txn( @@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): This is used to ensure the delta chains don't get too long. """ if isinstance(self.database_engine, PostgresEngine): - sql = (""" + sql = """ WITH RECURSIVE state(state_group) AS ( VALUES(?::bigint) UNION ALL @@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): WHERE s.state_group = e.state_group ) SELECT count(*) FROM state; - """) + """ txn.execute(sql, (state_group,)) row = txn.fetchone() @@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): self._background_deduplicate_state, ) self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) self.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, @@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): txn, table="event_to_state_groups", values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } + {"state_group": state_group_id, "event_id": event_id} for event_id, state_group_id in iteritems(state_groups) ], ) for event_id, state_group_id in iteritems(state_groups): txn.call_after( - self._get_state_group_for_event.prefill, - (event_id,), state_group_id + self._get_state_group_for_event.prefill, (event_id,), state_group_id ) @defer.inlineCallbacks @@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): if max_group is None: rows = yield self._execute( - "_background_deduplicate_state", None, + "_background_deduplicate_state", + None, "SELECT coalesce(max(id), 0) FROM state_groups", ) max_group = rows[0][0] @@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): " WHERE ? < id AND id <= ?" " ORDER BY id ASC" " LIMIT 1", - (new_last_state_group, max_group,) + (new_last_state_group, max_group), ) row = txn.fetchone() if row: @@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): txn.execute( "SELECT state_group FROM state_group_edges" " WHERE state_group = ?", - (state_group,) + (state_group,), ) # If we reach a point where we've already started inserting @@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): txn.execute( "SELECT coalesce(max(id), 0) FROM state_groups" " WHERE id < ? AND room_id = ?", - (state_group, room_id,) + (state_group, room_id), ) prev_group, = txn.fetchone() new_last_state_group = state_group if prev_group: - potential_hops = self._count_state_group_hops_txn( - txn, prev_group - ) + potential_hops = self._count_state_group_hops_txn(txn, prev_group) if potential_hops >= MAX_STATE_DELTA_HOPS: # We want to ensure chains are at most this long,# # otherwise read performance degrades. continue prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group], + txn, [prev_group] ) prev_state = prev_state[prev_group] curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group], + txn, [state_group] ) curr_state = curr_state[state_group] @@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): # of keys delta_state = { - key: value for key, value in iteritems(curr_state) + key: value + for key, value in iteritems(curr_state) if prev_state.get(key, None) != value } self._simple_delete_txn( txn, table="state_group_edges", - keyvalues={ - "state_group": state_group, - } + keyvalues={"state_group": state_group}, ) self._simple_insert_txn( @@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): values={ "state_group": state_group, "prev_state_group": prev_group, - } + }, ) self._simple_delete_txn( txn, table="state_groups_state", - keyvalues={ - "state_group": state_group, - } + keyvalues={"state_group": state_group}, ) self._simple_insert_many_txn( @@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): ) if finished: - yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME) + yield self._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR) @@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" ) - txn.execute( - "DROP INDEX IF EXISTS state_groups_state_id" - ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") finally: conn.set_session(autocommit=False) else: @@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): "CREATE INDEX state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" ) - txn.execute( - "DROP INDEX IF EXISTS state_groups_state_id" - ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") yield self.runWithConnection(reindex_txn) |