summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py129
-rw-r--r--synapse/replication/tcp/streams/events.py8
-rw-r--r--synapse/replication/tcp/streams/federation.py15
-rw-r--r--synapse/replication/tcp/streams/partial_state.py10
-rw-r--r--synapse/storage/util/id_generators.py68
6 files changed, 170 insertions, 62 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 63cf24a14d..7476839db5 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
                     data[_STREAM_POSITION_KEY] = {
                         "streams": {
-                            stream.NAME: stream.current_token(local_instance_name)
+                            stream.NAME: stream.minimal_local_current_token()
                             for stream in streams
                         },
                         "instance_name": local_instance_name,
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c6088a0f99..5c4d228f3d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 
 logger = logging.getLogger(__name__)
 
@@ -107,22 +108,10 @@ class Stream:
     def __init__(
         self,
         local_instance_name: str,
-        current_token_function: Callable[[str], Token],
         update_function: UpdateFunction,
     ):
         """Instantiate a Stream
 
-        `current_token_function` and `update_function` are callbacks which
-        should be implemented by subclasses.
-
-        `current_token_function` takes an instance name, which is a writer to
-        the stream, and returns the position in the stream of the writer (as
-        viewed from the current process). On the writer process this is where
-        the writer has successfully written up to, whereas on other processes
-        this is the position which we have received updates up to over
-        replication. (Note that most streams have a single writer and so their
-        implementations ignore the instance name passed in).
-
         `update_function` is called to get updates for this stream between a
         pair of stream tokens. See the `UpdateFunction` type definition for more
         info.
@@ -133,12 +122,28 @@ class Stream:
             update_function: callback go get stream updates, as above
         """
         self.local_instance_name = local_instance_name
-        self.current_token = current_token_function
         self.update_function = update_function
 
         # The token from which we last asked for updates
         self.last_token = self.current_token(self.local_instance_name)
 
+    def current_token(self, instance_name: str) -> Token:
+        """This takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process).
+        """
+        # We can't make this an abstract class as it makes mypy unhappy.
+        raise NotImplementedError()
+
+    def minimal_local_current_token(self) -> Token:
+        """Tries to return a minimal current token for the local instance,
+        i.e. for writers this would be the last successful write.
+
+        If local instance is not a writer (or has written yet) then falls back
+        to returning the normal "current token".
+        """
+        raise NotImplementedError()
+
     def discard_updates_and_advance(self) -> None:
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
@@ -190,6 +195,25 @@ class Stream:
         return updates, upto_token, limited
 
 
+class _StreamFromIdGen(Stream):
+    """Helper class for simple streams that use a stream ID generator"""
+
+    def __init__(
+        self,
+        local_instance_name: str,
+        update_function: UpdateFunction,
+        stream_id_gen: "AbstractStreamIdGenerator",
+    ):
+        self._stream_id_gen = stream_id_gen
+        super().__init__(local_instance_name, update_function)
+
+    def current_token(self, instance_name: str) -> Token:
+        return self._stream_id_gen.get_current_token_for_writer(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self._stream_id_gen.get_minimal_local_current_token()
+
+
 def current_token_without_instance(
     current_token: Callable[[], int]
 ) -> Callable[[str], int]:
@@ -242,17 +266,21 @@ class BackfillStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            self._current_token,
             self.store.get_all_new_backfill_event_rows,
         )
 
-    def _current_token(self, instance_name: str) -> int:
+    def current_token(self, instance_name: str) -> Token:
         # The backfill stream over replication operates on *positive* numbers,
         # which means we need to negate it.
         return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
 
+    def minimal_local_current_token(self) -> Token:
+        # The backfill stream over replication operates on *positive* numbers,
+        # which means we need to negate it.
+        return -self.store._backfill_id_gen.get_minimal_local_current_token()
 
-class PresenceStream(Stream):
+
+class PresenceStream(_StreamFromIdGen):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class PresenceStreamRow:
         user_id: str
@@ -283,9 +311,7 @@ class PresenceStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(),
-            current_token_without_instance(store.get_current_presence_token),
-            update_function,
+            hs.get_instance_name(), update_function, store._presence_id_gen
         )
 
 
@@ -305,13 +331,18 @@ class PresenceFederationStream(Stream):
     ROW_TYPE = PresenceFederationStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        federation_queue = hs.get_presence_handler().get_federation_queue()
+        self._federation_queue = hs.get_presence_handler().get_federation_queue()
         super().__init__(
             hs.get_instance_name(),
-            federation_queue.get_current_token,
-            federation_queue.get_replication_rows,
+            self._federation_queue.get_replication_rows,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self._federation_queue.get_current_token(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self._federation_queue.get_current_token(self.local_instance_name)
+
 
 class TypingStream(Stream):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -341,20 +372,25 @@ class TypingStream(Stream):
             update_function: Callable[
                 [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
             ] = typing_writer_handler.get_all_typing_updates
-            current_token_function = typing_writer_handler.get_current_token
+            self.current_token_function = typing_writer_handler.get_current_token
         else:
             # Query the typing writer process
             update_function = make_http_update_function(hs, self.NAME)
-            current_token_function = hs.get_typing_handler().get_current_token
+            self.current_token_function = hs.get_typing_handler().get_current_token
 
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(current_token_function),
             update_function,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self.current_token_function()
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token_function()
 
-class ReceiptsStream(Stream):
+
+class ReceiptsStream(_StreamFromIdGen):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class ReceiptsStreamRow:
         room_id: str
@@ -371,12 +407,12 @@ class ReceiptsStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_max_receipt_stream_id),
             store.get_all_updated_receipts,
+            store._receipts_id_gen,
         )
 
 
-class PushRulesStream(Stream):
+class PushRulesStream(_StreamFromIdGen):
     """A user has changed their push rules"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -387,20 +423,16 @@ class PushRulesStream(Stream):
     ROW_TYPE = PushRulesStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastores().main
+        store = hs.get_datastores().main
 
         super().__init__(
             hs.get_instance_name(),
-            self._current_token,
-            self.store.get_all_push_rule_updates,
+            store.get_all_push_rule_updates,
+            store._push_rules_stream_id_gen,
         )
 
-    def _current_token(self, instance_name: str) -> int:
-        push_rules_token = self.store.get_max_push_rules_stream_id()
-        return push_rules_token
-
 
-class PushersStream(Stream):
+class PushersStream(_StreamFromIdGen):
     """A user has added/changed/removed a pusher"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -418,8 +450,8 @@ class PushersStream(Stream):
 
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_pushers_stream_token),
             store.get_all_updated_pushers_rows,
+            store._pushers_id_gen,
         )
 
 
@@ -447,15 +479,20 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
+        self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token_for_writer,
-            store.get_all_updated_caches,
+            self.store.get_all_updated_caches,
         )
 
+    def current_token(self, instance_name: str) -> Token:
+        return self.store.get_cache_stream_token_for_writer(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token(self.local_instance_name)
+
 
-class DeviceListsStream(Stream):
+class DeviceListsStream(_StreamFromIdGen):
     """Either a user has updated their devices or a remote server needs to be
     told about a device update.
     """
@@ -473,8 +510,8 @@ class DeviceListsStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self.store.get_device_stream_token),
             self._update_function,
+            self.store._device_list_id_gen,
         )
 
     async def _update_function(
@@ -525,7 +562,7 @@ class DeviceListsStream(Stream):
         return updates, upper_limit_token, devices_limited or signatures_limited
 
 
-class ToDeviceStream(Stream):
+class ToDeviceStream(_StreamFromIdGen):
     """New to_device messages for a client"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -539,12 +576,12 @@ class ToDeviceStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_to_device_stream_token),
             store.get_all_new_device_messages,
+            store._device_inbox_id_gen,
         )
 
 
-class AccountDataStream(Stream):
+class AccountDataStream(_StreamFromIdGen):
     """Global or per room account data was changed"""
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -560,8 +597,8 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self.store.get_max_account_data_stream_id),
             self._update_function,
+            self.store._account_data_id_gen,
         )
 
     async def _update_function(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index da6d948e1b..38823113d8 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
 import attr
 
 from synapse.replication.tcp.streams._base import (
-    Stream,
     StreamRow,
     StreamUpdateResult,
     Token,
+    _StreamFromIdGen,
 )
 
 if TYPE_CHECKING:
@@ -139,7 +139,7 @@ _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
 TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
 
-class EventsStream(Stream):
+class EventsStream(_StreamFromIdGen):
     """We received a new event, or an event went from being an outlier to not"""
 
     NAME = "events"
@@ -147,9 +147,7 @@ class EventsStream(Stream):
     def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastores().main
         super().__init__(
-            hs.get_instance_name(),
-            self._store._stream_id_gen.get_current_token_for_writer,
-            self._update_function,
+            hs.get_instance_name(), self._update_function, self._store._stream_id_gen
         )
 
     async def _update_function(
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 4046bdec69..7f5af5852c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -18,6 +18,7 @@ import attr
 
 from synapse.replication.tcp.streams._base import (
     Stream,
+    Token,
     current_token_without_instance,
     make_http_update_function,
 )
@@ -47,7 +48,7 @@ class FederationStream(Stream):
             # will be a real FederationSender, which has stubs for current_token and
             # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
-            current_token = current_token_without_instance(
+            self.current_token_func = current_token_without_instance(
                 federation_sender.get_current_token
             )
             update_function: Callable[
@@ -57,15 +58,21 @@ class FederationStream(Stream):
         elif hs.should_send_federation():
             # federation sender: Query master process
             update_function = make_http_update_function(hs, self.NAME)
-            current_token = self._stub_current_token
+            self.current_token_func = self._stub_current_token
 
         else:
             # other worker: stub out the update function (we're not interested in
             # any updates so when we get a POSITION we do nothing)
             update_function = self._stub_update_function
-            current_token = self._stub_current_token
+            self.current_token_func = self._stub_current_token
 
-        super().__init__(hs.get_instance_name(), current_token, update_function)
+        super().__init__(hs.get_instance_name(), update_function)
+
+    def current_token(self, instance_name: str) -> Token:
+        return self.current_token_func(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token(self.local_instance_name)
 
     @staticmethod
     def _stub_current_token(instance_name: str) -> int:
diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py
index a8ce5ffd72..ad181d7e93 100644
--- a/synapse/replication/tcp/streams/partial_state.py
+++ b/synapse/replication/tcp/streams/partial_state.py
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
 
 import attr
 
-from synapse.replication.tcp.streams import Stream
+from synapse.replication.tcp.streams._base import _StreamFromIdGen
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
     room_id: str
 
 
-class UnPartialStatedRoomStream(Stream):
+class UnPartialStatedRoomStream(_StreamFromIdGen):
     """
     Stream to notify about rooms becoming un-partial-stated;
     that is, when the background sync finishes such that we now have full state for
@@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_un_partial_stated_rooms_token,
             store.get_un_partial_stated_rooms_from_stream,
+            store._un_partial_stated_rooms_stream_id_gen,
         )
 
 
@@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
     rejection_status_changed: bool
 
 
-class UnPartialStatedEventStream(Stream):
+class UnPartialStatedEventStream(_StreamFromIdGen):
     """
     Stream to notify about events becoming un-partial-stated.
     """
@@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
         store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            store.get_un_partial_stated_events_token,
             store.get_un_partial_stated_events_from_stream,
+            store._un_partial_stated_events_stream_id_gen,
         )
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d2c874b9a8..9c3eafb562 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -134,6 +134,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    def get_minimal_local_current_token(self) -> int:
+        """Tries to return a minimal current token for the local instance,
+        i.e. for writers this would be the last successful write.
+
+        If local instance is not a writer (or has written yet) then falls back
+        to returning the normal "current token".
+        """
+
+    @abc.abstractmethod
     def get_next(self) -> AsyncContextManager[int]:
         """
         Usage:
@@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
     def get_current_token_for_writer(self, instance_name: str) -> int:
         return self.get_current_token()
 
+    def get_minimal_local_current_token(self) -> int:
+        return self.get_current_token()
+
 
 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
     """Generates and tracks stream IDs for a stream with multiple writers.
@@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # The maximum stream ID that we have seen been allocated across any writer.
         self._max_seen_allocated_stream_id = 1
 
+        # The maximum position of the local instance. This can be higher than
+        # the corresponding position in `current_positions` table when there are
+        # no active writes in progress.
+        self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
         # We check that the table and sequence haven't diverged.
@@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             self._current_positions.values(), default=1
         )
 
+        # For the case where `stream_positions` is not up to date,
+        # `_persisted_upto_position` may be higher.
+        self._max_seen_allocated_stream_id = max(
+            self._max_seen_allocated_stream_id, self._persisted_upto_position
+        )
+
+        # Bump our local maximum position now that we've loaded things from the
+        # DB.
+        self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
         if not writers:
             # If there have been no explicit writers given then any instance can
             # write to the stream. In which case, let's pre-seed our own
@@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                     if instance == self._instance_name:
                         self._current_positions[instance] = stream_id
 
+        if self._writers:
+            # If we have explicit writers then make sure that each instance has
+            # a position.
+            for writer in self._writers:
+                self._current_positions.setdefault(
+                    writer, self._persisted_upto_position
+                )
+
         cur.close()
 
     def _load_next_id_txn(self, txn: Cursor) -> int:
@@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             if new_cur:
                 curr = self._current_positions.get(self._instance_name, 0)
                 self._current_positions[self._instance_name] = max(curr, new_cur)
+                self._max_position_of_local_instance = max(
+                    curr, new_cur, self._max_position_of_local_instance
+                )
 
             self._add_persisted_position(next_id)
 
@@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # persisted up to position. This stops Synapse from doing a full table
         # scan when a new writer announces itself over replication.
         with self._lock:
-            return self._return_factor * self._current_positions.get(
+            if self._instance_name == instance_name:
+                return self._return_factor * self._max_position_of_local_instance
+
+            pos = self._current_positions.get(
                 instance_name, self._persisted_upto_position
             )
 
+            # We want to return the maximum "current token" that we can for a
+            # writer, this helps ensure that streams progress as fast as
+            # possible.
+            pos = max(pos, self._persisted_upto_position)
+
+            return self._return_factor * pos
+
+    def get_minimal_local_current_token(self) -> int:
+        with self._lock:
+            return self._return_factor * self._current_positions.get(
+                self._instance_name, self._persisted_upto_position
+            )
+
     def get_positions(self) -> Dict[str, int]:
         """Get a copy of the current positon map.
 
@@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
 
+        # Advance our local max position.
+        self._max_position_of_local_instance = max(
+            self._max_position_of_local_instance, self._persisted_upto_position
+        )
+
+        if not self._unfinished_ids and not self._in_flight_fetches:
+            # If we don't have anything in flight, it's safe to advance to the
+            # max seen stream ID.
+            self._max_position_of_local_instance = max(
+                self._max_seen_allocated_stream_id, self._max_position_of_local_instance
+            )
+
         # We now iterate through the seen positions, discarding those that are
         # less than the current min positions, and incrementing the min position
         # if its exactly one greater.