From 4cff617df1ba6f241fee6957cc44859f57edcc0e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Mar 2020 14:54:01 +0000 Subject: Move catchup of replication streams to worker. (#7024) This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date. --- synapse/replication/tcp/streams/events.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'synapse/replication/tcp/streams/events.py') diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index b3afabb8cd..c6a595629f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import Tuple, Type import attr -from ._base import Stream +from ._base import Stream, db_query_to_update_function """Handling of the 'events' replication stream @@ -117,10 +117,11 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() self.current_token = self._store.get_current_events_token # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(EventsStream, self).__init__(hs) - async def update_function(self, from_token, current_token, limit=None): + async def _update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) -- cgit 1.5.1 From 67ff7b8ba0d3647f3c370341dff3f035b3a1160a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 17 Apr 2020 14:49:55 +0100 Subject: Improve type checking in `replication.tcp.Stream` (#7291) The general idea here is to get rid of the type: ignore annotations on all of the current_token and update_function assignments, which would have caught #7290. After a bit of experimentation, it seems like the least-awful way to do this is to pass the offending functions in as parameters to the Stream constructor. Unfortunately that means that the concrete implementations no longer have the same constructor signature as Stream itself, which means that it gets hard to correctly annotate STREAMS_MAP. I've also introduced a couple of new types, to take out some duplication. --- changelog.d/7291.misc | 1 + synapse/replication/tcp/streams/__init__.py | 5 +- synapse/replication/tcp/streams/_base.py | 224 ++++++++++++++------------ synapse/replication/tcp/streams/events.py | 16 +- synapse/replication/tcp/streams/federation.py | 19 ++- 5 files changed, 143 insertions(+), 122 deletions(-) create mode 100644 changelog.d/7291.misc (limited to 'synapse/replication/tcp/streams/events.py') diff --git a/changelog.d/7291.misc b/changelog.d/7291.misc new file mode 100644 index 0000000000..02e7ae3fa2 --- /dev/null +++ b/changelog.d/7291.misc @@ -0,0 +1 @@ +Improve typing annotations in `synapse.replication.tcp.streams.Stream`. diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 37bcd3de66..d1a61c3314 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -25,8 +25,6 @@ Each stream is defined by the following information: update_function: The function that returns a list of updates between two tokens """ -from typing import Dict, Type - from synapse.replication.tcp.streams._base import ( AccountDataStream, BackfillStream, @@ -67,8 +65,7 @@ STREAMS_MAP = { GroupServerStream, UserSignatureStream, ) -} # type: Dict[str, Type[Stream]] - +} __all__ = [ "STREAMS_MAP", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 0d3f050776..a860072ccf 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,12 +16,11 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates -from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -34,8 +33,32 @@ MAX_EVENTS_BEHIND = 500000 # A stream position token Token = int -# A pair of position in stream and args used to create an instance of `ROW_TYPE`. -StreamRow = Tuple[Token, tuple] +# The type of a stream update row, after JSON deserialisation, but before +# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's +# just a row from a database query, though this is dependent on the stream in question. +# +StreamRow = Tuple + +# The type returned by the update_function of a stream, as well as get_updates(), +# get_updates_since, etc. +# +# It consists of a triplet `(updates, new_last_token, limited)`, where: +# * `updates` is a list of `(token, row)` entries. +# * `new_last_token` is the new position in stream. +# * `limited` is whether there are more updates to fetch. +# +StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] + +# The type of an update_function for a stream +# +# The arguments are: +# +# * from_token: the previous stream token: the starting point for fetching the +# updates +# * to_token: the new stream token: the point to get updates up to +# * limit: the maximum number of rows to return +# +UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): @@ -50,7 +73,7 @@ class Stream(object): ROW_TYPE = None # type: Any @classmethod - def parse_row(cls, row): + def parse_row(cls, row: StreamRow): """Parse a row received over replication By default, assumes that the row data is an array object and passes its contents @@ -64,7 +87,28 @@ class Stream(object): """ return cls.ROW_TYPE(*row) - def __init__(self, hs): + def __init__( + self, + current_token_function: Callable[[], Token], + update_function: UpdateFunction, + ): + """Instantiate a Stream + + current_token_function and update_function are callbacks which should be + implemented by subclasses. + + current_token_function is called to get the current token of the underlying + stream. + + update_function is called to get updates for this stream between a pair of + stream tokens. See the UpdateFunction type definition for more info. + + Args: + current_token_function: callback to get the current token, as above + update_function: callback go get stream updates, as above + """ + self.current_token = current_token_function + self.update_function = update_function # The token from which we last asked for updates self.last_token = self.current_token() @@ -75,7 +119,7 @@ class Stream(object): """ self.last_token = self.current_token() - async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: + async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before). @@ -95,7 +139,7 @@ class Stream(object): async def get_updates_since( self, from_token: Token, upto_token: Token, limit: int = 100 - ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: + ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -112,33 +156,14 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, limit=limit, + from_token, upto_token, limit, ) return updates, upto_token, limited - def current_token(self): - """Gets the current token of the underlying streams. Should be provided - by the sub classes - - Returns: - int - """ - raise NotImplementedError() - - def update_function(self, from_token, current_token, limit): - """Get updates between from_token and to_token. - - Returns: - Deferred(list(tuple)): the first entry in the tuple is the token for - that update, and the rest of the tuple gets used to construct - a ``ROW_TYPE`` instance - """ - raise NotImplementedError() - def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] -) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] +) -> UpdateFunction: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ @@ -157,9 +182,7 @@ def db_query_to_update_function( return update_function -def make_http_update_function( - hs, stream_name: str -) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]: +def make_http_update_function(hs, stream_name: str) -> UpdateFunction: """Makes a suitable function for use as an `update_function` that queries the master process for updates. """ @@ -168,7 +191,7 @@ def make_http_update_function( async def update_function( from_token: int, upto_token: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + ) -> StreamUpdateResult: result = await client( stream_name=stream_name, from_token=from_token, @@ -202,10 +225,10 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore - - super(BackfillStream, self).__init__(hs) + super().__init__( + store.get_current_backfill_token, + db_query_to_update_function(store.get_all_new_backfill_event_rows), + ) class PresenceStream(Stream): @@ -227,19 +250,18 @@ class PresenceStream(Stream): def __init__(self, hs): store = hs.get_datastore() - presence_handler = hs.get_presence_handler() - - self._is_worker = hs.config.worker_app is not None - - self.current_token = store.get_current_presence_token # type: ignore if hs.config.worker_app is None: - self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + # on the master, query the presence handler + presence_handler = hs.get_presence_handler() + update_function = db_query_to_update_function( + presence_handler.get_all_presence_updates + ) else: # Query master process - self.update_function = make_http_update_function(hs, self.NAME) # type: ignore + update_function = make_http_update_function(hs, self.NAME) - super(PresenceStream, self).__init__(hs) + super().__init__(store.get_current_presence_token, update_function) class TypingStream(Stream): @@ -253,15 +275,16 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - self.current_token = typing_handler.get_current_token # type: ignore - if hs.config.worker_app is None: - self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + # on the master, query the typing handler + update_function = db_query_to_update_function( + typing_handler.get_all_typing_updates + ) else: # Query master process - self.update_function = make_http_update_function(hs, self.NAME) # type: ignore + update_function = make_http_update_function(hs, self.NAME) - super(TypingStream, self).__init__(hs) + super().__init__(typing_handler.get_current_token, update_function) class ReceiptsStream(Stream): @@ -281,11 +304,10 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore - - super(ReceiptsStream, self).__init__(hs) + super().__init__( + store.get_max_receipt_stream_id, + db_query_to_update_function(store.get_all_updated_receipts), + ) class PushRulesStream(Stream): @@ -299,13 +321,15 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() - super(PushRulesStream, self).__init__(hs) + super(PushRulesStream, self).__init__( + self._current_token, self._update_function + ) - def current_token(self): + def _current_token(self) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token: Token, to_token: Token, limit: int): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) limited = False @@ -331,10 +355,10 @@ class PushersStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore - - super(PushersStream, self).__init__(hs) + super().__init__( + store.get_pushers_stream_token, + db_query_to_update_function(store.get_all_updated_pushers_rows), + ) class CachesStream(Stream): @@ -362,11 +386,10 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore - - super(CachesStream, self).__init__(hs) + super().__init__( + store.get_cache_stream_token, + db_query_to_update_function(store.get_all_updated_caches), + ) class PublicRoomsStream(Stream): @@ -388,11 +411,10 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore - - super(PublicRoomsStream, self).__init__(hs) + super().__init__( + store.get_current_public_room_stream_id, + db_query_to_update_function(store.get_all_new_public_rooms), + ) class DeviceListsStream(Stream): @@ -409,11 +431,10 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_device_stream_token # type: ignore - self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore - - super(DeviceListsStream, self).__init__(hs) + super().__init__( + store.get_device_stream_token, + db_query_to_update_function(store.get_all_device_list_changes_for_remotes), + ) class ToDeviceStream(Stream): @@ -427,11 +448,10 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore - - super(ToDeviceStream, self).__init__(hs) + super().__init__( + store.get_to_device_stream_token, + db_query_to_update_function(store.get_all_new_device_messages), + ) class TagAccountDataStream(Stream): @@ -447,11 +467,10 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore - - super(TagAccountDataStream, self).__init__(hs) + super().__init__( + store.get_max_account_data_stream_id, + db_query_to_update_function(store.get_all_updated_tags), + ) class AccountDataStream(Stream): @@ -467,11 +486,10 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() - - self.current_token = self.store.get_max_account_data_stream_id # type: ignore - self.update_function = db_query_to_update_function(self._update_function) # type: ignore - - super(AccountDataStream, self).__init__(hs) + super().__init__( + self.store.get_max_account_data_stream_id, + db_query_to_update_function(self._update_function), + ) async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( @@ -498,11 +516,10 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_group_stream_token # type: ignore - self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore - - super(GroupServerStream, self).__init__(hs) + super().__init__( + store.get_group_stream_token, + db_query_to_update_function(store.get_all_groups_changes), + ) class UserSignatureStream(Stream): @@ -516,8 +533,9 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_device_stream_token # type: ignore - self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore - - super(UserSignatureStream, self).__init__(hs) + super().__init__( + store.get_device_stream_token, + db_query_to_update_function( + store.get_all_user_signature_changes_for_remotes + ), + ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index c6a595629f..051114596b 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -15,11 +15,11 @@ # limitations under the License. import heapq -from typing import Tuple, Type +from typing import Iterable, Tuple, Type import attr -from ._base import Stream, db_query_to_update_function +from ._base import Stream, Token, db_query_to_update_function """Handling of the 'events' replication stream @@ -116,12 +116,14 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() - self.current_token = self._store.get_current_events_token # type: ignore - self.update_function = db_query_to_update_function(self._update_function) # type: ignore - - super(EventsStream, self).__init__(hs) + super().__init__( + self._store.get_current_events_token, + db_query_to_update_function(self._update_function), + ) - async def _update_function(self, from_token, current_token, limit=None): + async def _update_function( + self, from_token: Token, current_token: Token, limit: int + ) -> Iterable[tuple]: event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 48c1d45718..75133d7e40 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,8 +15,6 @@ # limitations under the License. from collections import namedtuple -from twisted.internet import defer - from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function @@ -35,7 +33,6 @@ class FederationStream(Stream): NAME = "federation" ROW_TYPE = FederationStreamRow - _QUERY_MASTER = True def __init__(self, hs): # Not all synapse instances will have a federation sender instance, @@ -43,10 +40,16 @@ class FederationStream(Stream): # so we stub the stream out when that is the case. if hs.config.worker_app is None or hs.should_send_federation(): federation_sender = hs.get_federation_sender() - self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore + current_token = federation_sender.get_current_token + update_function = db_query_to_update_function( + federation_sender.get_replication_rows + ) else: - self.current_token = lambda: 0 # type: ignore - self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore + current_token = lambda: 0 + update_function = self._stub_update_function + + super().__init__(current_token, update_function) - super(FederationStream, self).__init__(hs) + @staticmethod + async def _stub_update_function(from_token, upto_token, limit): + return [], upto_token, False -- cgit 1.5.1 From ce428a1abe6aae25e236baf268f56b1811cba333 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 23 Apr 2020 18:19:08 +0100 Subject: Fix EventsStream raising assertions when it falls behind Figuring out how to correctly limit updates from this stream without dropping entries is far more complicated than just counting the number of rows being returned. We need to consider each query separately and, if any one query hits the limit, truncate the results from the others. I think this also fixes some potentially long-standing bugs where events or state changes could get missed if we hit the limit on either query. --- synapse/replication/tcp/streams/events.py | 113 ++++++++++++++++++---- synapse/storage/data_stores/main/events_worker.py | 46 ++++++--- 2 files changed, 129 insertions(+), 30 deletions(-) (limited to 'synapse/replication/tcp/streams/events.py') diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 051114596b..aa50492569 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -15,11 +15,12 @@ # limitations under the License. import heapq -from typing import Iterable, Tuple, Type +from collections import Iterable +from typing import List, Tuple, Type import attr -from ._base import Stream, Token, db_query_to_update_function +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,30 +118,106 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() super().__init__( - self._store.get_current_events_token, - db_query_to_update_function(self._update_function), + self._store.get_current_events_token, self._update_function, ) async def _update_function( - self, from_token: Token, current_token: Token, limit: int - ) -> Iterable[tuple]: + self, from_token: Token, current_token: Token, target_row_count: int + ) -> StreamUpdateResult: + + # the events stream merges together three separate sources: + # * new events + # * current_state changes + # * events which were previously outliers, but have now been de-outliered. + # + # The merge operation is complicated by the fact that we only have a single + # "stream token" which is supposed to indicate how far we have got through + # all three streams. It's therefore no good to return rows 1-1000 from the + # "new events" table if the state_deltas are limited to rows 1-100 by the + # target_row_count. + # + # In other words: we must pick a new upper limit, and must return *all* rows + # up to that point for each of the three sources. + # + # Start by trying to split the target_row_count up. We expect to have a + # negligible number of ex-outliers, and a rough approximation based on recent + # traffic on sw1v.org shows that there are approximately the same number of + # event rows between a given pair of stream ids as there are state + # updates, so let's split our target_row_count among those two types. The target + # is only an approximation - it doesn't matter if we end up going a bit over it. + + target_row_count //= 2 + + # now we fetch up to that many rows from the events table + event_rows = await self._store.get_all_new_forward_event_rows( - from_token, current_token, limit - ) - event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows - ) + from_token, current_token, target_row_count + ) # type: List[Tuple] + + # we rely on get_all_new_forward_event_rows strictly honouring the limit, so + # that we know it is safe to just take upper_limit = event_rows[-1][0]. + assert ( + len(event_rows) <= target_row_count + ), "get_all_new_forward_event_rows did not honour row limit" + + # if we hit the limit on event_updates, there's no point in going beyond the + # last stream_id in the batch for the other sources. + + if len(event_rows) == target_row_count: + limited = True + upper_limit = event_rows[-1][0] # type: int + else: + limited = False + upper_limit = current_token + + # next up is the state delta table state_rows = await self._store.get_all_updated_current_state_deltas( - from_token, current_token, limit - ) - state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows - ) + from_token, upper_limit, target_row_count + ) # type: List[Tuple] + + # again, if we've hit the limit there, we'll need to limit the other sources + assert len(state_rows) < target_row_count + if len(state_rows) == target_row_count: + assert state_rows[-1][0] <= upper_limit + upper_limit = state_rows[-1][0] + limited = True + + # FIXME: is it a given that there is only one row per stream_id in the + # state_deltas table (so that we can be sure that we have got all of the + # rows for upper_limit)? + + # finally, fetch the ex-outliers rows. We assume there are few enough of these + # not to bother with the limit. - all_updates = heapq.merge(event_updates, state_updates) + ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( + from_token, upper_limit + ) # type: List[Tuple] - return all_updates + # we now need to turn the raw database rows returned into tuples suitable + # for the replication protocol (basically, we add an identifier to + # distinguish the row type). At the same time, we can limit the event_rows + # to the max stream_id from state_rows. + + event_updates = ( + (stream_id, (EventsStreamEventRow.TypeId, rest)) + for (stream_id, *rest) in event_rows + if stream_id <= upper_limit + ) # type: Iterable[Tuple[int, Tuple]] + + state_updates = ( + (stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) + for (stream_id, *rest) in state_rows + ) # type: Iterable[Tuple[int, Tuple]] + + ex_outliers_updates = ( + (stream_id, (EventsStreamEventRow.TypeId, rest)) + for (stream_id, *rest) in ex_outliers_rows + ) # type: Iterable[Tuple[int, Tuple]] + + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) + return updates, upper_limit, limited @classmethod def parse_row(cls, row): diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index accde349a7..ce8be72bfe 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -973,8 +973,18 @@ class EventsWorkerStore(SQLBaseStore): return self._stream_id_gen.get_current_token() def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) + """Returns new events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ def get_all_new_forward_event_rows(txn): sql = ( @@ -989,13 +999,26 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() + return txn.fetchall() - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_ex_outlier_stream_rows(self, last_id, current_id): + """Returns de-outliered events, for the Events replication stream + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ + + def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id" @@ -1006,15 +1029,14 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN event_relations USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" + " ORDER BY event_stream_ordering ASC" ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - return new_event_updates + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows + "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) def get_all_new_backfill_event_rows(self, last_id, current_id, limit): -- cgit 1.5.1 From c2e1a2110fbe9ead26b4ecbb1afd504ed035a04d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 29 Apr 2020 12:30:36 +0100 Subject: Fix limit logic for EventsStream (#7358) * Factor out functions for injecting events into database I want to add some more flexibility to the tools for injecting events into the database, and I don't want to clutter up HomeserverTestCase with them, so let's factor them out to a new file. * Rework TestReplicationDataHandler This wasn't very easy to work with: the mock wrapping was largely superfluous, and it's useful to be able to inspect the received rows, and clear out the received list. * Fix AssertionErrors being thrown by EventsStream Part of the problem was that there was an off-by-one error in the assertion, but also the limit logic was too simple. Fix it all up and add some tests. --- changelog.d/7358.bugfix | 1 + synapse/replication/tcp/handler.py | 4 +- synapse/replication/tcp/streams/events.py | 22 +- synapse/server.pyi | 5 + synapse/storage/data_stores/main/events_worker.py | 64 +++- tests/replication/tcp/streams/_base.py | 41 ++- tests/replication/tcp/streams/test_events.py | 417 ++++++++++++++++++++++ tests/replication/tcp/streams/test_receipts.py | 10 +- tests/replication/tcp/streams/test_typing.py | 11 +- tests/rest/client/v1/utils.py | 2 +- tests/test_utils/__init__.py | 20 ++ tests/test_utils/event_injection.py | 96 +++++ tests/unittest.py | 30 +- tox.ini | 2 + 14 files changed, 658 insertions(+), 67 deletions(-) create mode 100644 changelog.d/7358.bugfix create mode 100644 tests/replication/tcp/streams/test_events.py create mode 100644 tests/test_utils/event_injection.py (limited to 'synapse/replication/tcp/streams/events.py') diff --git a/changelog.d/7358.bugfix b/changelog.d/7358.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7358.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0db5a3a24d..3a8c7c7e2d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -87,7 +87,9 @@ class ReplicationCommandHandler: stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] - self._position_linearizer = Linearizer("replication_position") + self._position_linearizer = Linearizer( + "replication_position", clock=self._clock + ) # Map of stream to batched updates. See RdataCommand for info on how # batching works. diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index aa50492569..52df81b1bd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -170,22 +170,16 @@ class EventsStream(Stream): limited = False upper_limit = current_token - # next up is the state delta table - - state_rows = await self._store.get_all_updated_current_state_deltas( + # next up is the state delta table. + ( + state_rows, + upper_limit, + state_rows_limited, + ) = await self._store.get_all_updated_current_state_deltas( from_token, upper_limit, target_row_count - ) # type: List[Tuple] - - # again, if we've hit the limit there, we'll need to limit the other sources - assert len(state_rows) < target_row_count - if len(state_rows) == target_row_count: - assert state_rows[-1][0] <= upper_limit - upper_limit = state_rows[-1][0] - limited = True + ) - # FIXME: is it a given that there is only one row per stream_id in the - # state_deltas table (so that we can be sure that we have got all of the - # rows for upper_limit)? + limited = limited or state_rows_limited # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. diff --git a/synapse/server.pyi b/synapse/server.pyi index f1a5717028..fc5886f762 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage +from synapse.events.builder import EventBuilderFactory class HomeServer(object): @property @@ -121,3 +122,7 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_event_builder_factory(self) -> EventBuilderFactory: + pass + def get_storage(self) -> synapse.storage.Storage: + pass diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ce8be72bfe..73df6b33ba 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,7 +19,7 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional +from typing import List, Optional, Tuple from canonicaljson import json from constantly import NamedConstant, Names @@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore): "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + async def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple], int, bool]: + """Fetch updates from current_state_delta_stream + + Args: + from_token: The previous stream token. Updates from this stream id will + be excluded. + + to_token: The current stream token (ie the upper limit). Updates up to this + stream id will be included (modulo the 'limit' param) + + target_row_count: The number of rows to try to return. If more rows are + available, we will set 'limited' in the result. In the event of a large + batch, we may return more rows than this. + Returns: + A triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of database tuples. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. + """ + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id @@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (from_token, to_token, limit)) + txn.execute(sql, (from_token, to_token, target_row_count)) return txn.fetchall() - return self.db.runInteraction( + def get_deltas_for_stream_id_txn(txn, stream_id): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE stream_id = ? + """ + txn.execute(sql, [stream_id]) + return txn.fetchall() + + # we need to make sure that, for every stream id in the results, we get *all* + # the rows with that stream id. + + rows = await self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, + ) # type: List[Tuple] + + # if we've got fewer rows than the limit, we're good + if len(rows) < target_row_count: + return rows, to_token, False + + # we hit the limit, so reduce the upper limit so that we exclude the stream id + # of the last row in the result. + assert rows[-1][0] <= to_token + to_token = rows[-1][0] - 1 + + # search backwards through the list for the point to truncate + for idx in range(len(rows) - 1, 0, -1): + if rows[idx - 1][0] <= to_token: + return rows[:idx], to_token, True + + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + to_token += 1 + rows = await self.db.runInteraction( + "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) + return rows, to_token, True diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 82f15c64e0..83e16cfe3d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Optional -from mock import Mock +import logging +from typing import Any, Dict, List, Optional, Tuple import attr @@ -25,6 +24,7 @@ from twisted.web.http import HTTPChannel from synapse.app.generic_worker import GenericWorkerServer from synapse.http.site import SynapseRequest +from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -65,9 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # databases objects are the same. self.worker_hs.get_datastore().db = hs.get_datastore().db - self.test_handler = Mock( - wraps=TestReplicationDataHandler(self.worker_hs.get_datastore()) - ) + self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) @@ -78,6 +76,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def _build_replication_data_handler(self): + return TestReplicationDataHandler(self.worker_hs.get_datastore()) + def reconnect(self): if self._client_transport: self.client.close() @@ -174,22 +175,28 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class TestReplicationDataHandler(ReplicationDataHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, hs): - super().__init__(hs) - self.streams = set() - self._received_rdata_rows = [] + def __init__(self, store: BaseSlavedStore): + super().__init__(store) + + # streams to subscribe to: map from stream id to position + self.stream_positions = {} # type: Dict[str, int] + + # list of received (stream_name, token, row) tuples + self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] def get_streams_to_replicate(self): - positions = {s: 0 for s in self.streams} - for stream, token, _ in self._received_rdata_rows: - if stream in self.streams: - positions[stream] = max(token, positions.get(stream, 0)) - return positions + return self.stream_positions async def on_rdata(self, stream_name, token, rows): await super().on_rdata(stream_name, token, rows) for r in rows: - self._received_rdata_rows.append((stream_name, token, r)) + self.received_rdata_rows.append((stream_name, token, r)) + + if ( + stream_name in self.stream_positions + and token > self.stream_positions[stream_name] + ): + self.stream_positions[stream_name] = token @attr.s() @@ -221,7 +228,7 @@ class _PushHTTPChannel(HTTPChannel): super().__init__() self.reactor = reactor - self._pull_to_push_producer = None + self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] def registerProducer(self, producer, streaming): # Convert pull producers to push producer. diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py new file mode 100644 index 0000000000..1fa28084f9 --- /dev/null +++ b/tests/replication/tcp/streams/test_events.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, + EventsStreamRow, +) +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication.tcp.streams._base import BaseStreamTestCase +from tests.test_utils.event_injection import inject_event, inject_member_event + + +class EventsStreamTestCase(BaseStreamTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + + self.reconnect() + self.test_handler.stream_positions["events"] = 0 + + self.room_id = self.helper.create_room_as(tok=self.user_tok) + self.test_handler.received_rdata_rows.clear() + + def test_update_function_event_row_limit(self): + """Test replication with many non-state events + + Checks that all events are correctly replicated when there are lots of + event rows to be replicated. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # generate lots of non-state events. We inject them using inject_event + # so that they are not send out over replication until we call self.replicate(). + events = [ + self._inject_test_event() + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1) + ] + + # also one state event + state_event = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + for event in events: + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + self.assertEqual([], received_rows) + + def test_update_function_huge_state_change(self): + """Test replication with many state events + + Ensures that all events are correctly replicated when there are lots of + state change rows to be replicated. + """ + + # we want to generate lots of state changes at a single stream ID. + # + # We do this by having two branches in the DAG. On one, we have a moderator + # which that generates lots of state; on the other, we de-op the moderator, + # thus invalidating all the state. + + OTHER_USER = "@other_user:localhost" + + # have the user join + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"][OTHER_USER] = 50 + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [ + self._inject_state_event(sender=OTHER_USER) + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + ] + + self.replicate() + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # a state event which doesn't get rolled back, to check that the state + # before the huge update comes through ok + state1 = self._inject_state_event() + + # roll back all the state by de-modding the user + prev_events = fork_point + pls["users"][OTHER_USER] = 0 + pl_event = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + + # one more bit of state that doesn't get rolled back + state2 = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # now we should have received all the expected rows in the right order. + # + # we expect: + # + # - two rows for state1 + # - the PL event row, plus state rows for the PL event and each + # of the states that got reverted. + # - two rows for state2 + + received_rows = self.test_handler.received_rdata_rows + + # first check the first two rows, which should be state1 + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state1.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state1.event_id) + + # now the last two rows, which should be state2 + stream_name, token, row = received_rows.pop(-2) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state2.event_id) + + stream_name, token, row = received_rows.pop(-1) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state2.event_id) + + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for stream_name, token, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + def test_update_function_state_row_limit(self): + """Test replication with many state events over several stream ids. + """ + + # we want to generate lots of state changes, but for this test, we want to + # spread out the state changes over a few stream IDs. + # + # We do this by having two branches in the DAG. On one, we have four moderators, + # each of which that generates lots of state; on the other, we de-op the users, + # thus invalidating all the state. + + NUM_USERS = 4 + STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1 + + user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)] + + # have the users join + for u in user_ids: + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"].update({u: 50 for u in user_ids}) + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [] # type: List[EventBase] + for user in user_ids: + events.extend( + self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) + ) + + self.replicate() + + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # now roll back all that state by de-modding the users + prev_events = fork_point + pl_events = [] + for u in user_ids: + pls["users"][u] = 0 + e = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + prev_events = [e.event_id] + pl_events.append(e) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + + received_rows = self.test_handler.received_rdata_rows + self.assertGreaterEqual(len(received_rows), len(events)) + for i in range(NUM_USERS): + # for each user, we expect the PL event row, followed by state rows for + # the PL event and each of the states that got reverted. + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_events[i].event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for j in range(STATES_PER_USER + 1): + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_events[i].event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + self.assertEqual([], received_rows) + + event_count = 0 + + def _inject_test_event( + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + ) -> EventBase: + if sender is None: + sender = self.user_id + + if body is None: + body = "event %i" % (self.event_count,) + self.event_count += 1 + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) + + def _inject_state_event( + self, + body: Optional[str] = None, + state_key: Optional[str] = None, + sender: Optional[str] = None, + ) -> EventBase: + if sender is None: + sender = self.user_id + + if state_key is None: + state_key = "state_%i" % (self.event_count,) + self.event_count += 1 + + if body is None: + body = "state event %s" % (state_key,) + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index a0206f7363..c122b8589c 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,6 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# type: ignore + +from mock import Mock + from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -20,11 +25,14 @@ USER_ID = "@feeling:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_receipt(self): self.reconnect() # make the client subscribe to the receipts stream - self.test_handler.streams.add("receipts") + self.test_handler.stream_positions.update({"receipts": 0}) # tell the master to send a new receipt self.get_success( diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index f0ad6402ae..4d354a9db8 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.handlers.typing import RoomMember from synapse.replication.http import streams from synapse.replication.tcp.streams import TypingStream @@ -26,6 +28,9 @@ class TypingStreamTestCase(BaseStreamTestCase): streams.register_servlets, ] + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_typing(self): typing = self.hs.get_typing_handler() @@ -33,8 +38,8 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.streams.add("typing") + # make the client subscribe to the typing stream + self.test_handler.stream_positions.update({"typing": 0}) typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) @@ -75,6 +80,6 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row = rdata_rows[0] self.assertEqual(room_id, row.room_id) self.assertEqual([], row.user_ids) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 371637618d..22d734e763 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -39,7 +39,7 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator, is_public=True, tok=None): + def create_room_as(self, room_creator=None, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index a7310cf12a..7b345b03bb 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +17,22 @@ """ Utilities for running the unit tests """ +from typing import Awaitable, TypeVar + +TV = TypeVar("TV") + + +def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: + """Get the result from an Awaitable which should have completed + + Asserts that the given awaitable has a result ready, and returns its value + """ + i = awaitable.__await__() + try: + next(i) + except StopIteration as e: + # awaitable returned a result + return e.value + + # if next didn't raise, the awaitable hasn't completed. + raise Exception("awaitable has not yet completed") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py new file mode 100644 index 0000000000..8f6872761a --- /dev/null +++ b/tests/test_utils/event_injection.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import synapse.server +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.types import Collection + +from tests.test_utils import get_awaitable_result + + +""" +Utility functions for poking events into the storage of the server under test. +""" + + +def inject_member_event( + hs: synapse.server.HomeServer, + room_id: str, + sender: str, + membership: str, + target: Optional[str] = None, + extra_content: Optional[dict] = None, + **kwargs +) -> EventBase: + """Inject a membership event into a room.""" + if target is None: + target = sender + + content = {"membership": membership} + if extra_content: + content.update(extra_content) + + return inject_event( + hs, + room_id=room_id, + type=EventTypes.Member, + sender=sender, + state_key=target, + content=content, + **kwargs + ) + + +def inject_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> EventBase: + """Inject a generic event into a room + + Args: + hs: the homeserver under test + room_version: the version of the room we're inserting into. + if not specified, will be looked up + prev_event_ids: prev_events for the event. If not specified, will be looked up + kwargs: fields for the event to be created + """ + test_reactor = hs.get_reactor() + + if room_version is None: + d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) + test_reactor.advance(0) + room_version = get_awaitable_result(d) + + builder = hs.get_event_builder_factory().for_room_version( + KNOWN_ROOM_VERSIONS[room_version], kwargs + ) + d = hs.get_event_creation_handler().create_new_client_event( + builder, prev_event_ids=prev_event_ids + ) + test_reactor.advance(0) + event, context = get_awaitable_result(d) + + d = hs.get_storage().persistence.persist_event(event, context) + test_reactor.advance(0) + get_awaitable_result(d) + + return event diff --git a/tests/unittest.py b/tests/unittest.py index 27af5228fe..6b6f224e9c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server @@ -55,6 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) +from tests.test_utils import event_injection from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase): """ Inject a membership event into a room. + Deprecated: use event_injection.inject_room_member directly + Args: room: Room ID to inject the event into. user: MXID of the user to inject the membership for. membership: The membership type. """ - event_builder_factory = self.hs.get_event_builder_factory() - event_creation_handler = self.hs.get_event_creation_handler() - - room_version = self.get_success( - self.hs.get_datastore().get_room_version_id(room) - ) - - builder = event_builder_factory.for_room_version( - KNOWN_ROOM_VERSIONS[room_version], - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - event_creation_handler.create_new_client_event(builder) - ) - - self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) - ) + event_injection.inject_member_event(self.hs, room, user, membership) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tox.ini b/tox.ini index 31011d7436..2630857436 100644 --- a/tox.ini +++ b/tox.ini @@ -204,6 +204,8 @@ commands = mypy \ synapse/storage/database.py \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ + tests/replication/tcp/streams \ + tests/test_utils \ tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: -- cgit 1.5.1 From 0e719f23981b8294df66ba7f38b8c7cc99fad228 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 May 2020 17:19:56 +0100 Subject: Thread through instance name to replication client. (#7369) For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams. --- changelog.d/7369.misc | 1 + synapse/app/generic_worker.py | 10 +++--- synapse/replication/http/_base.py | 19 +++++++++- synapse/replication/http/streams.py | 4 ++- synapse/replication/tcp/client.py | 12 ++++--- synapse/replication/tcp/handler.py | 20 ++++++++--- synapse/replication/tcp/streams/_base.py | 50 +++++++++++++++++++------- synapse/replication/tcp/streams/events.py | 10 ++++-- synapse/replication/tcp/streams/federation.py | 4 +-- tests/replication/tcp/streams/_base.py | 4 +-- tests/replication/tcp/streams/test_receipts.py | 4 +-- tests/replication/tcp/streams/test_typing.py | 4 +-- 12 files changed, 101 insertions(+), 41 deletions(-) create mode 100644 changelog.d/7369.misc (limited to 'synapse/replication/tcp/streams/events.py') diff --git a/changelog.d/7369.misc b/changelog.d/7369.misc new file mode 100644 index 0000000000..060b09c888 --- /dev/null +++ b/changelog.d/7369.misc @@ -0,0 +1 @@ +Thread through instance name to replication client. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 97b9b81237..667ad20428 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -646,13 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): else: self.send_handler = None - async def on_rdata(self, stream_name, token, rows): - await super(GenericWorkerReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - await self.process_and_notify(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + await self._process_and_notify(stream_name, instance_name, token, rows) - async def process_and_notify(self, stream_name, token, rows): + async def _process_and_notify(self, stream_name, instance_name, token, rows): try: if self.send_handler: await self.send_handler.process_replication_rows( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 1be1ccbdf3..f88c80ae84 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,7 @@ import abc import logging import re +from inspect import signature from typing import Dict, List, Tuple from six import raise_from @@ -60,6 +61,8 @@ class ReplicationEndpoint(object): must call `register` to register the path with the HTTP server. Requests can be sent by calling the client returned by `make_client`. + Requests are sent to master process by default, but can be sent to other + named processes by specifying an `instance_name` keyword argument. Attributes: NAME (str): A name for the endpoint, added to the path as well as used @@ -91,6 +94,16 @@ class ReplicationEndpoint(object): hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) + # We reserve `instance_name` as a parameter to sending requests, so we + # assert here that sub classes don't try and use the name. + assert ( + "instance_name" not in self.PATH_ARGS + ), "`instance_name` is a reserved paramater name" + assert ( + "instance_name" + not in signature(self.__class__._serialize_payload).parameters + ), "`instance_name` is a reserved paramater name" + assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod @@ -135,7 +148,11 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks - def send_request(**kwargs): + def send_request(instance_name="master", **kwargs): + # Currently we only support sending requests to master process. + if instance_name != "master": + raise Exception("Unknown instance") + data = yield cls._serialize_payload(**kwargs) url_args = [ diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index f35cebc710..0459f582bf 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) + self._instance_name = hs.get_instance_name() + # We pull the streams from the replication steamer (if we try and make # them ourselves we end up in an import loop). self.streams = hs.get_replication_streamer().get_streams() @@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): upto_token = parse_integer(request, "upto_token", required=True) updates, upto_token, limited = await stream.get_updates_since( - from_token, upto_token + self._instance_name, from_token, upto_token ) return ( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 5c28fd4ac3..3bbf3c3569 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -86,17 +86,19 @@ class ReplicationDataHandler: def __init__(self, store: BaseSlavedStore): self.store = store - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to handle more. Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. + stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d72f3d0cf9..2d1d119c7c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -278,19 +278,24 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self._replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: @@ -325,7 +330,9 @@ class ReplicationCommandHandler: updates, current_token, missing_updates, - ) = await stream.get_updates_since(current_token, cmd.token) + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) # TODO: add some tests for this @@ -334,7 +341,10 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + cmd.stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4af1afd119..b0f87c365b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # # The arguments are: # +# * instance_name: the writer of the stream # * from_token: the previous stream token: the starting point for fetching the # updates # * to_token: the new stream token: the point to get updates up to @@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # If there are more updates available, it should set `limited` in the result, and # it will be called again to get the next batch. # -UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): @@ -93,6 +94,7 @@ class Stream(object): def __init__( self, + local_instance_name: str, current_token_function: Callable[[], Token], update_function: UpdateFunction, ): @@ -108,9 +110,11 @@ class Stream(object): stream tokens. See the UpdateFunction type definition for more info. Args: + local_instance_name: The instance name of the current process current_token_function: callback to get the current token, as above update_function: callback go get stream updates, as above """ + self.local_instance_name = local_instance_name self.current_token = current_token_function self.update_function = update_function @@ -135,14 +139,14 @@ class Stream(object): """ current_token = self.current_token() updates, current_token, limited = await self.get_updates_since( - self.last_token, current_token + self.local_instance_name, self.last_token, current_token ) self.last_token = current_token return updates, current_token, limited async def get_updates_since( - self, from_token: Token, upto_token: Token + self, instance_name: str, from_token: Token, upto_token: Token ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -160,19 +164,19 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ - async def update_function(from_token, upto_token, limit): + async def update_function(instance_name, from_token, upto_token, limit): rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] limited = False @@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: client = ReplicationGetStreamUpdates.make_client(hs) async def update_function( - from_token: int, upto_token: int, limit: int + instance_name: str, from_token: int, upto_token: int, limit: int ) -> StreamUpdateResult: result = await client( - stream_name=stream_name, from_token=from_token, upto_token=upto_token, + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, ) return result["updates"], result["upto_token"], result["limited"] @@ -226,6 +233,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_backfill_token, db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -261,7 +269,9 @@ class PresenceStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(store.get_current_presence_token, update_function) + super().__init__( + hs.get_instance_name(), store.get_current_presence_token, update_function + ) class TypingStream(Stream): @@ -284,7 +294,9 @@ class TypingStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(typing_handler.get_current_token, update_function) + super().__init__( + hs.get_instance_name(), typing_handler.get_current_token, update_function + ) class ReceiptsStream(Stream): @@ -305,6 +317,7 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_receipt_stream_id, db_query_to_update_function(store.get_all_updated_receipts), ) @@ -322,14 +335,16 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super(PushRulesStream, self).__init__( - self._current_token, self._update_function + hs.get_instance_name(), self._current_token, self._update_function ) def _current_token(self) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function(self, from_token: Token, to_token: Token, limit: int): + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) limited = False @@ -356,6 +371,7 @@ class PushersStream(Stream): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_pushers_stream_token, db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -387,6 +403,7 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_cache_stream_token, db_query_to_update_function(store.get_all_updated_caches), ) @@ -412,6 +429,7 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_public_room_stream_id, db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -432,6 +450,7 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -449,6 +468,7 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_to_device_stream_token, db_query_to_update_function(store.get_all_new_device_messages), ) @@ -468,6 +488,7 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_account_data_stream_id, db_query_to_update_function(store.get_all_updated_tags), ) @@ -487,6 +508,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super().__init__( + hs.get_instance_name(), self.store.get_max_account_data_stream_id, db_query_to_update_function(self._update_function), ) @@ -517,6 +539,7 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_group_stream_token, db_query_to_update_function(store.get_all_groups_changes), ) @@ -534,6 +557,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function( store.get_all_user_signature_changes_for_remotes diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 52df81b1bd..890e75d827 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -118,11 +118,17 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() super().__init__( - self._store.get_current_events_token, self._update_function, + hs.get_instance_name(), + self._store.get_current_events_token, + self._update_function, ) async def _update_function( - self, from_token: Token, current_token: Token, target_row_count: int + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, ) -> StreamUpdateResult: # the events stream merges together three separate sources: diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 75133d7e40..e8bd52e389 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -48,8 +48,8 @@ class FederationStream(Stream): current_token = lambda: 0 update_function = self._stub_update_function - super().__init__(current_token, update_function) + super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - async def _stub_update_function(from_token, upto_token, limit): + async def _stub_update_function(instance_name, from_token, upto_token, limit): return [], upto_token, False diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 8c104f8d1d..7b56d2028d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -183,8 +183,8 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - async def on_rdata(self, stream_name, token, rows): - await super().on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index df332ee679..5853314fd4 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -41,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # there should be one RDATA command self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow @@ -71,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # We should now have caught up and get the missing data self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index e8d17ca68a..d25a7b194e 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow @@ -74,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] -- cgit 1.5.1