diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 6a90daea31..9ef7b48c74 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -27,8 +27,8 @@ from synapse.api.errors import NotFoundError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
from synapse.util.caches import get_cache_factor_for, intern_string
@@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
count = 0
while next_group:
- next_group = self._simple_select_one_onecol_txn(
+ next_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
):
break
- next_group = self._simple_select_one_onecol_txn(
+ next_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -214,8 +214,8 @@ class StateGroupWorkerStore(
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
- def __init__(self, db_conn, hs):
- super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
@@ -348,7 +348,9 @@ class StateGroupWorkerStore(
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
- return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
+ return self.db.runInteraction(
+ "get_current_state_ids", _get_current_state_ids_txn
+ )
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@@ -392,7 +394,7 @@ class StateGroupWorkerStore(
return results
- return self.runInteraction(
+ return self.db.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@@ -431,7 +433,7 @@ class StateGroupWorkerStore(
"""
def _get_state_group_delta_txn(txn):
- prev_group = self._simple_select_one_onecol_txn(
+ prev_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
@@ -442,7 +444,7 @@ class StateGroupWorkerStore(
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self._simple_select_list_txn(
+ delta_ids = self.db.simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
@@ -454,7 +456,9 @@ class StateGroupWorkerStore(
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
- return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
+ return self.db.runInteraction(
+ "get_state_group_delta", _get_state_group_delta_txn
+ )
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@@ -540,7 +544,7 @@ class StateGroupWorkerStore(
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn,
chunk,
@@ -644,7 +648,7 @@ class StateGroupWorkerStore(
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
@@ -661,7 +665,7 @@ class StateGroupWorkerStore(
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
@@ -902,7 +906,7 @@ class StateGroupWorkerStore(
state_group = self.database_engine.get_next_state_group_id(txn)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="state_groups",
values={"id": state_group, "room_id": room_id, "event_id": event_id},
@@ -911,7 +915,7 @@ class StateGroupWorkerStore(
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
- is_in_db = self._simple_select_one_onecol_txn(
+ is_in_db = self.db.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
@@ -926,13 +930,13 @@ class StateGroupWorkerStore(
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="state_group_edges",
values={"state_group": state_group, "prev_state_group": prev_group},
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -947,7 +951,7 @@ class StateGroupWorkerStore(
],
)
else:
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -993,7 +997,7 @@ class StateGroupWorkerStore(
return state_group
- return self.runInteraction("store_state_group", _store_state_group_txn)
+ return self.db.runInteraction("store_state_group", _store_state_group_txn)
@defer.inlineCallbacks
def get_referenced_state_groups(self, state_groups):
@@ -1007,7 +1011,7 @@ class StateGroupWorkerStore(
referenced.
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
@@ -1019,32 +1023,30 @@ class StateGroupWorkerStore(
return set(row["state_group"] for row in rows)
-class StateBackgroundUpdateStore(
- StateGroupBackgroundUpdateStore, BackgroundUpdateStore
-):
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
- def __init__(self, db_conn, hs):
- super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ self.db.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
index_name="event_to_state_groups_sg_index",
table="event_to_state_groups",
@@ -1065,7 +1067,7 @@ class StateBackgroundUpdateStore(
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
- rows = yield self._execute(
+ rows = yield self.db.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
@@ -1135,13 +1137,13 @@ class StateBackgroundUpdateStore(
if prev_state.get(key, None) != value
}
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="state_group_edges",
values={
@@ -1150,13 +1152,13 @@ class StateBackgroundUpdateStore(
},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -1177,18 +1179,18 @@ class StateBackgroundUpdateStore(
"max_group": max_group,
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
)
return False, batch_size
- finished, result = yield self.runInteraction(
+ finished, result = yield self.db.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
)
if finished:
- yield self._end_background_update(
+ yield self.db.updates._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
@@ -1218,9 +1220,9 @@ class StateBackgroundUpdateStore(
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- yield self.runWithConnection(reindex_txn)
+ yield self.db.runWithConnection(reindex_txn)
- yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+ yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
return 1
@@ -1244,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
- def __init__(self, db_conn, hs):
- super(StateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateStore, self).__init__(database, db_conn, hs)
def _store_event_state_mappings_txn(
self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
@@ -1263,7 +1265,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
state_groups[event.event_id] = context.state_group
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
|