diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index d5337fe588..384355698d 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -279,14 +279,6 @@ class ReplicationDataHandler:
# may be streaming.
self.notifier.notify_replication()
- def on_remote_server_up(self, server: str) -> None:
- """Called when get a new REMOTE_SERVER_UP command."""
-
- # Let's wake up the transaction queue for the server in case we have
- # pending stuff to send to it.
- if self.send_handler:
- self.send_handler.wake_destination(server)
-
async def wait_for_stream_position(
self,
instance_name: str,
@@ -405,9 +397,6 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
- def wake_destination(self, server: str) -> None:
- self.federation_sender.wake_destination(server)
-
async def process_replication_rows(
self, stream_name: str, token: int, rows: list
) -> None:
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b668bb5da1..1d586fb180 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -657,8 +657,6 @@ class ReplicationCommandHandler:
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""
- self._replication_data_handler.on_remote_server_up(cmd.data)
-
self._notifier.notify_remote_server_up(cmd.data)
def on_LOCK_RELEASED(
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c6088a0f99..5c4d228f3d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.util.id_generators import AbstractStreamIdGenerator
logger = logging.getLogger(__name__)
@@ -107,22 +108,10 @@ class Stream:
def __init__(
self,
local_instance_name: str,
- current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
- `current_token_function` and `update_function` are callbacks which
- should be implemented by subclasses.
-
- `current_token_function` takes an instance name, which is a writer to
- the stream, and returns the position in the stream of the writer (as
- viewed from the current process). On the writer process this is where
- the writer has successfully written up to, whereas on other processes
- this is the position which we have received updates up to over
- replication. (Note that most streams have a single writer and so their
- implementations ignore the instance name passed in).
-
`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
@@ -133,12 +122,28 @@ class Stream:
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
- self.current_token = current_token_function
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name)
+ def current_token(self, instance_name: str) -> Token:
+ """This takes an instance name, which is a writer to
+ the stream, and returns the position in the stream of the writer (as
+ viewed from the current process).
+ """
+ # We can't make this an abstract class as it makes mypy unhappy.
+ raise NotImplementedError()
+
+ def minimal_local_current_token(self) -> Token:
+ """Tries to return a minimal current token for the local instance,
+ i.e. for writers this would be the last successful write.
+
+ If local instance is not a writer (or has written yet) then falls back
+ to returning the normal "current token".
+ """
+ raise NotImplementedError()
+
def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
@@ -190,6 +195,25 @@ class Stream:
return updates, upto_token, limited
+class _StreamFromIdGen(Stream):
+ """Helper class for simple streams that use a stream ID generator"""
+
+ def __init__(
+ self,
+ local_instance_name: str,
+ update_function: UpdateFunction,
+ stream_id_gen: "AbstractStreamIdGenerator",
+ ):
+ self._stream_id_gen = stream_id_gen
+ super().__init__(local_instance_name, update_function)
+
+ def current_token(self, instance_name: str) -> Token:
+ return self._stream_id_gen.get_current_token_for_writer(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self._stream_id_gen.get_minimal_local_current_token()
+
+
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
@@ -242,17 +266,21 @@ class BackfillStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- self._current_token,
self.store.get_all_new_backfill_event_rows,
)
- def _current_token(self, instance_name: str) -> int:
+ def current_token(self, instance_name: str) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
+ def minimal_local_current_token(self) -> Token:
+ # The backfill stream over replication operates on *positive* numbers,
+ # which means we need to negate it.
+ return -self.store._backfill_id_gen.get_minimal_local_current_token()
-class PresenceStream(Stream):
+
+class PresenceStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow:
user_id: str
@@ -283,9 +311,7 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_current_presence_token),
- update_function,
+ hs.get_instance_name(), update_function, store._presence_id_gen
)
@@ -305,13 +331,18 @@ class PresenceFederationStream(Stream):
ROW_TYPE = PresenceFederationStreamRow
def __init__(self, hs: "HomeServer"):
- federation_queue = hs.get_presence_handler().get_federation_queue()
+ self._federation_queue = hs.get_presence_handler().get_federation_queue()
super().__init__(
hs.get_instance_name(),
- federation_queue.get_current_token,
- federation_queue.get_replication_rows,
+ self._federation_queue.get_replication_rows,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self._federation_queue.get_current_token(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self._federation_queue.get_current_token(self.local_instance_name)
+
class TypingStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -341,20 +372,25 @@ class TypingStream(Stream):
update_function: Callable[
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
] = typing_writer_handler.get_all_typing_updates
- current_token_function = typing_writer_handler.get_current_token
+ self.current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
- current_token_function = hs.get_typing_handler().get_current_token
+ self.current_token_function = hs.get_typing_handler().get_current_token
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(current_token_function),
update_function,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self.current_token_function()
+
+ def minimal_local_current_token(self) -> Token:
+ return self.current_token_function()
-class ReceiptsStream(Stream):
+
+class ReceiptsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow:
room_id: str
@@ -371,12 +407,12 @@ class ReceiptsStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_max_receipt_stream_id),
store.get_all_updated_receipts,
+ store._receipts_id_gen,
)
-class PushRulesStream(Stream):
+class PushRulesStream(_StreamFromIdGen):
"""A user has changed their push rules"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -387,20 +423,16 @@ class PushRulesStream(Stream):
ROW_TYPE = PushRulesStreamRow
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastores().main
+ store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- self._current_token,
- self.store.get_all_push_rule_updates,
+ store.get_all_push_rule_updates,
+ store._push_rules_stream_id_gen,
)
- def _current_token(self, instance_name: str) -> int:
- push_rules_token = self.store.get_max_push_rules_stream_id()
- return push_rules_token
-
-class PushersStream(Stream):
+class PushersStream(_StreamFromIdGen):
"""A user has added/changed/removed a pusher"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -418,8 +450,8 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_pushers_stream_token),
store.get_all_updated_pushers_rows,
+ store._pushers_id_gen,
)
@@ -447,15 +479,20 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
+ self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token_for_writer,
- store.get_all_updated_caches,
+ self.store.get_all_updated_caches,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self.store.get_cache_stream_token_for_writer(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self.current_token(self.local_instance_name)
+
-class DeviceListsStream(Stream):
+class DeviceListsStream(_StreamFromIdGen):
"""Either a user has updated their devices or a remote server needs to be
told about a device update.
"""
@@ -473,8 +510,8 @@ class DeviceListsStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(self.store.get_device_stream_token),
self._update_function,
+ self.store._device_list_id_gen,
)
async def _update_function(
@@ -525,7 +562,7 @@ class DeviceListsStream(Stream):
return updates, upper_limit_token, devices_limited or signatures_limited
-class ToDeviceStream(Stream):
+class ToDeviceStream(_StreamFromIdGen):
"""New to_device messages for a client"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -539,12 +576,12 @@ class ToDeviceStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_to_device_stream_token),
store.get_all_new_device_messages,
+ store._device_inbox_id_gen,
)
-class AccountDataStream(Stream):
+class AccountDataStream(_StreamFromIdGen):
"""Global or per room account data was changed"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -560,8 +597,8 @@ class AccountDataStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(self.store.get_max_account_data_stream_id),
self._update_function,
+ self.store._account_data_id_gen,
)
async def _update_function(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index da6d948e1b..38823113d8 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr
from synapse.replication.tcp.streams._base import (
- Stream,
StreamRow,
StreamUpdateResult,
Token,
+ _StreamFromIdGen,
)
if TYPE_CHECKING:
@@ -139,7 +139,7 @@ _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
-class EventsStream(Stream):
+class EventsStream(_StreamFromIdGen):
"""We received a new event, or an event went from being an outlier to not"""
NAME = "events"
@@ -147,9 +147,7 @@ class EventsStream(Stream):
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
super().__init__(
- hs.get_instance_name(),
- self._store._stream_id_gen.get_current_token_for_writer,
- self._update_function,
+ hs.get_instance_name(), self._update_function, self._store._stream_id_gen
)
async def _update_function(
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 4046bdec69..7f5af5852c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -18,6 +18,7 @@ import attr
from synapse.replication.tcp.streams._base import (
Stream,
+ Token,
current_token_without_instance,
make_http_update_function,
)
@@ -47,7 +48,7 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
- current_token = current_token_without_instance(
+ self.current_token_func = current_token_without_instance(
federation_sender.get_current_token
)
update_function: Callable[
@@ -57,15 +58,21 @@ class FederationStream(Stream):
elif hs.should_send_federation():
# federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME)
- current_token = self._stub_current_token
+ self.current_token_func = self._stub_current_token
else:
# other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
- current_token = self._stub_current_token
+ self.current_token_func = self._stub_current_token
- super().__init__(hs.get_instance_name(), current_token, update_function)
+ super().__init__(hs.get_instance_name(), update_function)
+
+ def current_token(self, instance_name: str) -> Token:
+ return self.current_token_func(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self.current_token(self.local_instance_name)
@staticmethod
def _stub_current_token(instance_name: str) -> int:
diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py
index a8ce5ffd72..ad181d7e93 100644
--- a/synapse/replication/tcp/streams/partial_state.py
+++ b/synapse/replication/tcp/streams/partial_state.py
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
import attr
-from synapse.replication.tcp.streams import Stream
+from synapse.replication.tcp.streams._base import _StreamFromIdGen
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
room_id: str
-class UnPartialStatedRoomStream(Stream):
+class UnPartialStatedRoomStream(_StreamFromIdGen):
"""
Stream to notify about rooms becoming un-partial-stated;
that is, when the background sync finishes such that we now have full state for
@@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_un_partial_stated_rooms_token,
store.get_un_partial_stated_rooms_from_stream,
+ store._un_partial_stated_rooms_stream_id_gen,
)
@@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
rejection_status_changed: bool
-class UnPartialStatedEventStream(Stream):
+class UnPartialStatedEventStream(_StreamFromIdGen):
"""
Stream to notify about events becoming un-partial-stated.
"""
@@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_un_partial_stated_events_token,
store.get_un_partial_stated_events_from_stream,
+ store._un_partial_stated_events_stream_id_gen,
)
|