diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6ddc4055d2..0bfe1b4550 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -40,10 +40,13 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
-class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+class _GetStateGroupDelta(
+ namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
"""Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching
"""
+
__slots__ = []
def __len__(self):
@@ -70,10 +73,7 @@ class StateFilter(object):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {
- k: v for k, v in iteritems(self.types)
- if v is not None
- }
+ self.types = {k: v for k, v in iteritems(self.types) if v is not None}
@staticmethod
def all():
@@ -130,10 +130,7 @@ class StateFilter(object):
Returns:
StateFilter
"""
- return StateFilter(
- types={EventTypes.Member: set(members)},
- include_others=True,
- )
+ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self):
"""Creates a new StateFilter where type wild cards have been removed
@@ -243,9 +240,7 @@ class StateFilter(object):
if where_clause:
where_clause += " OR "
- where_clause += "type NOT IN (%s)" % (
- ",".join(["?"] * len(self.types)),
- )
+ where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
where_args.extend(self.types)
return where_clause, where_args
@@ -305,12 +300,8 @@ class StateFilter(object):
bool
"""
- return (
- self.include_others
- or any(
- state_keys is None
- for state_keys in itervalues(self.types)
- )
+ return self.include_others or any(
+ state_keys is None for state_keys in itervalues(self.types)
)
def concrete_types(self):
@@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._state_group_cache = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
- 50000 * get_cache_factor_for("stateGroupCache")
+ 50000 * get_cache_factor_for("stateGroupCache"),
)
self._state_group_members_cache = DictionaryCache(
"*stateGroupMembersCache*",
- 500000 * get_cache_factor_for("stateGroupMembersCache")
+ 500000 * get_cache_factor_for("stateGroupMembersCache"),
)
@defer.inlineCallbacks
@@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
deferred: dict of (type, state_key) -> event_id
"""
+
def _get_current_state_ids_txn(txn):
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
""",
- (room_id,)
+ (room_id,),
)
return {
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
- return self.runInteraction(
- "get_current_state_ids",
- _get_current_state_ids_txn,
- )
+ return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
return self.runInteraction(
- "get_filtered_current_state_ids",
- _get_filtered_current_state_ids_txn,
+ "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@defer.inlineCallbacks
@@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[str|None]: The canonical alias, if any
"""
- state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types(
- [(EventTypes.CanonicalAlias, "")]
- ))
+ state = yield self.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+ )
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
@@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
(prev_group, delta_ids), where both may be None.
"""
+
def _get_state_group_delta_txn(txn):
prev_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
- keyvalues={
- "state_group": state_group,
- },
+ keyvalues={"state_group": state_group},
retcol="prev_state_group",
allow_none=True,
)
@@ -598,20 +585,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
delta_ids = self._simple_select_list_txn(
txn,
table="state_groups_state",
- keyvalues={
- "state_group": state_group,
- },
- retcols=("type", "state_key", "event_id",)
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
)
- return _GetStateGroupDelta(prev_group, {
- (row["type"], row["state_key"]): row["event_id"]
- for row in delta_ids
- })
- return self.runInteraction(
- "get_state_group_delta",
- _get_state_group_delta_txn,
- )
+ return _GetStateGroupDelta(
+ prev_group,
+ {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ )
+
+ return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_ids:
defer.returnValue({})
- event_to_groups = yield self._get_state_group_for_events(
- event_ids,
- )
+ event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups)
@@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_event_map = yield self.get_events(
[
- ev_id for group_ids in itervalues(group_to_ids)
+ ev_id
+ for group_ids in itervalues(group_to_ids)
for ev_id in itervalues(group_ids)
],
- get_prev_content=False
+ get_prev_content=False,
)
- defer.returnValue({
- group: [
- state_event_map[v] for v in itervalues(event_id_map)
- if v in state_event_map
- ]
- for group, event_id_map in iteritems(group_to_ids)
- })
+ defer.returnValue(
+ {
+ group: [
+ state_event_map[v]
+ for v in itervalues(event_id_map)
+ if v in state_event_map
+ ]
+ for group, event_id_map in iteritems(group_to_ids)
+ }
+ )
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter):
@@ -695,18 +680,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
results = {}
- chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
+ chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn, chunk, state_filter,
+ self._get_state_groups_from_groups_txn,
+ chunk,
+ state_filter,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all(),
+ self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
@@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
- args
+ args,
)
results[group].update(
((typ, state_key), event_id)
@@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
- max_entries_returned is not None and
- len(results[group]) == max_entries_returned
+ max_entries_returned is not None
+ and len(results[group]) == max_entries_returned
):
break
@@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
- event_to_groups = yield self._get_state_group_for_events(
- event_ids,
- )
+ event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter)
state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
- get_prev_content=False
+ get_prev_content=False,
)
event_to_state = {
@@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
- event_to_groups = yield self._get_state_group_for_events(
- event_ids,
- )
+ event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter)
@@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
- keyvalues={
- "event_id": event_id,
- },
+ keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
desc="_get_state_group_for_event",
)
- @cachedList(cached_method_name="_get_state_group_for_event",
- list_name="event_ids", num_args=1, inlineCallbacks=True)
+ @cachedList(
+ cached_method_name="_get_state_group_for_event",
+ list_name="event_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
@@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
column="event_id",
iterable=event_ids,
keyvalues={},
- retcols=("event_id", "state_group",),
+ retcols=("event_id", "state_group"),
desc="_get_state_group_for_events",
)
@@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Now we look them up in the member and non-member caches
non_member_state, incomplete_groups_nm, = (
yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache,
- state_filter=non_member_filter,
+ groups, self._state_group_cache, state_filter=non_member_filter
)
)
member_state, incomplete_groups_m, = (
yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache,
- state_filter=member_filter,
+ groups, self._state_group_members_cache, state_filter=member_filter
)
)
@@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
db_state_filter = state_filter.return_expanded()
group_to_state_dict = yield self._get_state_groups_from_groups(
- list(incomplete_groups),
- state_filter=db_state_filter,
+ list(incomplete_groups), state_filter=db_state_filter
)
# Now lets update the caches
@@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue(state)
- def _get_state_for_groups_using_cache(
- self, groups, cache, state_filter,
- ):
+ def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results, incomplete_groups
- def _insert_into_cache(self, group_to_state_dict, state_filter,
- cache_seq_num_members, cache_seq_num_non_members):
+ def _insert_into_cache(
+ self,
+ group_to_state_dict,
+ state_filter,
+ cache_seq_num_members,
+ cache_seq_num_non_members,
+ ):
"""Inserts results from querying the database into the relevant cache.
Args:
@@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
fetched_keys=non_member_types,
)
- def store_state_group(self, event_id, room_id, prev_group, delta_ids,
- current_state_ids):
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
"""Store a new set of state, returning a newly assigned state group.
Args:
@@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[int]: The state group ID
"""
+
def _store_state_group_txn(txn):
if current_state_ids is None:
# AFAIK, this can never happen
@@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._simple_insert_txn(
txn,
table="state_groups",
- values={
- "id": state_group,
- "room_id": room_id,
- "event_id": event_id,
- },
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
)
# We persist as a delta if we can, while also ensuring the chain
@@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
% (prev_group,)
)
- potential_hops = self._count_state_group_hops_txn(
- txn, prev_group
- )
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
- values={
- "state_group": state_group,
- "prev_state_group": prev_group,
- },
+ values={"state_group": state_group, "prev_state_group": prev_group},
)
self._simple_insert_many_txn(
@@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
- sql = ("""
+ sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
@@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
- """)
+ """
txn.execute(sql, (state_group,))
row = txn.fetchone()
@@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
self._background_deduplicate_state,
)
self.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME,
- self._background_index_state,
+ self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn,
table="event_to_state_groups",
values=[
- {
- "state_group": state_group_id,
- "event_id": event_id,
- }
+ {"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
- self._get_state_group_for_event.prefill,
- (event_id,), state_group_id
+ self._get_state_group_for_event.prefill, (event_id,), state_group_id
)
@defer.inlineCallbacks
@@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
if max_group is None:
rows = yield self._execute(
- "_background_deduplicate_state", None,
+ "_background_deduplicate_state",
+ None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
@@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC"
" LIMIT 1",
- (new_last_state_group, max_group,)
+ (new_last_state_group, max_group),
)
row = txn.fetchone()
if row:
@@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute(
"SELECT state_group FROM state_group_edges"
" WHERE state_group = ?",
- (state_group,)
+ (state_group,),
)
# If we reach a point where we've already started inserting
@@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute(
"SELECT coalesce(max(id), 0) FROM state_groups"
" WHERE id < ? AND room_id = ?",
- (state_group, room_id,)
+ (state_group, room_id),
)
prev_group, = txn.fetchone()
new_last_state_group = state_group
if prev_group:
- potential_hops = self._count_state_group_hops_txn(
- txn, prev_group
- )
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if potential_hops >= MAX_STATE_DELTA_HOPS:
# We want to ensure chains are at most this long,#
# otherwise read performance degrades.
continue
prev_state = self._get_state_groups_from_groups_txn(
- txn, [prev_group],
+ txn, [prev_group]
)
prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn(
- txn, [state_group],
+ txn, [state_group]
)
curr_state = curr_state[state_group]
@@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
# of keys
delta_state = {
- key: value for key, value in iteritems(curr_state)
+ key: value
+ for key, value in iteritems(curr_state)
if prev_state.get(key, None) != value
}
self._simple_delete_txn(
txn,
table="state_group_edges",
- keyvalues={
- "state_group": state_group,
- }
+ keyvalues={"state_group": state_group},
)
self._simple_insert_txn(
@@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
values={
"state_group": state_group,
"prev_state_group": prev_group,
- }
+ },
)
self._simple_delete_txn(
txn,
table="state_groups_state",
- keyvalues={
- "state_group": state_group,
- }
+ keyvalues={"state_group": state_group},
)
self._simple_insert_many_txn(
@@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
)
if finished:
- yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
+ yield self._end_background_update(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+ )
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
@@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
- txn.execute(
- "DROP INDEX IF EXISTS state_groups_state_id"
- )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
finally:
conn.set_session(autocommit=False)
else:
@@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
- txn.execute(
- "DROP INDEX IF EXISTS state_groups_state_id"
- )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
yield self.runWithConnection(reindex_txn)
|