summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py31
1 files changed, 24 insertions, 7 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py

index 139f57cf86..3b88dc68ea 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,9 @@ """A replication client for use by synapse workers. """ import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple + +from sortedcontainers import SortedList from twisted.internet import defer from twisted.internet.defer import Deferred @@ -26,6 +28,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import ( AccountDataStream, + CachesStream, DeviceListsStream, PushersStream, PushRulesStream, @@ -73,6 +76,7 @@ class ReplicationDataHandler: self._instance_name = hs.get_instance_name() self._typing_handler = hs.get_typing_handler() self._state_storage_controller = hs.get_storage_controllers().state + self.auth = hs.get_auth() self._notify_pushers = hs.config.worker.start_pushers self._pusher_pool = hs.get_pusherpool() @@ -84,7 +88,9 @@ class ReplicationDataHandler: # Map from stream and instance to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. - self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {} + self._streams_to_waiters: Dict[ + Tuple[str, str], SortedList[Tuple[int, Deferred]] + ] = {} async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -218,6 +224,16 @@ class ReplicationDataHandler: self._state_storage_controller.notify_event_un_partial_stated( row.event_id ) + # invalidate the introspection token cache + elif stream_name == CachesStream.NAME: + for row in rows: + if row.cache_func == "introspection_token_invalidation": + if row.keys[0] is None: + # invalidate the whole cache + # mypy ignore - the token cache is defined on MSC3861DelegatedAuth + self.auth.invalidate_token_cache() # type: ignore[attr-defined] + else: + self.auth.invalidate_cached_tokens(row.keys) # type: ignore[attr-defined] await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows @@ -226,7 +242,9 @@ class ReplicationDataHandler: # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is # greater than the received row position. - waiting_list = self._streams_to_waiters.get((stream_name, instance_name), []) + waiting_list = self._streams_to_waiters.get((stream_name, instance_name)) + if not waiting_list: + return # Index of first item with a position after the current token, i.e we # have called all deferreds before this index. If not overwritten by @@ -250,7 +268,7 @@ class ReplicationDataHandler: # Drop all entries in the waiting list that were called in the above # loop. (This maintains the order so no need to resort) - waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + del waiting_list[:index_of_first_deferred_not_called] for deferred in deferreds_to_callback: try: @@ -310,11 +328,10 @@ class ReplicationDataHandler: ) waiting_list = self._streams_to_waiters.setdefault( - (stream_name, instance_name), [] + (stream_name, instance_name), SortedList(key=lambda t: t[0]) ) - waiting_list.append((position, deferred)) - waiting_list.sort(key=lambda t: t[0]) + waiting_list.add((position, deferred)) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"):