diff --git a/changelog.d/12372.feature b/changelog.d/12372.feature
new file mode 100644
index 0000000000..34bb60e966
--- /dev/null
+++ b/changelog.d/12372.feature
@@ -0,0 +1 @@
+Reduce overhead of restarting synchrotrons.
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 00a634d3a9..30717c2bd0 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -20,7 +20,6 @@ from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatu
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -33,8 +32,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
- super().__init__(database, db_conn, hs)
-
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
@@ -47,26 +44,8 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
("device_lists_changes_in_room", "stream_id"),
],
)
- device_list_max = self._device_list_id_gen.get_current_token()
- device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
- db_conn,
- "device_lists_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=device_list_max,
- limit=1000,
- )
- self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache",
- min_device_list_id,
- prefilled_cache=device_list_prefill,
- )
- self._user_signature_stream_cache = StreamChangeCache(
- "UserSignatureStreamChangeCache", device_list_max
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max
- )
+
+ super().__init__(database, db_conn, hs)
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0264dea61d..12750d9b89 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2030,29 +2030,40 @@ class DatabasePool:
max_value: int,
limit: int = 100000,
) -> Tuple[Dict[Any, int], int]:
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - %(limit)s"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- "limit": limit,
- }
+ """Gets roughly the last N changes in the given stream table as a
+ map from entity to the stream ID of the most recent change.
+
+ Also returns the minimum stream ID.
+ """
+
+ # This may return many rows for the same entity, but the `limit` is only
+ # a suggestion so we don't care that much.
+ #
+ # Note: Some stream tables can have multiple rows with the same stream
+ # ID. Instead of handling this with complicated SQL, we instead simply
+ # add one to the returned minimum stream ID to ensure correctness.
+ sql = f"""
+ SELECT {entity_column}, {stream_column}
+ FROM {table}
+ ORDER BY {stream_column} DESC
+ LIMIT ?
+ """
txn = db_conn.cursor(txn_name="get_cache_dict")
- txn.execute(sql, (int(max_value),))
+ txn.execute(sql, (limit,))
- cache = {row[0]: int(row[1]) for row in txn}
+ # The rows come out in reverse stream ID order, so we want to keep the
+ # stream ID of the first row for each entity.
+ cache: Dict[Any, int] = {}
+ for row in txn:
+ cache.setdefault(row[0], int(row[1]))
txn.close()
if cache:
- min_val = min(cache.values())
+ # We add one here as we don't know if we have all rows for the
+ # minimum stream ID.
+ min_val = min(cache.values()) + 1
else:
min_val = max_value
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index d4a38daa9a..951031af50 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -183,27 +183,6 @@ class DataStore(
super().__init__(database, db_conn, hs)
- device_list_max = self._device_list_id_gen.get_current_token()
- device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
- db_conn,
- "device_lists_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=device_list_max,
- limit=1000,
- )
- self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache",
- min_device_list_id,
- prefilled_cache=device_list_prefill,
- )
- self._user_signature_stream_cache = StreamChangeCache(
- "UserSignatureStreamChangeCache", device_list_max
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max
- )
-
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 07eea4b3d2..dc8009b23d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ device_list_max = self._device_list_id_gen.get_current_token()
+ device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_stream_cache = StreamChangeCache(
+ "DeviceListStreamChangeCache",
+ min_device_list_id,
+ prefilled_cache=device_list_prefill,
+ )
+
+ (
+ user_signature_stream_prefill,
+ user_signature_stream_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "user_signature_stream",
+ entity_column="from_user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=1000,
+ )
+ self._user_signature_stream_cache = StreamChangeCache(
+ "UserSignatureStreamChangeCache",
+ user_signature_stream_list_id,
+ prefilled_cache=user_signature_stream_prefill,
+ )
+
+ (
+ device_list_federation_prefill,
+ device_list_federation_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_outbound_pokes",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache",
+ device_list_federation_list_id,
+ prefilled_cache=device_list_federation_prefill,
+ )
+
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e6f97aeece..332e901dda 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
+ max_receipts_stream_id = self.get_max_receipt_stream_id()
+ receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "receipts_linearized",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=max_receipts_stream_id,
+ limit=10000,
+ )
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
+ "ReceiptsRoomChangeCache",
+ min_receipts_stream_id,
+ prefilled_cache=receipts_stream_prefill,
)
def get_max_receipt_stream_id(self) -> int:
|