diff options
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 629 |
1 files changed, 419 insertions, 210 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f03111c259..4acefc8a96 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,102 +14,84 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import itertools +import heapq import logging from collections import namedtuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) + +import attr + +from synapse.replication.http.streams import ReplicationGetStreamUpdates -from twisted.internet import defer +if TYPE_CHECKING: + import synapse.server logger = logging.getLogger(__name__) +# the number of rows to request from an update_function. +_STREAM_UPDATE_TARGET_ROW_COUNT = 100 -MAX_EVENTS_BEHIND = 10000 -BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), -) -PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), -) -TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) -) -ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), -) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str -PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool -) -CachesStreamRow = namedtuple( - "CachesStreamRow", - ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int -) -PublicRoomsStreamRow = namedtuple( - "PublicRoomsStreamRow", - ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional - ), -) -DeviceListsStreamRow = namedtuple( - "DeviceListsStreamRow", ("user_id", "destination") # str # str -) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str -TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict -) -AccountDataStreamRow = namedtuple( - "AccountDataStream", - ("user_id", "room_id", "data_type", "data"), # str # str # str # dict -) -GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict -) +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# The type of a stream update row, after JSON deserialisation, but before +# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's +# just a row from a database query, though this is dependent on the stream in question. +# +StreamRow = TypeVar("StreamRow", bound=Tuple) + +# The type returned by the update_function of a stream, as well as get_updates(), +# get_updates_since, etc. +# +# It consists of a triplet `(updates, new_last_token, limited)`, where: +# * `updates` is a list of `(token, row)` entries. +# * `new_last_token` is the new position in stream. +# * `limited` is whether there are more updates to fetch. +# +StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] + +# The type of an update_function for a stream +# +# The arguments are: +# +# * instance_name: the writer of the stream +# * from_token: the previous stream token: the starting point for fetching the +# updates +# * to_token: the new stream token: the point to get updates up to +# * target_row_count: a target for the number of rows to be returned. +# +# The update_function is expected to return up to _approximately_ target_row_count rows. +# If there are more updates available, it should set `limited` in the result, and +# it will be called again to get the next batch. +# +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last - time it was called up until the point `advance_current_token` was called. + time it was called. """ - NAME = None # The name of the stream - ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. - _LIMITED = True # Whether the update function takes a limit + NAME = None # type: str # The name of the stream + # The type of the row. Used by the default impl of parse_row. + ROW_TYPE = None # type: Any @classmethod - def parse_row(cls, row): + def parse_row(cls, row: StreamRow): """Parse a row received over replication By default, assumes that the row data is an array object and passes its contents @@ -123,102 +105,138 @@ class Stream(object): """ return cls.ROW_TYPE(*row) - def __init__(self, hs): - # The token from which we last asked for updates - self.last_token = self.current_token() - - # The token that we will get updates up to - self.upto_token = self.current_token() + def __init__( + self, + local_instance_name: str, + current_token_function: Callable[[str], Token], + update_function: UpdateFunction, + ): + """Instantiate a Stream + + `current_token_function` and `update_function` are callbacks which + should be implemented by subclasses. + + `current_token_function` takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). On the writer process this is where + the writer has successfully written up to, whereas on other processes + this is the position which we have received updates up to over + replication. (Note that most streams have a single writer and so their + implementations ignore the instance name passed in). + + `update_function` is called to get updates for this stream between a + pair of stream tokens. See the `UpdateFunction` type definition for more + info. - def advance_current_token(self): - """Updates `upto_token` to "now", which updates up until which point - get_updates[_since] will fetch rows till. + Args: + local_instance_name: The instance name of the current process + current_token_function: callback to get the current token, as above + update_function: callback go get stream updates, as above """ - self.upto_token = self.current_token() + self.local_instance_name = local_instance_name + self.current_token = current_token_function + self.update_function = update_function + + # The token from which we last asked for updates + self.last_token = self.current_token(self.local_instance_name) def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.upto_token = self.current_token() - self.last_token = self.upto_token + self.last_token = self.current_token(self.local_instance_name) - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or - since the stream was constructed if it hadn't been called before), - until the `upto_token` + since the stream was constructed if it hadn't been called before). Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - updates, current_token = yield self.get_updates_since(self.last_token) + current_token = self.current_token(self.local_instance_name) + updates, current_token, limited = await self.get_updates_since( + self.local_instance_name, self.last_token, current_token + ) self.last_token = current_token - return updates, current_token + return updates, current_token, limited - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since( + self, instance_name: str, from_token: Token, upto_token: Token + ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - if from_token in ("NOW", "now"): - return [], self.upto_token - - current_token = self.upto_token from_token = int(from_token) - if from_token == current_token: - return [], current_token + if from_token == upto_token: + return [], upto_token, False - if self._LIMITED: - rows = yield self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 - ) + updates, upto_token, limited = await self.update_function( + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + ) + return updates, upto_token, limited - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - else: - rows = yield self.update_function(from_token, current_token) +def current_token_without_instance( + current_token: Callable[[], int] +) -> Callable[[str], int]: + """Takes a current token callback function for a single writer stream + that doesn't take an instance name parameter and wraps it in a function that + does accept an instance name parameter but ignores it. + """ + return lambda instance_name: current_token() + + +def db_query_to_update_function( + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> UpdateFunction: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(instance_name, from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - # check we didn't get more rows than the limit. - # doing it like this allows the update_function to be a generator. - if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) + return updates, upto_token, limited - return updates, current_token + return update_function - def current_token(self): - """Gets the current token of the underlying streams. Should be provided - by the sub classes - Returns: - int - """ - raise NotImplementedError() +def make_http_update_function(hs, stream_name: str) -> UpdateFunction: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ - def update_function(self, from_token, current_token, limit=None): - """Get updates between from_token and to_token. If Stream._LIMITED is - True then limit is provided, otherwise it's not. + client = ReplicationGetStreamUpdates.make_client(hs) - Returns: - Deferred(list(tuple)): the first entry in the tuple is the token for - that update, and the rest of the tuple gets used to construct - a ``ROW_TYPE`` instance - """ - raise NotImplementedError() + async def update_function( + instance_name: str, from_token: int, upto_token: int, limit: int + ) -> StreamUpdateResult: + result = await client( + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + ) + return result["updates"], result["upto_token"], result["limited"] + + return update_function class BackfillStream(Stream): @@ -226,94 +244,170 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ + BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), + ) + NAME = "backfill" ROW_TYPE = BackfillStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_backfill_token - self.update_function = store.get_all_new_backfill_event_rows - - super(BackfillStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_backfill_token), + db_query_to_update_function(store.get_all_new_backfill_event_rows), + ) class PresenceStream(Stream): + PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), + ) + NAME = "presence" - _LIMITED = False ROW_TYPE = PresenceStreamRow def __init__(self, hs): store = hs.get_datastore() - presence_handler = hs.get_presence_handler() - self.current_token = store.get_current_presence_token - self.update_function = presence_handler.get_all_presence_updates + if hs.config.worker_app is None: + # on the master, query the presence handler + presence_handler = hs.get_presence_handler() + update_function = db_query_to_update_function( + presence_handler.get_all_presence_updates + ) + else: + # Query master process + update_function = make_http_update_function(hs, self.NAME) - super(PresenceStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_presence_token), + update_function, + ) class TypingStream(Stream): + TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) + ) + NAME = "typing" - _LIMITED = False ROW_TYPE = TypingStreamRow def __init__(self, hs): typing_handler = hs.get_typing_handler() - self.current_token = typing_handler.get_current_token - self.update_function = typing_handler.get_all_typing_updates + if hs.config.worker_app is None: + # on the master, query the typing handler + update_function = db_query_to_update_function( + typing_handler.get_all_typing_updates + ) + else: + # Query master process + update_function = make_http_update_function(hs, self.NAME) - super(TypingStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(typing_handler.get_current_token), + update_function, + ) class ReceiptsStream(Stream): + ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), + ) + NAME = "receipts" ROW_TYPE = ReceiptsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_max_receipt_stream_id - self.update_function = store.get_all_updated_receipts - - super(ReceiptsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_max_receipt_stream_id), + db_query_to_update_function(store.get_all_updated_receipts), + ) class PushRulesStream(Stream): """A user has changed their push rules """ + PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow def __init__(self, hs): self.store = hs.get_datastore() - super(PushRulesStream, self).__init__(hs) + super(PushRulesStream, self).__init__( + hs.get_instance_name(), self._current_token, self._update_function + ) - def current_token(self): + def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): """A user has added/changed/removed a pusher """ + PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool + ) + NAME = "pushers" ROW_TYPE = PushersStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_pushers_stream_token - self.update_function = store.get_all_updated_pushers_rows - - super(PushersStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_pushers_stream_token), + db_query_to_update_function(store.get_all_updated_pushers_rows), + ) class CachesStream(Stream): @@ -321,120 +415,235 @@ class CachesStream(Stream): the cache on the workers """ + @attr.s + class CachesStreamRow: + """Stream to inform workers they should invalidate their cache. + + Attributes: + cache_func: Name of the cached function. + keys: The entry in the cache to invalidate. If None then will + invalidate all. + invalidation_ts: Timestamp of when the invalidation took place. + """ + + cache_func = attr.ib(type=str) + keys = attr.ib(type=Optional[List[Any]]) + invalidation_ts = attr.ib(type=int) + NAME = "caches" ROW_TYPE = CachesStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + self.store.get_cache_stream_token, + self._update_function, + ) - self.current_token = store.get_cache_stream_token - self.update_function = store.get_all_updated_caches + async def _update_function( + self, instance_name: str, from_token: int, upto_token: int, limit: int + ): + rows = await self.store.get_all_updated_caches( + instance_name, from_token, upto_token, limit + ) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - super(CachesStream, self).__init__(hs) + return updates, upto_token, limited class PublicRoomsStream(Stream): """The public rooms list changed """ + PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), + ) + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_current_public_room_stream_id - self.update_function = store.get_all_new_public_rooms - - super(PublicRoomsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_public_room_stream_id), + db_query_to_update_function(store.get_all_new_public_rooms), + ) class DeviceListsStream(Stream): - """Someone added/changed/removed a device + """Either a user has updated their devices or a remote server needs to be + told about a device update. """ + @attr.s + class DeviceListsStreamRow: + entity = attr.ib(type=str) + NAME = "device_lists" - _LIMITED = False ROW_TYPE = DeviceListsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_device_stream_token - self.update_function = store.get_all_device_list_changes_for_remotes - - super(DeviceListsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_device_stream_token), + db_query_to_update_function(store.get_all_device_list_changes_for_remotes), + ) class ToDeviceStream(Stream): """New to_device messages for a client """ + ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_to_device_stream_token - self.update_function = store.get_all_new_device_messages - - super(ToDeviceStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_to_device_stream_token), + db_query_to_update_function(store.get_all_new_device_messages), + ) class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict + ) + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_max_account_data_stream_id - self.update_function = store.get_all_updated_tags - - super(TagAccountDataStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_max_account_data_stream_id), + db_query_to_update_function(store.get_all_updated_tags), + ) class AccountDataStream(Stream): """Global or per room account data was changed """ + AccountDataStreamRow = namedtuple( + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str + ) + NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(self.store.get_max_account_data_stream_id), + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit + ) - self.current_token = self.store.get_max_account_data_stream_id + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True - super(AccountDataStream, self).__init__(hs) + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit + ) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) + for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type, content) - for stream_id, user_id, account_data_type, content in global_results + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results ) - return results + # We need to return a sorted list, so merge them together. + # + # Note: We order only by the stream ID to work around a bug where the + # same stream ID could appear in both `global_rows` and `room_rows`, + # leading to a comparison between the data tuples. The comparison could + # fail due to attempting to compare the `room_id` which results in a + # `TypeError` from comparing a `str` vs `None`. + updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) + return updates, to_token, limited class GroupServerStream(Stream): + GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict + ) + NAME = "groups" ROW_TYPE = GroupsStreamRow def __init__(self, hs): store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_group_stream_token), + db_query_to_update_function(store.get_all_groups_changes), + ) - self.current_token = store.get_group_stream_token - self.update_function = store.get_all_groups_changes - super(GroupServerStream, self).__init__(hs) +class UserSignatureStream(Stream): + """A user has signed their own device with their user-signing key + """ + + UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + + NAME = "user_signature" + ROW_TYPE = UserSignatureStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_device_stream_token), + db_query_to_update_function( + store.get_all_user_signature_changes_for_remotes + ), + ) |