diff --git a/changelog.d/16613.feature b/changelog.d/16613.feature
new file mode 100644
index 0000000000..254d04a90e
--- /dev/null
+++ b/changelog.d/16613.feature
@@ -0,0 +1 @@
+Improve the performance of claiming encryption keys in multi-worker deployments.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index f50a4ce2fc..0af0507307 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1116,7 +1116,7 @@ class DatabasePool:
def simple_insert_many_txn(
txn: LoggingTransaction,
table: str,
- keys: Collection[str],
+ keys: Sequence[str],
values: Collection[Iterable[Any]],
) -> None:
"""Executes an INSERT query on the named table.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 4d0470ffd9..d7232f566b 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+ def _invalidate_cache_and_stream_bulk(
+ self,
+ txn: LoggingTransaction,
+ cache_func: CachedFunction,
+ key_tuples: Collection[Tuple[Any, ...]],
+ ) -> None:
+ """A bulk version of _invalidate_cache_and_stream.
+
+ Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
+ for each key-tuple over replication.
+
+ This implementation is more efficient than a loop which repeatedly calls the
+ non-bulk version.
+ """
+ if not key_tuples:
+ return
+
+ for keys in key_tuples:
+ txn.call_after(cache_func.invalidate, keys)
+
+ self._send_invalidation_to_replication_bulk(
+ txn, cache_func.__name__, key_tuples
+ )
+
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
@@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None
- # get_next() returns a context manager which is designed to wrap
- # the transaction. However, we want to only get an ID when we want
- # to use it, here, so we need to call __enter__ manually, and have
- # __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
@@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
+ def _send_invalidation_to_replication_bulk(
+ self,
+ txn: LoggingTransaction,
+ cache_name: str,
+ key_tuples: Collection[Tuple[Any, ...]],
+ ) -> None:
+ """Announce the invalidation of multiple (but not all) cache entries.
+
+ This is more efficient than repeated calls to the non-bulk version. It should
+ NOT be used to invalidating the entire cache: use
+ `_send_invalidation_to_replication` with keys=None.
+
+ Note that this does *not* invalidate the cache locally.
+
+ Args:
+ txn
+ cache_name
+ key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ assert self._cache_id_gen is not None
+
+ stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
+ ts = self._clock.time_msec()
+ txn.call_after(self.hs.get_notifier().on_new_replication_data)
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="cache_invalidation_stream_by_instance",
+ keys=(
+ "stream_id",
+ "instance_name",
+ "cache_func",
+ "keys",
+ "invalidation_ts",
+ ),
+ values=[
+ # We convert key_tuples to a list here because psycopg2 serialises
+ # lists as pq arrrays, but serialises tuples as "composite types".
+ # (We need an array because the `keys` column has type `[]text`.)
+ # See:
+ # https://www.psycopg.org/docs/usage.html#adapt-list
+ # https://www.psycopg.org/docs/usage.html#adapt-tuple
+ (stream_id, self._instance_name, cache_name, list(key_tuple), ts)
+ for stream_id, key_tuple in zip(stream_ids, key_tuples)
+ ],
+ )
+
def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4f96ac25c7..3005e2a2c5 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
-
- if (user_id, device_id) in seen_user_device:
- continue
seen_user_device.add((user_id, device_id))
- self._invalidate_cache_and_stream(
- txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
- )
+
+ self._invalidate_cache_and_stream_bulk(
+ txn, self.get_e2e_unused_fallback_key_types, seen_user_device
+ )
return results
@@ -1376,14 +1374,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)
- seen_user_device: Set[Tuple[str, str]] = set()
- for user_id, device_id, _, _, _ in otk_rows:
- if (user_id, device_id) in seen_user_device:
- continue
- seen_user_device.add((user_id, device_id))
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
+ seen_user_device = {
+ (user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
+ }
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self.count_e2e_one_time_keys,
+ seen_user_device,
+ )
return otk_rows
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9c3eafb562..bd3c81827f 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
next_id = self._load_next_id_txn(txn)
- txn.call_after(self._mark_id_as_finished, next_id)
- txn.call_on_exception(self._mark_id_as_finished, next_id)
+ txn.call_after(self._mark_ids_as_finished, [next_id])
+ txn.call_on_exception(self._mark_ids_as_finished, [next_id])
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
@@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return self._return_factor * next_id
- def _mark_id_as_finished(self, next_id: int) -> None:
- """The ID has finished being processed so we should advance the
+ def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next_txn(txn)
+ # ... persist event ...
+ """
+
+ # If we have a list of instances that are allowed to write to this
+ # stream, make sure we're in it.
+ if self._writers and self._instance_name not in self._writers:
+ raise Exception("Tried to allocate stream ID on non-writer")
+
+ next_ids = self._load_next_mult_id_txn(txn, n)
+
+ txn.call_after(self._mark_ids_as_finished, next_ids)
+ txn.call_on_exception(self._mark_ids_as_finished, next_ids)
+ txn.call_after(self._notifier.notify_replication)
+
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persisted row with the correct instance name.
+ if self._writers:
+ txn.call_after(
+ run_as_background_process,
+ "MultiWriterIdGenerator._update_table",
+ self._db.runInteraction,
+ "MultiWriterIdGenerator._update_table",
+ self._update_stream_positions_table_txn,
+ )
+
+ return [self._return_factor * next_id for next_id in next_ids]
+
+ def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
+ """These IDs have finished being processed so we should advance the
current position if possible.
"""
with self._lock:
- self._unfinished_ids.discard(next_id)
- self._finished_ids.add(next_id)
+ self._unfinished_ids.difference_update(next_ids)
+ self._finished_ids.update(next_ids)
new_cur: Optional[int] = None
@@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
curr, new_cur, self._max_position_of_local_instance
)
- self._add_persisted_position(next_id)
+ # TODO Can we call this for just the last position or somehow batch
+ # _add_persisted_position.
+ for next_id in next_ids:
+ self._add_persisted_position(next_id)
def get_current_token(self) -> int:
return self.get_persisted_upto_position()
@@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
- for i in self.stream_ids:
- self.id_gen._mark_id_as_finished(i)
+ self.id_gen._mark_ids_as_finished(self.stream_ids)
self.notifier.notify_replication()
diff --git a/tests/storage/databases/main/test_cache.py b/tests/storage/databases/main/test_cache.py
new file mode 100644
index 0000000000..3f71f5d102
--- /dev/null
+++ b/tests/storage/databases/main/test_cache.py
@@ -0,0 +1,117 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock, call
+
+from synapse.storage.database import LoggingTransaction
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import HomeserverTestCase
+
+
+class CacheInvalidationTestCase(HomeserverTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+ self.store = self.hs.get_datastores().main
+
+ def test_bulk_invalidation(self) -> None:
+ master_invalidate = Mock()
+
+ self.store._get_cached_user_device.invalidate = master_invalidate
+
+ keys_to_invalidate = [
+ ("a", "b"),
+ ("c", "d"),
+ ("e", "f"),
+ ("g", "h"),
+ ]
+
+ def test_txn(txn: LoggingTransaction) -> None:
+ self.store._invalidate_cache_and_stream_bulk(
+ txn,
+ # This is an arbitrarily chosen cached store function. It was chosen
+ # because it takes more than one argument. We'll use this later to
+ # check that the invalidation was actioned over replication.
+ cache_func=self.store._get_cached_user_device,
+ key_tuples=keys_to_invalidate,
+ )
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test_invalidate_cache_and_stream_bulk", test_txn
+ )
+ )
+
+ master_invalidate.assert_has_calls(
+ [call(key_list) for key_list in keys_to_invalidate],
+ any_order=True,
+ )
+
+
+class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+ self.store = self.hs.get_datastores().main
+
+ def test_bulk_invalidation_replicates(self) -> None:
+ """Like test_bulk_invalidation, but also checks the invalidations replicate."""
+ master_invalidate = Mock()
+ worker_invalidate = Mock()
+
+ self.store._get_cached_user_device.invalidate = master_invalidate
+ worker = self.make_worker_hs("synapse.app.generic_worker")
+ worker_ds = worker.get_datastores().main
+ worker_ds._get_cached_user_device.invalidate = worker_invalidate
+
+ keys_to_invalidate = [
+ ("a", "b"),
+ ("c", "d"),
+ ("e", "f"),
+ ("g", "h"),
+ ]
+
+ def test_txn(txn: LoggingTransaction) -> None:
+ self.store._invalidate_cache_and_stream_bulk(
+ txn,
+ # This is an arbitrarily chosen cached store function. It was chosen
+ # because it takes more than one argument. We'll use this later to
+ # check that the invalidation was actioned over replication.
+ cache_func=self.store._get_cached_user_device,
+ key_tuples=keys_to_invalidate,
+ )
+
+ assert self.store._cache_id_gen is not None
+ initial_token = self.store._cache_id_gen.get_current_token()
+ self.get_success(
+ self.database_pool.runInteraction(
+ "test_invalidate_cache_and_stream_bulk", test_txn
+ )
+ )
+ second_token = self.store._cache_id_gen.get_current_token()
+
+ self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))
+
+ self.get_success(
+ worker.get_replication_data_handler().wait_for_stream_position(
+ "master", "caches", second_token
+ )
+ )
+
+ master_invalidate.assert_has_calls(
+ [call(key_list) for key_list in keys_to_invalidate],
+ any_order=True,
+ )
+ worker_invalidate.assert_has_calls(
+ [call(key_list) for key_list in keys_to_invalidate],
+ any_order=True,
+ )
|