summary refs log tree commit diff
path: root/synapse/replication/tcp/client.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-05-22 14:21:54 +0100
committerGitHub <noreply@github.com>2020-05-22 14:21:54 +0100
commit1531b214fc57714c14046a8f66c7b5fe5ec5dcdd (patch)
treefd150f21a14dcc6f5cf3f373984478dd12d43c95 /synapse/replication/tcp/client.py
parentConvert sending mail to async/await. (#7557) (diff)
downloadsynapse-1531b214fc57714c14046a8f66c7b5fe5ec5dcdd.tar.xz
Add ability to wait for replication streams (#7542)
The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room).

Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on.

People probably want to look at this commit by commit.
Diffstat (limited to 'synapse/replication/tcp/client.py')
-rw-r--r--synapse/replication/tcp/client.py90
1 files changed, 88 insertions, 2 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 28826302f5..508ad1b720 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,19 +14,23 @@
 # limitations under the License.
 """A replication client for use by synapse workers.
 """
-
+import heapq
 import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
 
+from twisted.internet.defer import Deferred
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.streams.events import (
     EventsStream,
     EventsStreamEventRow,
     EventsStreamRow,
 )
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -35,6 +39,10 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+# How long we allow callers to wait for replication updates before timing out.
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+
+
 class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
     """Factory for building connections to the master. Will reconnect if the
     connection is lost.
@@ -92,6 +100,16 @@ class ReplicationDataHandler:
         self.store = hs.get_datastore()
         self.pusher_pool = hs.get_pusherpool()
         self.notifier = hs.get_notifier()
+        self._reactor = hs.get_reactor()
+        self._clock = hs.get_clock()
+        self._streams = hs.get_replication_streams()
+        self._instance_name = hs.get_instance_name()
+
+        # Map from stream to list of deferreds waiting for the stream to
+        # arrive at a particular position. The lists are sorted by stream position.
+        self._streams_to_waiters = (
+            {}
+        )  # type: Dict[str, List[Tuple[int, Deferred[None]]]]
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
@@ -131,8 +149,76 @@ class ReplicationDataHandler:
 
             await self.pusher_pool.on_new_notifications(token, token)
 
+        # Notify any waiting deferreds. The list is ordered by position so we
+        # just iterate through the list until we reach a position that is
+        # greater than the received row position.
+        waiting_list = self._streams_to_waiters.get(stream_name, [])
+
+        # Index of first item with a position after the current token, i.e we
+        # have called all deferreds before this index. If not overwritten by
+        # loop below means either a) no items in list so no-op or b) all items
+        # in list were called and so the list should be cleared. Setting it to
+        # `len(list)` works for both cases.
+        index_of_first_deferred_not_called = len(waiting_list)
+
+        for idx, (position, deferred) in enumerate(waiting_list):
+            if position <= token:
+                try:
+                    with PreserveLoggingContext():
+                        deferred.callback(None)
+                except Exception:
+                    # The deferred has been cancelled or timed out.
+                    pass
+            else:
+                # The list is sorted by position so we don't need to continue
+                # checking any futher entries in the list.
+                index_of_first_deferred_not_called = idx
+                break
+
+        # Drop all entries in the waiting list that were called in the above
+        # loop. (This maintains the order so no need to resort)
+        waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+
     async def on_position(self, stream_name: str, instance_name: str, token: int):
         self.store.process_replication_rows(stream_name, instance_name, token, [])
 
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
+
+    async def wait_for_stream_position(
+        self, instance_name: str, stream_name: str, position: int
+    ):
+        """Wait until this instance has received updates up to and including
+        the given stream position.
+        """
+
+        if instance_name == self._instance_name:
+            # We don't get told about updates written by this process, and
+            # anyway in that case we don't need to wait.
+            return
+
+        current_position = self._streams[stream_name].current_token(self._instance_name)
+        if position <= current_position:
+            # We're already past the position
+            return
+
+        # Create a new deferred that times out after N seconds, as we don't want
+        # to wedge here forever.
+        deferred = Deferred()
+        deferred = timeout_deferred(
+            deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
+        )
+
+        waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
+
+        # We insert into the list using heapq as it is more efficient than
+        # pushing then resorting each time.
+        heapq.heappush(waiting_list, (position, deferred))
+
+        # We measure here to get in flight counts and average waiting time.
+        with Measure(self._clock, "repl.wait_for_stream_position"):
+            logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
+            await make_deferred_yieldable(deferred)
+            logger.info(
+                "Finished waiting for repl stream %r to reach %s", stream_name, position
+            )