diff options
Diffstat (limited to 'synapse/replication/tcp/streams')
-rw-r--r-- | synapse/replication/tcp/streams/__init__.py | 70 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 629 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 133 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/federation.py | 54 |
4 files changed, 623 insertions, 263 deletions
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 634f636dc9..d1a61c3314 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -25,25 +25,63 @@ Each stream is defined by the following information: update_function: The function that returns a list of updates between two tokens """ -from . import _base, events, federation +from synapse.replication.tcp.streams._base import ( + AccountDataStream, + BackfillStream, + CachesStream, + DeviceListsStream, + GroupServerStream, + PresenceStream, + PublicRoomsStream, + PushersStream, + PushRulesStream, + ReceiptsStream, + Stream, + TagAccountDataStream, + ToDeviceStream, + TypingStream, + UserSignatureStream, +) +from synapse.replication.tcp.streams.events import EventsStream +from synapse.replication.tcp.streams.federation import FederationStream STREAMS_MAP = { stream.NAME: stream for stream in ( - events.EventsStream, - _base.BackfillStream, - _base.PresenceStream, - _base.TypingStream, - _base.ReceiptsStream, - _base.PushRulesStream, - _base.PushersStream, - _base.CachesStream, - _base.PublicRoomsStream, - _base.DeviceListsStream, - _base.ToDeviceStream, - federation.FederationStream, - _base.TagAccountDataStream, - _base.AccountDataStream, - _base.GroupServerStream, + EventsStream, + BackfillStream, + PresenceStream, + TypingStream, + ReceiptsStream, + PushRulesStream, + PushersStream, + CachesStream, + PublicRoomsStream, + DeviceListsStream, + ToDeviceStream, + FederationStream, + TagAccountDataStream, + AccountDataStream, + GroupServerStream, + UserSignatureStream, ) } + +__all__ = [ + "STREAMS_MAP", + "Stream", + "BackfillStream", + "PresenceStream", + "TypingStream", + "ReceiptsStream", + "PushRulesStream", + "PushersStream", + "CachesStream", + "PublicRoomsStream", + "DeviceListsStream", + "ToDeviceStream", + "TagAccountDataStream", + "AccountDataStream", + "GroupServerStream", + "UserSignatureStream", +] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f03111c259..4acefc8a96 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,102 +14,84 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import itertools +import heapq import logging from collections import namedtuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) + +import attr + +from synapse.replication.http.streams import ReplicationGetStreamUpdates -from twisted.internet import defer +if TYPE_CHECKING: + import synapse.server logger = logging.getLogger(__name__) +# the number of rows to request from an update_function. +_STREAM_UPDATE_TARGET_ROW_COUNT = 100 -MAX_EVENTS_BEHIND = 10000 -BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), -) -PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), -) -TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) -) -ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), -) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str -PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool -) -CachesStreamRow = namedtuple( - "CachesStreamRow", - ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int -) -PublicRoomsStreamRow = namedtuple( - "PublicRoomsStreamRow", - ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional - ), -) -DeviceListsStreamRow = namedtuple( - "DeviceListsStreamRow", ("user_id", "destination") # str # str -) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str -TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict -) -AccountDataStreamRow = namedtuple( - "AccountDataStream", - ("user_id", "room_id", "data_type", "data"), # str # str # str # dict -) -GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict -) +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# The type of a stream update row, after JSON deserialisation, but before +# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's +# just a row from a database query, though this is dependent on the stream in question. +# +StreamRow = TypeVar("StreamRow", bound=Tuple) + +# The type returned by the update_function of a stream, as well as get_updates(), +# get_updates_since, etc. +# +# It consists of a triplet `(updates, new_last_token, limited)`, where: +# * `updates` is a list of `(token, row)` entries. +# * `new_last_token` is the new position in stream. +# * `limited` is whether there are more updates to fetch. +# +StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] + +# The type of an update_function for a stream +# +# The arguments are: +# +# * instance_name: the writer of the stream +# * from_token: the previous stream token: the starting point for fetching the +# updates +# * to_token: the new stream token: the point to get updates up to +# * target_row_count: a target for the number of rows to be returned. +# +# The update_function is expected to return up to _approximately_ target_row_count rows. +# If there are more updates available, it should set `limited` in the result, and +# it will be called again to get the next batch. +# +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last - time it was called up until the point `advance_current_token` was called. + time it was called. """ - NAME = None # The name of the stream - ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. - _LIMITED = True # Whether the update function takes a limit + NAME = None # type: str # The name of the stream + # The type of the row. Used by the default impl of parse_row. + ROW_TYPE = None # type: Any @classmethod - def parse_row(cls, row): + def parse_row(cls, row: StreamRow): """Parse a row received over replication By default, assumes that the row data is an array object and passes its contents @@ -123,102 +105,138 @@ class Stream(object): """ return cls.ROW_TYPE(*row) - def __init__(self, hs): - # The token from which we last asked for updates - self.last_token = self.current_token() - - # The token that we will get updates up to - self.upto_token = self.current_token() + def __init__( + self, + local_instance_name: str, + current_token_function: Callable[[str], Token], + update_function: UpdateFunction, + ): + """Instantiate a Stream + + `current_token_function` and `update_function` are callbacks which + should be implemented by subclasses. + + `current_token_function` takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). On the writer process this is where + the writer has successfully written up to, whereas on other processes + this is the position which we have received updates up to over + replication. (Note that most streams have a single writer and so their + implementations ignore the instance name passed in). + + `update_function` is called to get updates for this stream between a + pair of stream tokens. See the `UpdateFunction` type definition for more + info. - def advance_current_token(self): - """Updates `upto_token` to "now", which updates up until which point - get_updates[_since] will fetch rows till. + Args: + local_instance_name: The instance name of the current process + current_token_function: callback to get the current token, as above + update_function: callback go get stream updates, as above """ - self.upto_token = self.current_token() + self.local_instance_name = local_instance_name + self.current_token = current_token_function + self.update_function = update_function + + # The token from which we last asked for updates + self.last_token = self.current_token(self.local_instance_name) def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.upto_token = self.current_token() - self.last_token = self.upto_token + self.last_token = self.current_token(self.local_instance_name) - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or - since the stream was constructed if it hadn't been called before), - until the `upto_token` + since the stream was constructed if it hadn't been called before). Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - updates, current_token = yield self.get_updates_since(self.last_token) + current_token = self.current_token(self.local_instance_name) + updates, current_token, limited = await self.get_updates_since( + self.local_instance_name, self.last_token, current_token + ) self.last_token = current_token - return updates, current_token + return updates, current_token, limited - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since( + self, instance_name: str, from_token: Token, upto_token: Token + ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - if from_token in ("NOW", "now"): - return [], self.upto_token - - current_token = self.upto_token from_token = int(from_token) - if from_token == current_token: - return [], current_token + if from_token == upto_token: + return [], upto_token, False - if self._LIMITED: - rows = yield self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 - ) + updates, upto_token, limited = await self.update_function( + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + ) + return updates, upto_token, limited - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - else: - rows = yield self.update_function(from_token, current_token) +def current_token_without_instance( + current_token: Callable[[], int] +) -> Callable[[str], int]: + """Takes a current token callback function for a single writer stream + that doesn't take an instance name parameter and wraps it in a function that + does accept an instance name parameter but ignores it. + """ + return lambda instance_name: current_token() + + +def db_query_to_update_function( + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> UpdateFunction: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(instance_name, from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - # check we didn't get more rows than the limit. - # doing it like this allows the update_function to be a generator. - if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) + return updates, upto_token, limited - return updates, current_token + return update_function - def current_token(self): - """Gets the current token of the underlying streams. Should be provided - by the sub classes - Returns: - int - """ - raise NotImplementedError() +def make_http_update_function(hs, stream_name: str) -> UpdateFunction: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ - def update_function(self, from_token, current_token, limit=None): - """Get updates between from_token and to_token. If Stream._LIMITED is - True then limit is provided, otherwise it's not. + client = ReplicationGetStreamUpdates.make_client(hs) - Returns: - Deferred(list(tuple)): the first entry in the tuple is the token for - that update, and the rest of the tuple gets used to construct - a ``ROW_TYPE`` instance - """ - raise NotImplementedError() + async def update_function( + instance_name: str, from_token: int, upto_token: int, limit: int + ) -> StreamUpdateResult: + result = await client( + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + ) + return result["updates"], result["upto_token"], result["limited"] + + return update_function class BackfillStream(Stream): @@ -226,94 +244,170 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ + BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), + ) + NAME = "backfill" ROW_TYPE = BackfillStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_backfill_token - self.update_function = store.get_all_new_backfill_event_rows - - super(BackfillStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_backfill_token), + db_query_to_update_function(store.get_all_new_backfill_event_rows), + ) class PresenceStream(Stream): + PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), + ) + NAME = "presence" - _LIMITED = False ROW_TYPE = PresenceStreamRow def __init__(self, hs): store = hs.get_datastore() - presence_handler = hs.get_presence_handler() - self.current_token = store.get_current_presence_token - self.update_function = presence_handler.get_all_presence_updates + if hs.config.worker_app is None: + # on the master, query the presence handler + presence_handler = hs.get_presence_handler() + update_function = db_query_to_update_function( + presence_handler.get_all_presence_updates + ) + else: + # Query master process + update_function = make_http_update_function(hs, self.NAME) - super(PresenceStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_presence_token), + update_function, + ) class TypingStream(Stream): + TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) + ) + NAME = "typing" - _LIMITED = False ROW_TYPE = TypingStreamRow def __init__(self, hs): typing_handler = hs.get_typing_handler() - self.current_token = typing_handler.get_current_token - self.update_function = typing_handler.get_all_typing_updates + if hs.config.worker_app is None: + # on the master, query the typing handler + update_function = db_query_to_update_function( + typing_handler.get_all_typing_updates + ) + else: + # Query master process + update_function = make_http_update_function(hs, self.NAME) - super(TypingStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(typing_handler.get_current_token), + update_function, + ) class ReceiptsStream(Stream): + ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), + ) + NAME = "receipts" ROW_TYPE = ReceiptsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_max_receipt_stream_id - self.update_function = store.get_all_updated_receipts - - super(ReceiptsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_max_receipt_stream_id), + db_query_to_update_function(store.get_all_updated_receipts), + ) class PushRulesStream(Stream): """A user has changed their push rules """ + PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow def __init__(self, hs): self.store = hs.get_datastore() - super(PushRulesStream, self).__init__(hs) + super(PushRulesStream, self).__init__( + hs.get_instance_name(), self._current_token, self._update_function + ) - def current_token(self): + def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): """A user has added/changed/removed a pusher """ + PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool + ) + NAME = "pushers" ROW_TYPE = PushersStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_pushers_stream_token - self.update_function = store.get_all_updated_pushers_rows - - super(PushersStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_pushers_stream_token), + db_query_to_update_function(store.get_all_updated_pushers_rows), + ) class CachesStream(Stream): @@ -321,120 +415,235 @@ class CachesStream(Stream): the cache on the workers """ + @attr.s + class CachesStreamRow: + """Stream to inform workers they should invalidate their cache. + + Attributes: + cache_func: Name of the cached function. + keys: The entry in the cache to invalidate. If None then will + invalidate all. + invalidation_ts: Timestamp of when the invalidation took place. + """ + + cache_func = attr.ib(type=str) + keys = attr.ib(type=Optional[List[Any]]) + invalidation_ts = attr.ib(type=int) + NAME = "caches" ROW_TYPE = CachesStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + self.store.get_cache_stream_token, + self._update_function, + ) - self.current_token = store.get_cache_stream_token - self.update_function = store.get_all_updated_caches + async def _update_function( + self, instance_name: str, from_token: int, upto_token: int, limit: int + ): + rows = await self.store.get_all_updated_caches( + instance_name, from_token, upto_token, limit + ) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - super(CachesStream, self).__init__(hs) + return updates, upto_token, limited class PublicRoomsStream(Stream): """The public rooms list changed """ + PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), + ) + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_current_public_room_stream_id - self.update_function = store.get_all_new_public_rooms - - super(PublicRoomsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_public_room_stream_id), + db_query_to_update_function(store.get_all_new_public_rooms), + ) class DeviceListsStream(Stream): - """Someone added/changed/removed a device + """Either a user has updated their devices or a remote server needs to be + told about a device update. """ + @attr.s + class DeviceListsStreamRow: + entity = attr.ib(type=str) + NAME = "device_lists" - _LIMITED = False ROW_TYPE = DeviceListsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_device_stream_token - self.update_function = store.get_all_device_list_changes_for_remotes - - super(DeviceListsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_device_stream_token), + db_query_to_update_function(store.get_all_device_list_changes_for_remotes), + ) class ToDeviceStream(Stream): """New to_device messages for a client """ + ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_to_device_stream_token - self.update_function = store.get_all_new_device_messages - - super(ToDeviceStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_to_device_stream_token), + db_query_to_update_function(store.get_all_new_device_messages), + ) class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict + ) + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_max_account_data_stream_id - self.update_function = store.get_all_updated_tags - - super(TagAccountDataStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_max_account_data_stream_id), + db_query_to_update_function(store.get_all_updated_tags), + ) class AccountDataStream(Stream): """Global or per room account data was changed """ + AccountDataStreamRow = namedtuple( + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str + ) + NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(self.store.get_max_account_data_stream_id), + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit + ) - self.current_token = self.store.get_max_account_data_stream_id + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True - super(AccountDataStream, self).__init__(hs) + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit + ) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) + for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type, content) - for stream_id, user_id, account_data_type, content in global_results + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results ) - return results + # We need to return a sorted list, so merge them together. + # + # Note: We order only by the stream ID to work around a bug where the + # same stream ID could appear in both `global_rows` and `room_rows`, + # leading to a comparison between the data tuples. The comparison could + # fail due to attempting to compare the `room_id` which results in a + # `TypeError` from comparing a `str` vs `None`. + updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) + return updates, to_token, limited class GroupServerStream(Stream): + GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict + ) + NAME = "groups" ROW_TYPE = GroupsStreamRow def __init__(self, hs): store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_group_stream_token), + db_query_to_update_function(store.get_all_groups_changes), + ) - self.current_token = store.get_group_stream_token - self.update_function = store.get_all_groups_changes - super(GroupServerStream, self).__init__(hs) +class UserSignatureStream(Stream): + """A user has signed their own device with their user-signing key + """ + + UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + + NAME = "user_signature" + ROW_TYPE = UserSignatureStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_device_stream_token), + db_query_to_update_function( + store.get_all_user_signature_changes_for_remotes + ), + ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index d97669c886..f370390331 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,13 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import heapq +from collections import Iterable +from typing import List, Tuple, Type import attr -from twisted.internet import defer - -from ._base import Stream +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -63,7 +64,8 @@ class BaseEventsStreamRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overriden in sub classes. + TypeId = None # type: str @classmethod def from_data(cls, data): @@ -99,9 +101,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): event_id = attr.ib() # str, optional -TypeToRow = { - Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) -} +_EventRows = ( + EventsStreamEventRow, + EventsStreamCurrentStateRow, +) # type: Tuple[Type[BaseEventsStreamRow], ...] + +TypeToRow = {Row.TypeId: Row for Row in _EventRows} class EventsStream(Stream): @@ -112,29 +117,107 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() - self.current_token = self._store.get_current_events_token - - super(EventsStream, self).__init__(hs) - - @defer.inlineCallbacks - def update_function(self, from_token, current_token, limit=None): - event_rows = yield self._store.get_all_new_forward_event_rows( - from_token, current_token, limit - ) - event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows + super().__init__( + hs.get_instance_name(), + current_token_without_instance(self._store.get_current_events_token), + self._update_function, ) - state_rows = yield self._store.get_all_updated_current_state_deltas( - from_token, current_token, limit - ) - state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows + async def _update_function( + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, + ) -> StreamUpdateResult: + + # the events stream merges together three separate sources: + # * new events + # * current_state changes + # * events which were previously outliers, but have now been de-outliered. + # + # The merge operation is complicated by the fact that we only have a single + # "stream token" which is supposed to indicate how far we have got through + # all three streams. It's therefore no good to return rows 1-1000 from the + # "new events" table if the state_deltas are limited to rows 1-100 by the + # target_row_count. + # + # In other words: we must pick a new upper limit, and must return *all* rows + # up to that point for each of the three sources. + # + # Start by trying to split the target_row_count up. We expect to have a + # negligible number of ex-outliers, and a rough approximation based on recent + # traffic on sw1v.org shows that there are approximately the same number of + # event rows between a given pair of stream ids as there are state + # updates, so let's split our target_row_count among those two types. The target + # is only an approximation - it doesn't matter if we end up going a bit over it. + + target_row_count //= 2 + + # now we fetch up to that many rows from the events table + + event_rows = await self._store.get_all_new_forward_event_rows( + from_token, current_token, target_row_count + ) # type: List[Tuple] + + # we rely on get_all_new_forward_event_rows strictly honouring the limit, so + # that we know it is safe to just take upper_limit = event_rows[-1][0]. + assert ( + len(event_rows) <= target_row_count + ), "get_all_new_forward_event_rows did not honour row limit" + + # if we hit the limit on event_updates, there's no point in going beyond the + # last stream_id in the batch for the other sources. + + if len(event_rows) == target_row_count: + limited = True + upper_limit = event_rows[-1][0] # type: int + else: + limited = False + upper_limit = current_token + + # next up is the state delta table. + ( + state_rows, + upper_limit, + state_rows_limited, + ) = await self._store.get_all_updated_current_state_deltas( + from_token, upper_limit, target_row_count ) - all_updates = heapq.merge(event_updates, state_updates) + limited = limited or state_rows_limited + + # finally, fetch the ex-outliers rows. We assume there are few enough of these + # not to bother with the limit. + + ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( + from_token, upper_limit + ) # type: List[Tuple] + + # we now need to turn the raw database rows returned into tuples suitable + # for the replication protocol (basically, we add an identifier to + # distinguish the row type). At the same time, we can limit the event_rows + # to the max stream_id from state_rows. - return all_updates + event_updates = ( + (stream_id, (EventsStreamEventRow.TypeId, rest)) + for (stream_id, *rest) in event_rows + if stream_id <= upper_limit + ) # type: Iterable[Tuple[int, Tuple]] + + state_updates = ( + (stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) + for (stream_id, *rest) in state_rows + ) # type: Iterable[Tuple[int, Tuple]] + + ex_outliers_updates = ( + (stream_id, (EventsStreamEventRow.TypeId, rest)) + for (stream_id, *rest) in ex_outliers_rows + ) # type: Iterable[Tuple[int, Tuple]] + + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) + return updates, upper_limit, limited @classmethod def parse_row(cls, row): diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index dc2484109d..9bcd13b009 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,14 +15,10 @@ # limitations under the License. from collections import namedtuple -from ._base import Stream - -FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), +from synapse.replication.tcp.streams._base import ( + Stream, + current_token_without_instance, + make_http_update_function, ) @@ -31,13 +27,47 @@ class FederationStream(Stream): sending disabled. """ + FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), + ) + NAME = "federation" ROW_TYPE = FederationStreamRow def __init__(self, hs): - federation_sender = hs.get_federation_sender() + if hs.config.worker_app is None: + # master process: get updates from the FederationRemoteSendQueue. + # (if the master is configured to send federation itself, federation_sender + # will be a real FederationSender, which has stubs for current_token and + # get_replication_rows.) + federation_sender = hs.get_federation_sender() + current_token = current_token_without_instance( + federation_sender.get_current_token + ) + update_function = federation_sender.get_replication_rows + + elif hs.should_send_federation(): + # federation sender: Query master process + update_function = make_http_update_function(hs, self.NAME) + current_token = self._stub_current_token + + else: + # other worker: stub out the update function (we're not interested in + # any updates so when we get a POSITION we do nothing) + update_function = self._stub_update_function + current_token = self._stub_current_token + + super().__init__(hs.get_instance_name(), current_token, update_function) - self.current_token = federation_sender.get_current_token - self.update_function = federation_sender.get_replication_rows + @staticmethod + def _stub_current_token(instance_name: str) -> int: + # dummy current-token method for use on workers + return 0 - super(FederationStream, self).__init__(hs) + @staticmethod + async def _stub_update_function(instance_name, from_token, upto_token, limit): + return [], upto_token, False |