summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2023-10-23 16:57:30 +0100
committerGitHub <noreply@github.com>2023-10-23 16:57:30 +0100
commit8f35f8148e1a7ce3ac249e2d2052854409f2c0d6 (patch)
tree72eb736c64970a416e07c123a4b350614dc1518d /synapse/replication
parentFix bug that could cause a `/sync` to tightloop with sqlite after restart (#1... (diff)
downloadsynapse-8f35f8148e1a7ce3ac249e2d2052854409f2c0d6.tar.xz
Fix bug where a new writer advances their token too quickly (#16473)
* Fix bug where a new writer advances their token too quickly

When starting a new writer (for e.g. persisting events), the
`MultiWriterIdGenerator` doesn't have a minimum token for it as there
are no rows matching that new writer in the DB.

This results in the the first stream ID it acquired being announced as
persisted *before* it actually finishes persisting, if another writer
gets and persists a subsequent stream ID. This is due to the logic of
setting the minimum persisted position to the minimum known position of
across all writers, and the new writer starts off not being considered.

* Fix sending out POSITIONs when our token advances without update

Broke in #14820

* For replication HTTP requests, only wait for minimal position
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py129
-rw-r--r--synapse/replication/tcp/streams/events.py8
-rw-r--r--synapse/replication/tcp/streams/federation.py15
-rw-r--r--synapse/replication/tcp/streams/partial_state.py10
5 files changed, 103 insertions, 61 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 63cf24a14d..7476839db5 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
                     data[_STREAM_POSITION_KEY] = {
                         "streams": {
-                            stream.NAME: stream.current_token(local_instance_name)
+                            stream.NAME: stream.minimal_local_current_token()
                             for stream in streams
                         },
                         "instance_name": local_instance_name,
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c6088a0f99..5c4d228f3d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 
 logger = logging.getLogger(__name__)
 
@@ -107,22 +108,10 @@ class Stream:
     def __init__(
         self,
         local_instance_name: str,
-        current_token_function: Callable[[str], Token],
         update_function: UpdateFunction,
     ):
         """Instantiate a Stream
 
-        `current_token_function` and `update_function` are callbacks which
-        should be implemented by subclasses.
-
-        `current_token_function` takes an instance name, which is a writer to
-        the stream, and returns the position in the stream of the writer (as
-        viewed from the current process). On the writer process this is where
-        the writer has successfully written up to, whereas on other processes
-        this is the position which we have received updates up to over
-        replication. (Note that most streams have a single writer and so their
-        implementations ignore the instance name passed in).
-
         `update_function` is called to get updates for this stream between a
         pair of stream tokens. See the `UpdateFunction` type definition for more
         info.
@@ -133,12 +122,28 @@ class Stream:
             update_function: callback go get stream updates, as above
         """
         self.local_instance_name = local_instance_name
-        self.current_token = current_token_function
         self.update_function = update_function
 
         # The token from which we last asked for updates
         self.last_token = self.current_token(self.local_instance_name)
 
+    def current_token(self, instance_name: str) -> Token:
+        """This takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process).
+        """
+        # We can't make this an abstract class as it makes mypy unhappy.
+        raise NotImplementedError()
+
+    def minimal_local_current_token(self) -> Token:
+        """Tries to return a minimal current token for the local instance,
+        i.e. for writers this would be the last successful write.
+
+        If local instance is not a writer (or has written yet) then falls back
+        to returning the normal "current token".
+        """
+        raise NotImplementedError()
+
     def discard_updates_and_advance(self) -> None:
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
@@ -190,6 +195,25 @@ class Stream:
         return updates, upto_token, limited
 
 
+class _StreamFromIdGen(Stream):
+    """Helper class for simple streams that use a stream ID generator"""
+
+    def __init__(
+        self,
+        local_instance_name: str,
+        update_function: UpdateFunction,
+        stream_id_gen: "AbstractStreamIdGenerator",
+    ):
+        self._stream_id_gen = stream_id_gen
+        super().__init__(local_instance_name, update_function)
+
+    def current_token(self, instance_name: str) -> Token:
+        return self._stream_id_gen.get_current_token_for_writer(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self._stream_id_gen.get_minimal_local_current_token()
+
+
 def current_token_without_instance(
     current_token: Callable[[], int]
 ) -> Callable[[str], int]:
@@ -242,17 +266,21 @@ class BackfillStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            self._current_token,
             self.store.get_all_new_backfill_event_rows,
         )
 
-    def _current_token(self, instance_name: str) -> int:
+    def current_token(self, instance_name: str) -> Token:
         # The backfill stream over replication operates on *positive* numbers,
         # which means we need to negate it.
         return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
 
+    def minimal_local_current_token(self) -> Token:
+        # The backfill stream over replication operates on *positive* numbers,
+        # which means we need to negate it.
+        return -self.store._backfill_id_gen.get_minimal_local_current_token()
 
-class PresenceStream(Stream):
+
+class PresenceStream(_StreamFromIdGen):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class PresenceStreamRow:
         user_id: str
@@ -283,9 +311,7 @@ class PresenceStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(),
-            current_token_without_instance(store.get_current_presence_token),
-            update_function,
+            hs.get_instance_name(), update_function, store._presence_id_gen
         )
 
 
@@ -305,13 +331,18 @@ class PresenceFederationStream(Stream):
     ROW_TYPE = PresenceFederationStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        federation_queue = hs.get_presence_handler().get_federation_queue()
+        self._federation_queue = hs.get_presence_handler().get_federation_queue()
         super().__init__(
             hs.get_instance_name(),
-            federation_queue.get_current_token,
-            federation_queue.get_replication_rows,
+            self._federation_queue.get_replication_rows,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self._federation_queue.get_current_token(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self._federation_queue.get_current_token(self.local_instance_name)
+
 
 class TypingStream(Stream):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -341,20 +372,25 @@ class TypingStream(Stream):
             update_function: Callable[
                 [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
             ] = typing_writer_handler.get_all_typing_updates
-            current_token_function = typing_writer_handler.get_current_token
+            self.current_token_function = typing_writer_handler.get_current_token
         else:
             # Query the typing writer process
             update_function = make_http_update_function(hs, self.NAME)
-            current_token_function = hs.get_typing_handler().get_current_token
+            self.current_token_function = hs.get_typing_handler().get_current_token
 
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(current_token_function),
             update_function,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self.current_token_function()
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token_function()
 
-class ReceiptsStream(Stream):
+
+class ReceiptsStream(_StreamFromIdGen):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class ReceiptsStreamRow:
         room_id: str
@@ -371,12 +407,12 @@ class ReceiptsStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_max_receipt_stream_id),
             store.get_all_updated_receipts,
+            store._receipts_id_gen,
         )
 
 
-class PushRulesStream(Stream):
+class PushRulesStream(_StreamFromIdGen):
     """A user has changed their push rules"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -387,20 +423,16 @@ class PushRulesStream(Stream):
     ROW_TYPE = PushRulesStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastores().main
+        store = hs.get_datastores().main
 
         super().__init__(
             hs.get_instance_name(),
-            self._current_token,
-            self.store.get_all_push_rule_updates,
+            store.get_all_push_rule_updates,
+            store._push_rules_stream_id_gen,
         )
 
-    def _current_token(self, instance_name: str) -> int:
-        push_rules_token = self.store.get_max_push_rules_stream_id()
-        return push_rules_token
-
 
-class PushersStream(Stream):
+class PushersStream(_StreamFromIdGen):
     """A user has added/changed/removed a pusher"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -418,8 +450,8 @@ class PushersStream(Stream):
 
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_pushers_stream_token),
             store.get_all_updated_pushers_rows,
+            store._pushers_id_gen,
         )
 
 
@@ -447,15 +479,20 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
+        self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token_for_writer,
-            store.get_all_updated_caches,
+            self.store.get_all_updated_caches,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self.store.get_cache_stream_token_for_writer(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token(self.local_instance_name)
+
 
-class DeviceListsStream(Stream):
+class DeviceListsStream(_StreamFromIdGen):
     """Either a user has updated their devices or a remote server needs to be
     told about a device update.
     """
@@ -473,8 +510,8 @@ class DeviceListsStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self.store.get_device_stream_token),
             self._update_function,
+            self.store._device_list_id_gen,
         )
 
     async def _update_function(
@@ -525,7 +562,7 @@ class DeviceListsStream(Stream):
         return updates, upper_limit_token, devices_limited or signatures_limited
 
 
-class ToDeviceStream(Stream):
+class ToDeviceStream(_StreamFromIdGen):
     """New to_device messages for a client"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -539,12 +576,12 @@ class ToDeviceStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_to_device_stream_token),
             store.get_all_new_device_messages,
+            store._device_inbox_id_gen,
         )
 
 
-class AccountDataStream(Stream):
+class AccountDataStream(_StreamFromIdGen):
     """Global or per room account data was changed"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -560,8 +597,8 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self.store.get_max_account_data_stream_id),
             self._update_function,
+            self.store._account_data_id_gen,
         )
 
     async def _update_function(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index da6d948e1b..38823113d8 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
 import attr
 
 from synapse.replication.tcp.streams._base import (
-    Stream,
     StreamRow,
     StreamUpdateResult,
     Token,
+    _StreamFromIdGen,
 )
 
 if TYPE_CHECKING:
@@ -139,7 +139,7 @@ _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
 TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
 
-class EventsStream(Stream):
+class EventsStream(_StreamFromIdGen):
     """We received a new event, or an event went from being an outlier to not"""
 
     NAME = "events"
@@ -147,9 +147,7 @@ class EventsStream(Stream):
     def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastores().main
         super().__init__(
-            hs.get_instance_name(),
-            self._store._stream_id_gen.get_current_token_for_writer,
-            self._update_function,
+            hs.get_instance_name(), self._update_function, self._store._stream_id_gen
         )
 
     async def _update_function(
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 4046bdec69..7f5af5852c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -18,6 +18,7 @@ import attr
 
 from synapse.replication.tcp.streams._base import (
     Stream,
+    Token,
     current_token_without_instance,
     make_http_update_function,
 )
@@ -47,7 +48,7 @@ class FederationStream(Stream):
             # will be a real FederationSender, which has stubs for current_token and
             # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
-            current_token = current_token_without_instance(
+            self.current_token_func = current_token_without_instance(
                 federation_sender.get_current_token
             )
             update_function: Callable[
@@ -57,15 +58,21 @@ class FederationStream(Stream):
         elif hs.should_send_federation():
             # federation sender: Query master process
             update_function = make_http_update_function(hs, self.NAME)
-            current_token = self._stub_current_token
+            self.current_token_func = self._stub_current_token
 
         else:
             # other worker: stub out the update function (we're not interested in
             # any updates so when we get a POSITION we do nothing)
             update_function = self._stub_update_function
-            current_token = self._stub_current_token
+            self.current_token_func = self._stub_current_token
 
-        super().__init__(hs.get_instance_name(), current_token, update_function)
+        super().__init__(hs.get_instance_name(), update_function)
+
+    def current_token(self, instance_name: str) -> Token:
+        return self.current_token_func(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token(self.local_instance_name)
 
     @staticmethod
     def _stub_current_token(instance_name: str) -> int:
diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py
index a8ce5ffd72..ad181d7e93 100644
--- a/synapse/replication/tcp/streams/partial_state.py
+++ b/synapse/replication/tcp/streams/partial_state.py
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
 
 import attr
 
-from synapse.replication.tcp.streams import Stream
+from synapse.replication.tcp.streams._base import _StreamFromIdGen
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
     room_id: str
 
 
-class UnPartialStatedRoomStream(Stream):
+class UnPartialStatedRoomStream(_StreamFromIdGen):
     """
     Stream to notify about rooms becoming un-partial-stated;
     that is, when the background sync finishes such that we now have full state for
@@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_un_partial_stated_rooms_token,
             store.get_un_partial_stated_rooms_from_stream,
+            store._un_partial_stated_rooms_stream_id_gen,
         )
 
 
@@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
     rejection_status_changed: bool
 
 
-class UnPartialStatedEventStream(Stream):
+class UnPartialStatedEventStream(_StreamFromIdGen):
     """
     Stream to notify about events becoming un-partial-stated.
     """
@@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_un_partial_stated_events_token,
             store.get_un_partial_stated_events_from_stream,
+            store._un_partial_stated_events_stream_id_gen,
         )