diff options
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 99 |
1 files changed, 54 insertions, 45 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d5b9c2831b..2699e466bc 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union import attr @@ -40,10 +40,6 @@ class Stream(object): # The type of the row. Used by the default impl of parse_row. ROW_TYPE = None # type: Any - # Whether the update function is only available on master. If True then - # calls to get updates are proxied to the master via a HTTP call. - _QUERY_MASTER = False - @classmethod def parse_row(cls, row): """Parse a row received over replication @@ -60,10 +56,6 @@ class Stream(object): return cls.ROW_TYPE(*row) def __init__(self, hs): - self._is_worker = hs.config.worker_app is not None - - if self._QUERY_MASTER and self._is_worker: - self._replication_client = ReplicationGetStreamUpdates.make_client(hs) # The token from which we last asked for updates self.last_token = self.current_token() @@ -105,31 +97,15 @@ class Stream(object): to fetch. """ - if from_token in ("NOW", "now"): - return [], upto_token, False - from_token = int(from_token) if from_token == upto_token: return [], upto_token, False - if self._is_worker and self._QUERY_MASTER: - result = await self._replication_client( - stream_name=self.NAME, - from_token=from_token, - upto_token=upto_token, - limit=limit, - ) - return result["updates"], result["upto_token"], result["limited"] - else: - limited = False - rows = await self.update_function(from_token, upto_token, limit=limit) - updates = [(row[0], row[1:]) for row in rows] - if len(updates) == limit: - upto_token = rows[-1][0] - limited = True - - return updates, upto_token, limited + updates, upto_token, limited = await self.update_function( + from_token, upto_token, limit=limit, + ) + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -151,6 +127,26 @@ class Stream(object): raise NotImplementedError() +def db_query_to_update_function( + query_function: Callable[[int, int, int], Awaitable[List[tuple]]] +) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -174,7 +170,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = store.get_all_new_backfill_event_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore super(BackfillStream, self).__init__(hs) @@ -195,16 +191,20 @@ class PresenceStream(Stream): NAME = "presence" ROW_TYPE = PresenceStreamRow - _QUERY_MASTER = True def __init__(self, hs): store = hs.get_datastore() presence_handler = hs.get_presence_handler() + self._is_worker = hs.config.worker_app is not None + self.current_token = store.get_current_presence_token # type: ignore if hs.config.worker_app is None: - self.update_function = presence_handler.get_all_presence_updates # type: ignore + self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + else: + # Query master process + self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore super(PresenceStream, self).__init__(hs) @@ -216,7 +216,6 @@ class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow - _QUERY_MASTER = True def __init__(self, hs): typing_handler = hs.get_typing_handler() @@ -224,7 +223,10 @@ class TypingStream(Stream): self.current_token = typing_handler.get_current_token # type: ignore if hs.config.worker_app is None: - self.update_function = typing_handler.get_all_typing_updates # type: ignore + self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + else: + # Query master process + self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore super(TypingStream, self).__init__(hs) @@ -248,7 +250,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = store.get_all_updated_receipts # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -272,7 +274,13 @@ class PushRulesStream(Stream): async def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): @@ -291,7 +299,7 @@ class PushersStream(Stream): store = hs.get_datastore() self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = store.get_all_updated_pushers_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore super(PushersStream, self).__init__(hs) @@ -323,7 +331,7 @@ class CachesStream(Stream): store = hs.get_datastore() self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = store.get_all_updated_caches # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore super(CachesStream, self).__init__(hs) @@ -349,7 +357,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = store.get_all_new_public_rooms # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -370,7 +378,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -388,7 +396,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = store.get_all_new_device_messages # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -408,7 +416,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = store.get_all_updated_tags # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -428,10 +436,11 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() self.current_token = self.store.get_max_account_data_stream_id # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(AccountDataStream, self).__init__(hs) - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) @@ -458,7 +467,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() self.current_token = store.get_group_stream_token # type: ignore - self.update_function = store.get_all_groups_changes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore super(GroupServerStream, self).__init__(hs) @@ -476,6 +485,6 @@ class UserSignatureStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore super(UserSignatureStream, self).__init__(hs) |