diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 139f57cf86..3b88dc68ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,9 @@
"""A replication client for use by synapse workers.
"""
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple
+
+from sortedcontainers import SortedList
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -26,6 +28,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import (
AccountDataStream,
+ CachesStream,
DeviceListsStream,
PushersStream,
PushRulesStream,
@@ -73,6 +76,7 @@ class ReplicationDataHandler:
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
self._state_storage_controller = hs.get_storage_controllers().state
+ self.auth = hs.get_auth()
self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool()
@@ -84,7 +88,9 @@ class ReplicationDataHandler:
# Map from stream and instance to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
- self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {}
+ self._streams_to_waiters: Dict[
+ Tuple[str, str], SortedList[Tuple[int, Deferred]]
+ ] = {}
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -218,6 +224,16 @@ class ReplicationDataHandler:
self._state_storage_controller.notify_event_un_partial_stated(
row.event_id
)
+ # invalidate the introspection token cache
+ elif stream_name == CachesStream.NAME:
+ for row in rows:
+ if row.cache_func == "introspection_token_invalidation":
+ if row.keys[0] is None:
+ # invalidate the whole cache
+ # mypy ignore - the token cache is defined on MSC3861DelegatedAuth
+ self.auth.invalidate_token_cache() # type: ignore[attr-defined]
+ else:
+ self.auth.invalidate_cached_tokens(row.keys) # type: ignore[attr-defined]
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
@@ -226,7 +242,9 @@ class ReplicationDataHandler:
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
- waiting_list = self._streams_to_waiters.get((stream_name, instance_name), [])
+ waiting_list = self._streams_to_waiters.get((stream_name, instance_name))
+ if not waiting_list:
+ return
# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
@@ -250,7 +268,7 @@ class ReplicationDataHandler:
# Drop all entries in the waiting list that were called in the above
# loop. (This maintains the order so no need to resort)
- waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+ del waiting_list[:index_of_first_deferred_not_called]
for deferred in deferreds_to_callback:
try:
@@ -310,11 +328,10 @@ class ReplicationDataHandler:
)
waiting_list = self._streams_to_waiters.setdefault(
- (stream_name, instance_name), []
+ (stream_name, instance_name), SortedList(key=lambda t: t[0])
)
- waiting_list.append((position, deferred))
- waiting_list.sort(key=lambda t: t[0])
+ waiting_list.add((position, deferred))
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
|