From 3085cde577216519d789c8160262831cb2029972 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 May 2020 15:21:35 +0100 Subject: Use `stream.current_token()` and remove `stream_positions()` (#7172) We move the processing of typing and federation replication traffic into their handlers so that `Stream.current_token()` points to a valid token. This allows us to remove `get_streams_to_replicate()` and `stream_positions()`. --- tests/replication/tcp/streams/test_typing.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'tests/replication/tcp/streams/test_typing.py') diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 4d354a9db8..e8d17ca68a 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the typing stream - self.test_handler.stream_positions.update({"typing": 0}) - typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) self.reactor.advance(0) -- cgit 1.5.1 From 0e719f23981b8294df66ba7f38b8c7cc99fad228 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 May 2020 17:19:56 +0100 Subject: Thread through instance name to replication client. (#7369) For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams. --- changelog.d/7369.misc | 1 + synapse/app/generic_worker.py | 10 +++--- synapse/replication/http/_base.py | 19 +++++++++- synapse/replication/http/streams.py | 4 ++- synapse/replication/tcp/client.py | 12 ++++--- synapse/replication/tcp/handler.py | 20 ++++++++--- synapse/replication/tcp/streams/_base.py | 50 +++++++++++++++++++------- synapse/replication/tcp/streams/events.py | 10 ++++-- synapse/replication/tcp/streams/federation.py | 4 +-- tests/replication/tcp/streams/_base.py | 4 +-- tests/replication/tcp/streams/test_receipts.py | 4 +-- tests/replication/tcp/streams/test_typing.py | 4 +-- 12 files changed, 101 insertions(+), 41 deletions(-) create mode 100644 changelog.d/7369.misc (limited to 'tests/replication/tcp/streams/test_typing.py') diff --git a/changelog.d/7369.misc b/changelog.d/7369.misc new file mode 100644 index 0000000000..060b09c888 --- /dev/null +++ b/changelog.d/7369.misc @@ -0,0 +1 @@ +Thread through instance name to replication client. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 97b9b81237..667ad20428 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -646,13 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): else: self.send_handler = None - async def on_rdata(self, stream_name, token, rows): - await super(GenericWorkerReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - await self.process_and_notify(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + await self._process_and_notify(stream_name, instance_name, token, rows) - async def process_and_notify(self, stream_name, token, rows): + async def _process_and_notify(self, stream_name, instance_name, token, rows): try: if self.send_handler: await self.send_handler.process_replication_rows( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 1be1ccbdf3..f88c80ae84 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,7 @@ import abc import logging import re +from inspect import signature from typing import Dict, List, Tuple from six import raise_from @@ -60,6 +61,8 @@ class ReplicationEndpoint(object): must call `register` to register the path with the HTTP server. Requests can be sent by calling the client returned by `make_client`. + Requests are sent to master process by default, but can be sent to other + named processes by specifying an `instance_name` keyword argument. Attributes: NAME (str): A name for the endpoint, added to the path as well as used @@ -91,6 +94,16 @@ class ReplicationEndpoint(object): hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) + # We reserve `instance_name` as a parameter to sending requests, so we + # assert here that sub classes don't try and use the name. + assert ( + "instance_name" not in self.PATH_ARGS + ), "`instance_name` is a reserved paramater name" + assert ( + "instance_name" + not in signature(self.__class__._serialize_payload).parameters + ), "`instance_name` is a reserved paramater name" + assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod @@ -135,7 +148,11 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks - def send_request(**kwargs): + def send_request(instance_name="master", **kwargs): + # Currently we only support sending requests to master process. + if instance_name != "master": + raise Exception("Unknown instance") + data = yield cls._serialize_payload(**kwargs) url_args = [ diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index f35cebc710..0459f582bf 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) + self._instance_name = hs.get_instance_name() + # We pull the streams from the replication steamer (if we try and make # them ourselves we end up in an import loop). self.streams = hs.get_replication_streamer().get_streams() @@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): upto_token = parse_integer(request, "upto_token", required=True) updates, upto_token, limited = await stream.get_updates_since( - from_token, upto_token + self._instance_name, from_token, upto_token ) return ( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 5c28fd4ac3..3bbf3c3569 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -86,17 +86,19 @@ class ReplicationDataHandler: def __init__(self, store: BaseSlavedStore): self.store = store - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to handle more. Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. + stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d72f3d0cf9..2d1d119c7c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -278,19 +278,24 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self._replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: @@ -325,7 +330,9 @@ class ReplicationCommandHandler: updates, current_token, missing_updates, - ) = await stream.get_updates_since(current_token, cmd.token) + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) # TODO: add some tests for this @@ -334,7 +341,10 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + cmd.stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4af1afd119..b0f87c365b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # # The arguments are: # +# * instance_name: the writer of the stream # * from_token: the previous stream token: the starting point for fetching the # updates # * to_token: the new stream token: the point to get updates up to @@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # If there are more updates available, it should set `limited` in the result, and # it will be called again to get the next batch. # -UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): @@ -93,6 +94,7 @@ class Stream(object): def __init__( self, + local_instance_name: str, current_token_function: Callable[[], Token], update_function: UpdateFunction, ): @@ -108,9 +110,11 @@ class Stream(object): stream tokens. See the UpdateFunction type definition for more info. Args: + local_instance_name: The instance name of the current process current_token_function: callback to get the current token, as above update_function: callback go get stream updates, as above """ + self.local_instance_name = local_instance_name self.current_token = current_token_function self.update_function = update_function @@ -135,14 +139,14 @@ class Stream(object): """ current_token = self.current_token() updates, current_token, limited = await self.get_updates_since( - self.last_token, current_token + self.local_instance_name, self.last_token, current_token ) self.last_token = current_token return updates, current_token, limited async def get_updates_since( - self, from_token: Token, upto_token: Token + self, instance_name: str, from_token: Token, upto_token: Token ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -160,19 +164,19 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ - async def update_function(from_token, upto_token, limit): + async def update_function(instance_name, from_token, upto_token, limit): rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] limited = False @@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: client = ReplicationGetStreamUpdates.make_client(hs) async def update_function( - from_token: int, upto_token: int, limit: int + instance_name: str, from_token: int, upto_token: int, limit: int ) -> StreamUpdateResult: result = await client( - stream_name=stream_name, from_token=from_token, upto_token=upto_token, + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, ) return result["updates"], result["upto_token"], result["limited"] @@ -226,6 +233,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_backfill_token, db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -261,7 +269,9 @@ class PresenceStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(store.get_current_presence_token, update_function) + super().__init__( + hs.get_instance_name(), store.get_current_presence_token, update_function + ) class TypingStream(Stream): @@ -284,7 +294,9 @@ class TypingStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(typing_handler.get_current_token, update_function) + super().__init__( + hs.get_instance_name(), typing_handler.get_current_token, update_function + ) class ReceiptsStream(Stream): @@ -305,6 +317,7 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_receipt_stream_id, db_query_to_update_function(store.get_all_updated_receipts), ) @@ -322,14 +335,16 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super(PushRulesStream, self).__init__( - self._current_token, self._update_function + hs.get_instance_name(), self._current_token, self._update_function ) def _current_token(self) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function(self, from_token: Token, to_token: Token, limit: int): + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) limited = False @@ -356,6 +371,7 @@ class PushersStream(Stream): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_pushers_stream_token, db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -387,6 +403,7 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_cache_stream_token, db_query_to_update_function(store.get_all_updated_caches), ) @@ -412,6 +429,7 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_public_room_stream_id, db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -432,6 +450,7 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -449,6 +468,7 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_to_device_stream_token, db_query_to_update_function(store.get_all_new_device_messages), ) @@ -468,6 +488,7 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_account_data_stream_id, db_query_to_update_function(store.get_all_updated_tags), ) @@ -487,6 +508,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super().__init__( + hs.get_instance_name(), self.store.get_max_account_data_stream_id, db_query_to_update_function(self._update_function), ) @@ -517,6 +539,7 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_group_stream_token, db_query_to_update_function(store.get_all_groups_changes), ) @@ -534,6 +557,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function( store.get_all_user_signature_changes_for_remotes diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 52df81b1bd..890e75d827 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -118,11 +118,17 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() super().__init__( - self._store.get_current_events_token, self._update_function, + hs.get_instance_name(), + self._store.get_current_events_token, + self._update_function, ) async def _update_function( - self, from_token: Token, current_token: Token, target_row_count: int + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, ) -> StreamUpdateResult: # the events stream merges together three separate sources: diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 75133d7e40..e8bd52e389 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -48,8 +48,8 @@ class FederationStream(Stream): current_token = lambda: 0 update_function = self._stub_update_function - super().__init__(current_token, update_function) + super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - async def _stub_update_function(from_token, upto_token, limit): + async def _stub_update_function(instance_name, from_token, upto_token, limit): return [], upto_token, False diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 8c104f8d1d..7b56d2028d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -183,8 +183,8 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - async def on_rdata(self, stream_name, token, rows): - await super().on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index df332ee679..5853314fd4 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -41,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # there should be one RDATA command self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow @@ -71,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # We should now have caught up and get the missing data self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index e8d17ca68a..d25a7b194e 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow @@ -74,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] -- cgit 1.5.1 From d5aa7d93ed1f7963524125d16ab640ebf6cb91c2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 5 May 2020 14:15:57 +0100 Subject: Fix catchup-on-reconnect for the Federation Stream (#7374) looks like we managed to break this during the refactorathon. --- changelog.d/7374.misc | 1 + synapse/federation/send_queue.py | 40 ++++++----- synapse/federation/sender/__init__.py | 12 ++-- synapse/federation/sender/per_destination_queue.py | 6 +- synapse/federation/sender/transaction_manager.py | 6 +- synapse/replication/tcp/resource.py | 2 +- synapse/replication/tcp/streams/_base.py | 3 +- synapse/replication/tcp/streams/federation.py | 30 +++++--- tests/replication/tcp/streams/_base.py | 20 ++++-- tests/replication/tcp/streams/test_federation.py | 81 ++++++++++++++++++++++ tests/replication/tcp/streams/test_typing.py | 5 -- 11 files changed, 158 insertions(+), 48 deletions(-) create mode 100644 changelog.d/7374.misc create mode 100644 tests/replication/tcp/streams/test_federation.py (limited to 'tests/replication/tcp/streams/test_typing.py') diff --git a/changelog.d/7374.misc b/changelog.d/7374.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7374.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index e1700ca8aa..6fbacf6a3e 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,6 +31,7 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple +from typing import List, Tuple from six import iteritems @@ -69,7 +70,11 @@ class FederationRemoteSendQueue(object): self.edus = SortedDict() # stream position -> Edu + # stream ID for the next entry into presence_changed/keyed_edu_changed/edus. self.pos = 1 + + # map from stream ID to the time that stream entry was generated, so that we + # can clear out entries after a while self.pos_time = SortedDict() # EVERYTHING IS SAD. In particular, python only makes new scopes when @@ -250,19 +255,23 @@ class FederationRemoteSendQueue(object): self._clear_queue_before_pos(token) async def get_replication_rows( - self, from_token, to_token, limit, federation_ack=None - ): + self, instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: """Get rows to be sent over federation between the two tokens Args: - from_token (int) - to_token(int) - limit (int) - federation_ack (int): Optional. The position where the worker is - explicitly acknowledged it has handled. Allows us to drop - data from before that point + instance_name: the name of the current process + from_token: the previous stream token: the starting point for fetching the + updates + to_token: the new stream token: the point to get updates up to + target_row_count: a target for the number of rows to be returned. + + Returns: a triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of `(token, row)` entries. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. """ - # TODO: Handle limit. + # TODO: Handle target_row_count. # To handle restarts where we wrap around if from_token > self.pos: @@ -270,12 +279,7 @@ class FederationRemoteSendQueue(object): # list of tuple(int, BaseFederationRow), where the first is the position # of the federation stream. - rows = [] - - # There should be only one reader, so lets delete everything its - # acknowledged its seen. - if federation_ack: - self._clear_queue_before_pos(federation_ack) + rows = [] # type: List[Tuple[int, BaseFederationRow]] # Fetch changed presence i = self.presence_changed.bisect_right(from_token) @@ -332,7 +336,11 @@ class FederationRemoteSendQueue(object): # Sort rows based on pos rows.sort() - return [(pos, row.TypeId, row.to_data()) for pos, row in rows] + return ( + [(pos, (row.TypeId, row.to_data())) for pos, row in rows], + to_token, + False, + ) class BaseFederationRow(object): diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index a477578e44..d473576902 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Hashable, Iterable, List, Optional, Set +from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple from six import itervalues @@ -498,14 +498,16 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() - def get_current_token(self) -> int: + @staticmethod + def get_current_token() -> int: # Dummy implementation for case where federation sender isn't offloaded # to a worker. return 0 + @staticmethod async def get_replication_rows( - self, from_token, to_token, limit, federation_ack=None - ): + instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: # Dummy implementation for case where federation sender isn't offloaded # to a worker. - return [] + return [], 0, False diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index e13cd20ffa..276a2b596f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,11 +15,10 @@ # limitations under the License. import datetime import logging -from typing import Dict, Hashable, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple from prometheus_client import Counter -import synapse.server from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter +if TYPE_CHECKING: + import synapse.server + # This is defined in the Matrix spec and enforced by the receiver. MAX_EDUS_PER_TRANSACTION = 100 diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 3c2a02a3b3..a2752a54a5 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import TYPE_CHECKING, List from canonicaljson import json -import synapse.server from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -31,6 +30,9 @@ from synapse.logging.opentracing import ( ) from synapse.util.metrics import measure_func +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 33d2f589ac..b690abedad 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -80,7 +80,7 @@ class ReplicationStreamer(object): for stream in STREAMS_MAP.values(): if stream == FederationStream and hs.config.send_federation: # We only support federation stream if federation sending - # hase been disabled on the master. + # has been disabled on the master. continue self.streams.append(stream(hs)) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b0f87c365b..084604e8b0 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -104,7 +104,8 @@ class Stream(object): implemented by subclasses. current_token_function is called to get the current token of the underlying - stream. + stream. It is only meaningful on the process that is the source of the + replication stream (ie, usually the master). update_function is called to get updates for this stream between a pair of stream tokens. See the UpdateFunction type definition for more info. diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index e8bd52e389..b0505b8a2c 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,7 @@ # limitations under the License. from collections import namedtuple -from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function +from synapse.replication.tcp.streams._base import Stream, make_http_update_function class FederationStream(Stream): @@ -35,21 +35,33 @@ class FederationStream(Stream): ROW_TYPE = FederationStreamRow def __init__(self, hs): - # Not all synapse instances will have a federation sender instance, - # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, - # so we stub the stream out when that is the case. - if hs.config.worker_app is None or hs.should_send_federation(): + if hs.config.worker_app is None: + # master process: get updates from the FederationRemoteSendQueue. + # (if the master is configured to send federation itself, federation_sender + # will be a real FederationSender, which has stubs for current_token and + # get_replication_rows.) federation_sender = hs.get_federation_sender() current_token = federation_sender.get_current_token - update_function = db_query_to_update_function( - federation_sender.get_replication_rows - ) + update_function = federation_sender.get_replication_rows + + elif hs.should_send_federation(): + # federation sender: Query master process + update_function = make_http_update_function(hs, self.NAME) + current_token = self._stub_current_token + else: - current_token = lambda: 0 + # other worker: stub out the update function (we're not interested in + # any updates so when we get a POSITION we do nothing) update_function = self._stub_update_function + current_token = self._stub_current_token super().__init__(hs.get_instance_name(), current_token, update_function) + @staticmethod + def _stub_current_token(): + # dummy current-token method for use on workers + return 0 + @staticmethod async def _stub_update_function(instance_name, from_token, upto_token, limit): return [], upto_token, False diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 7b56d2028d..9d4f0bbe44 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -27,6 +27,7 @@ from synapse.app.generic_worker import ( GenericWorkerServer, ) from synapse.http.site import SynapseRequest +from synapse.replication.http import streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -42,6 +43,10 @@ logger = logging.getLogger(__name__) class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + servlets = [ + streams.register_servlets, + ] + def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) @@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.server = server_factory.buildProtocol(None) # Make a new HomeServer object for the worker - config = self.default_config() - config["worker_app"] = "synapse.app.generic_worker" - config["worker_replication_host"] = "testserv" - config["worker_replication_http_port"] = "8765" - self.reactor.lookups["testserv"] = "1.2.3.4" - self.worker_hs = self.setup_test_homeserver( http_client=None, homeserverToUse=GenericWorkerServer, - config=config, + config=self._get_worker_hs_config(), reactor=self.reactor, ) @@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.generic_worker" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + def _build_replication_data_handler(self): return TestReplicationDataHandler(self.worker_hs) diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py new file mode 100644 index 0000000000..eea4565da3 --- /dev/null +++ b/tests/replication/tcp/streams/test_federation.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.federation.send_queue import EduRow +from synapse.replication.tcp.streams.federation import FederationStream + +from tests.replication.tcp.streams._base import BaseStreamTestCase + + +class FederationStreamTestCase(BaseStreamTestCase): + def _get_worker_hs_config(self) -> dict: + # enable federation sending on the worker + config = super()._get_worker_hs_config() + # TODO: make it so we don't need both of these + config["send_federation"] = True + config["worker_app"] = "synapse.app.federation_sender" + return config + + def test_catchup(self): + """Basic test of catchup on reconnect + + Makes sure that updates sent while we are offline are received later. + """ + fed_sender = self.hs.get_federation_sender() + received_rows = self.test_handler.received_rdata_rows + + fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"}) + + self.reconnect() + self.reactor.advance(0) + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual(received_rows, []) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "federation") + + # we should have received an update row + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test_edu") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"a": "b"}) + + self.assertEqual(received_rows, []) + + # additional updates should be transferred without an HTTP hit + fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"}) + self.reactor.advance(0) + # there should be no http hit + self.assertEqual(len(self.reactor.tcpClients), 0) + # ... but we should have a row + self.assertEqual(len(received_rows), 1) + + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test1") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"c": "d"}) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index d25a7b194e..125c63dab5 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -15,7 +15,6 @@ from mock import Mock from synapse.handlers.typing import RoomMember -from synapse.replication.http import streams from synapse.replication.tcp.streams import TypingStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -24,10 +23,6 @@ USER_ID = "@feeling:blue" class TypingStreamTestCase(BaseStreamTestCase): - servlets = [ - streams.register_servlets, - ] - def _build_replication_data_handler(self): return Mock(wraps=super()._build_replication_data_handler()) -- cgit 1.5.1