diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1dc347f0c9..5c21402dea 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -61,6 +61,7 @@ from .registration import RegistrationStore
from .rejections import RejectionsStore
from .relations import RelationsStore
from .room import RoomStore
+from .room_batch import RoomBatchStore
from .roommember import RoomMemberStore
from .search import SearchStore
from .session import SessionStore
@@ -81,6 +82,7 @@ class DataStore(
EventsBackgroundUpdatesStore,
RoomMemberStore,
RoomStore,
+ RoomBatchStore,
RegistrationStore,
StreamStore,
ProfileStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 1d02795f43..d0cf3460da 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -494,7 +494,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn,
table="ignored_users",
column="ignored_user_id",
- iterable=previously_ignored_users - currently_ignored_users,
+ values=previously_ignored_users - currently_ignored_users,
keyvalues={"ignorer_user_id": user_id},
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 047782eb06..10184d6ae7 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1034,13 +1034,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
LIMIT ?
"""
- # Find any chunk connections of a given insertion event
- chunk_connection_query = """
+ # Find any batch connections of a given insertion event
+ batch_connection_query = """
SELECT e.depth, c.event_id FROM insertion_events AS i
- /* Find the chunk that connects to the given insertion event */
- INNER JOIN chunk_events AS c
- ON i.next_chunk_id = c.chunk_id
- /* Get the depth of the chunk start event from the events table */
+ /* Find the batch that connects to the given insertion event */
+ INNER JOIN batch_events AS c
+ ON i.next_batch_id = c.batch_id
+ /* Get the depth of the batch start event from the events table */
INNER JOIN events AS e USING (event_id)
/* Find an insertion event which matches the given event_id */
WHERE i.event_id = ?
@@ -1077,12 +1077,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.add(event_id)
- # Try and find any potential historical chunks of message history.
+ # Try and find any potential historical batches of message history.
#
# First we look for an insertion event connected to the current
# event (by prev_event). If we find any, we need to go and try to
- # find any chunk events connected to the insertion event (by
- # chunk_id). If we find any, we'll add them to the queue and
+ # find any batch events connected to the insertion event (by
+ # batch_id). If we find any, we'll add them to the queue and
# navigate up the DAG like normal in the next iteration of the loop.
txn.execute(
connected_insertion_event_query, (event_id, limit - len(event_results))
@@ -1097,17 +1097,17 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
connected_insertion_event = row[1]
queue.put((-connected_insertion_event_depth, connected_insertion_event))
- # Find any chunk connections for the given insertion event
+ # Find any batch connections for the given insertion event
txn.execute(
- chunk_connection_query,
+ batch_connection_query,
(connected_insertion_event, limit - len(event_results)),
)
- chunk_start_event_id_results = txn.fetchall()
+ batch_start_event_id_results = txn.fetchall()
logger.debug(
- "_get_backfill_events: chunk_start_event_id_results %s",
- chunk_start_event_id_results,
+ "_get_backfill_events: batch_start_event_id_results %s",
+ batch_start_event_id_results,
)
- for row in chunk_start_event_id_results:
+ for row in batch_start_event_id_results:
if row[1] not in event_results:
queue.put((-row[0], row[1]))
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8e691678e5..584f818ff3 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -667,7 +667,7 @@ class PersistEventsStore:
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
- iterable=new_chain_tuples,
+ values=new_chain_tuples,
)
# Now we need to calculate any new links between chains caused by
@@ -1509,7 +1509,7 @@ class PersistEventsStore:
self._handle_event_relations(txn, event)
self._handle_insertion_event(txn, event)
- self._handle_chunk_event(txn, event)
+ self._handle_batch_event(txn, event)
# Store the labels for this event.
labels = event.content.get(EventContentFields.LABELS)
@@ -1790,23 +1790,23 @@ class PersistEventsStore:
):
return
- next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID)
- if next_chunk_id is None:
- # Invalid insertion event without next chunk ID
+ next_batch_id = event.content.get(EventContentFields.MSC2716_NEXT_BATCH_ID)
+ if next_batch_id is None:
+ # Invalid insertion event without next batch ID
return
logger.debug(
- "_handle_insertion_event (next_chunk_id=%s) %s", next_chunk_id, event
+ "_handle_insertion_event (next_batch_id=%s) %s", next_batch_id, event
)
- # Keep track of the insertion event and the chunk ID
+ # Keep track of the insertion event and the batch ID
self.db_pool.simple_insert_txn(
txn,
table="insertion_events",
values={
"event_id": event.event_id,
"room_id": event.room_id,
- "next_chunk_id": next_chunk_id,
+ "next_batch_id": next_batch_id,
},
)
@@ -1822,8 +1822,8 @@ class PersistEventsStore:
},
)
- def _handle_chunk_event(self, txn: LoggingTransaction, event: EventBase):
- """Handles inserting the chunk edges/connections between the chunk event
+ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+ """Handles inserting the batch edges/connections between the batch event
and an insertion event. Part of MSC2716.
Args:
@@ -1831,11 +1831,11 @@ class PersistEventsStore:
event: The event to process
"""
- if event.type != EventTypes.MSC2716_CHUNK:
- # Not a chunk event
+ if event.type != EventTypes.MSC2716_BATCH:
+ # Not a batch event
return
- # Skip processing a chunk event if the room version doesn't
+ # Skip processing a batch event if the room version doesn't
# support it or the event is not from the room creator.
room_version = self.store.get_room_version_txn(txn, event.room_id)
room_creator = self.db_pool.simple_select_one_onecol_txn(
@@ -1852,35 +1852,35 @@ class PersistEventsStore:
):
return
- chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID)
- if chunk_id is None:
- # Invalid chunk event without a chunk ID
+ batch_id = event.content.get(EventContentFields.MSC2716_BATCH_ID)
+ if batch_id is None:
+ # Invalid batch event without a batch ID
return
- logger.debug("_handle_chunk_event chunk_id=%s %s", chunk_id, event)
+ logger.debug("_handle_batch_event batch_id=%s %s", batch_id, event)
- # Keep track of the insertion event and the chunk ID
+ # Keep track of the insertion event and the batch ID
self.db_pool.simple_insert_txn(
txn,
- table="chunk_events",
+ table="batch_events",
values={
"event_id": event.event_id,
"room_id": event.room_id,
- "chunk_id": chunk_id,
+ "batch_id": batch_id,
},
)
- # When we receive an event with a `chunk_id` referencing the
- # `next_chunk_id` of the insertion event, we can remove it from the
+ # When we receive an event with a `batch_id` referencing the
+ # `next_batch_id` of the insertion event, we can remove it from the
# `insertion_event_extremities` table.
sql = """
DELETE FROM insertion_event_extremities WHERE event_id IN (
SELECT event_id FROM insertion_events
- WHERE next_chunk_id = ?
+ WHERE next_batch_id = ?
)
"""
- txn.execute(sql, (chunk_id,))
+ txn.execute(sql, (batch_id,))
def _handle_redaction(self, txn, redacted_event_id):
"""Handles receiving a redaction and checking whether we need to remove
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 6fcb2b8353..1afc59fafb 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -490,7 +490,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="event_forward_extremities",
column="event_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
@@ -520,7 +520,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="_extremities_to_check",
column="event_id",
- iterable=original_set,
+ values=original_set,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 63ac09c61d..a93caae8d0 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -324,7 +324,7 @@ class PusherWorkerStore(SQLBaseStore):
txn,
table="pushers",
column="user_name",
- iterable=users,
+ values=users,
keyvalues={},
)
@@ -373,7 +373,7 @@ class PusherWorkerStore(SQLBaseStore):
txn,
table="pushers",
column="id",
- iterable=(pusher_id for pusher_id, token in pushers if token is None),
+ values=[pusher_id for pusher_id, token in pushers if token is None],
keyvalues={},
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index edeaacd7a6..01a4281301 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
@@ -153,12 +153,12 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_id: List of room_ids.
+ room_id: The room IDs to fetch receipts of.
to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
new file mode 100644
index 0000000000..a383388757
--- /dev/null
+++ b/synapse/storage/databases/main/room_batch.py
@@ -0,0 +1,36 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from synapse.storage._base import SQLBaseStore
+
+
+class RoomBatchStore(SQLBaseStore):
+ async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]:
+ """Retrieve a insertion event ID.
+
+ Args:
+ batch_id: The batch ID of the insertion event to retrieve.
+
+ Returns:
+ The event_id of an insertion event, or None if there is no known
+ insertion event for the given insertion event.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="insertion_events",
+ keyvalues={"next_batch_id": batch_id},
+ retcol="event_id",
+ allow_none=True,
+ )
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 8e22da99ae..a8e8dd4577 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -473,7 +473,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
@@ -481,7 +481,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="event_forward_extremities",
column="room_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 4d6bbc94c7..340ca9e47d 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -326,7 +326,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions_ips",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
@@ -377,7 +377,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions_credentials",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
@@ -386,7 +386,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 8aebdc2817..718f3e9976 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -85,19 +85,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
- # If search all users is on, get all the users we want to add.
- if self.hs.config.user_directory_search_all_users:
- sql = (
- "CREATE TABLE IF NOT EXISTS "
- + TEMP_TABLE
- + "_users(user_id TEXT NOT NULL)"
- )
- txn.execute(sql)
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_users(user_id TEXT NOT NULL)"
+ )
+ txn.execute(sql)
- txn.execute("SELECT name FROM users")
- users = [{"user_id": x[0]} for x in txn.fetchall()]
+ txn.execute("SELECT name FROM users")
+ users = [{"user_id": x[0]} for x in txn.fetchall()]
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = await self.get_max_stream_id_in_current_state_deltas()
await self.db_pool.runInteraction(
@@ -265,13 +263,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def _populate_user_directory_process_users(self, progress, batch_size):
"""
- If search_all_users is enabled, add all of the users to the user directory.
+ Add all local users to the user directory.
"""
- if not self.hs.config.user_directory_search_all_users:
- await self.db_pool.updates._end_background_update(
- "populate_user_directory_process_users"
- )
- return 1
def _get_next_batch(txn):
sql = "SELECT user_id FROM %s LIMIT %s" % (
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c2891cb07f..eb1118d2cb 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,12 +13,20 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates.
"""
- def _count_state_group_hops_txn(self, txn, state_group):
+ def _count_state_group_hops_txn(
+ self, txn: LoggingTransaction, state_group: int
+ ) -> int:
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
+ next_group: Optional[int] = state_group
count = 0
while next_group:
@@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter: Optional[StateFilter] = None
- ):
+ self,
+ txn: LoggingTransaction,
+ groups: List[int],
+ state_filter: Optional[StateFilter] = None,
+ ) -> Mapping[int, StateMap[str]]:
state_filter = state_filter or StateFilter.all()
- results = {group: {} for group in groups}
+ results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
@@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""
for group in groups:
- args = [group]
+ args: List[Union[int, str]] = [group]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
@@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
- next_group = group
+ next_group: Optional[int] = group
while next_group:
# We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
allow_none=True,
)
+ # The results shouldn't be considered mutable.
return results
@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
- async def _background_deduplicate_state(self, progress, batch_size):
+ async def _background_deduplicate_state(
+ self, progress: dict, batch_size: int
+ ) -> int:
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
max_group = rows[0][0]
- def reindex_txn(txn):
+ def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group
for count in range(batch_size):
txn.execute(
@@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
- (prev_group,) = txn.fetchone()
+ # There will be a result due to the coalesce.
+ (prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group
if prev_group:
@@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
# otherwise read performance degrades.
continue
- prev_state = self._get_state_groups_from_groups_txn(
+ prev_state_by_group = self._get_state_groups_from_groups_txn(
txn, [prev_group]
)
- prev_state = prev_state[prev_group]
+ prev_state = prev_state_by_group[prev_group]
- curr_state = self._get_state_groups_from_groups_txn(
+ curr_state_by_group = self._get_state_groups_from_groups_txn(
txn, [state_group]
)
- curr_state = curr_state[state_group]
+ curr_state = curr_state_by_group[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return result * BATCH_SIZE_SCALE_FACTOR
- async def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
+ async def _background_index_state(self, progress: dict, batch_size: int) -> int:
+ def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f839c0c24f..c4c8c0021b 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,43 +13,56 @@
# limitations under the License.
import logging
-from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+
+import attr
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
-class _GetStateGroupDelta(
- namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _GetStateGroupDelta:
"""Return type of get_state_group_delta that implements __len__, which lets
- us use the itrable flag when caching
+ us use the iterable flag when caching
"""
- __slots__ = []
+ prev_group: Optional[int]
+ delta_ids: Optional[StateMap[str]]
- def __len__(self):
+ def __len__(self) -> int:
return len(self.delta_ids) if self.delta_ids else 0
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups."""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
@@ -81,19 +94,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# We size the non-members cache to be smaller than the members cache as the
# vast majority of state in Matrix (today) is member events.
- self._state_group_cache = DictionaryCache(
+ self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
50000,
)
- self._state_group_members_cache = DictionaryCache(
+ self._state_group_members_cache: DictionaryCache[
+ int, StateKey, str
+ ] = DictionaryCache(
"*stateGroupMembersCache*",
500000,
)
- def get_max_state_group_txn(txn: Cursor):
+ def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
- return txn.fetchone()[0]
+ return txn.fetchone()[0] # type: ignore
self._state_group_seq_gen = build_sequence_generator(
db_conn,
@@ -105,15 +120,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@cached(max_entries=10000, iterable=True)
- async def get_state_group_delta(self, state_group):
+ async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta:
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- (prev_group, delta_ids), where both may be None.
+ _GetStateGroupDelta containing prev_group and delta_ids, where both may be None.
"""
- def _get_state_group_delta_txn(txn):
+ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
prev_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
@@ -154,7 +169,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
- results = {}
+ results: Dict[int, StateMap[str]] = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
@@ -168,19 +183,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return results
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ def _get_state_for_group_using_cache(
+ self,
+ cache: DictionaryCache[int, StateKey, str],
+ group: int,
+ state_filter: StateFilter,
+ ) -> Tuple[MutableStateMap[str], bool]:
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
- cache(DictionaryCache): the state group cache to use
- group(int): The state group to lookup
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ cache: the state group cache to use
+ group: The state group to lookup
+ state_filter: The state filter used to fetch state from the database.
- Returns 2-tuple (`state_dict`, `got_all`).
- `got_all` is a bool indicating if we successfully retrieved all
- requests state from the cache, if False we need to query the DB for the
- missing state.
+ Returns:
+ 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
"""
cache_entry = cache.get(group)
state_dict_ids = cache_entry.value
@@ -277,8 +297,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state
def _get_state_for_groups_using_cache(
- self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
- ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+ self,
+ groups: Iterable[int],
+ cache: DictionaryCache[int, StateKey, str],
+ state_filter: StateFilter,
+ ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@@ -310,21 +333,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
def _insert_into_cache(
self,
- group_to_state_dict,
- state_filter,
- cache_seq_num_members,
- cache_seq_num_non_members,
- ):
+ group_to_state_dict: Dict[int, StateMap[str]],
+ state_filter: StateFilter,
+ cache_seq_num_members: int,
+ cache_seq_num_non_members: int,
+ ) -> None:
"""Inserts results from querying the database into the relevant cache.
Args:
- group_to_state_dict (dict): The new entries pulled from database.
+ group_to_state_dict: The new entries pulled from database.
Map from state group to state dict
- state_filter (StateFilter): The state filter used to fetch state
+ state_filter: The state filter used to fetch state
from the database.
- cache_seq_num_members (int): Sequence number of member cache since
+ cache_seq_num_members: Sequence number of member cache since
last lookup in cache
- cache_seq_num_non_members (int): Sequence number of member cache since
+ cache_seq_num_non_members: Sequence number of member cache since
last lookup in cache
"""
@@ -395,7 +418,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group ID
"""
- def _store_state_group_txn(txn):
+ def _store_state_group_txn(txn: LoggingTransaction) -> int:
if current_state_ids is None:
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
@@ -426,6 +449,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ assert delta_ids is not None
+
self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
@@ -498,7 +523,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def purge_unreferenced_state_groups(
- self, room_id: str, state_groups_to_delete
+ self, room_id: str, state_groups_to_delete: Collection[int]
) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@@ -506,8 +531,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Args:
room_id: The room the state groups belong to (must all be in the
same room).
- state_groups_to_delete (Collection[int]): Set of all state groups
- to delete.
+ state_groups_to_delete: Set of all state groups to delete.
"""
await self.db_pool.runInteraction(
@@ -517,7 +541,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete,
)
- def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ def _purge_unreferenced_state_groups(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_groups_to_delete: Collection[int],
+ ) -> None:
logger.info(
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
@@ -546,8 +575,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# groups to non delta versions.
for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
+ curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state_by_group[sg]
self.db_pool.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
@@ -605,12 +634,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows}
- async def purge_room_state(self, room_id, state_groups_to_delete):
+ async def purge_room_state(
+ self, room_id: str, state_groups_to_delete: Collection[int]
+ ) -> None:
"""Deletes all record of a room from state tables
Args:
- room_id (str):
- state_groups_to_delete (list[int]): State groups to delete
+ room_id:
+ state_groups_to_delete: State groups to delete
"""
await self.db_pool.runInteraction(
@@ -620,7 +651,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete,
)
- def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ def _purge_room_state_txn(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_groups_to_delete: Collection[int],
+ ) -> None:
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
@@ -628,7 +664,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_groups_state",
column="state_group",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
@@ -639,7 +675,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_group_edges",
column="state_group",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
@@ -650,6 +686,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_groups",
column="id",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
|