diff options
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 96 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 5 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/federation.py | 6 |
3 files changed, 60 insertions, 47 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d7e9371a00..d64cbc5cc8 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union import attr @@ -40,10 +40,6 @@ class Stream(object): # The type of the row. Used by the default impl of parse_row. ROW_TYPE = None # type: Any - # Whether the update function is only available on master. If True then - # calls to get updates are proxied to the master via a HTTP call. - _QUERY_MASTER = False - @classmethod def parse_row(cls, row): """Parse a row received over replication @@ -60,10 +56,6 @@ class Stream(object): return cls.ROW_TYPE(*row) def __init__(self, hs): - self._is_worker = hs.config.worker_app is not None - - if self._QUERY_MASTER and self._is_worker: - self._replication_client = ReplicationGetStreamUpdates.make_client(hs) # The token from which we last asked for updates self.last_token = self.current_token() @@ -110,23 +102,10 @@ class Stream(object): if from_token == upto_token: return [], upto_token, False - if self._is_worker and self._QUERY_MASTER: - result = await self._replication_client( - stream_name=self.NAME, - from_token=from_token, - upto_token=upto_token, - limit=limit, - ) - return result["updates"], result["upto_token"], result["limited"] - else: - limited = False - rows = await self.update_function(from_token, upto_token, limit=limit) - updates = [(row[0], row[1:]) for row in rows] - if len(updates) == limit: - upto_token = rows[-1][0] - limited = True - - return updates, upto_token, limited + updates, upto_token, limited = await self.update_function( + from_token, upto_token, limit=limit, + ) + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -148,6 +127,26 @@ class Stream(object): raise NotImplementedError() +def db_query_to_update_function( + query_function: Callable[[int, int, int], Awaitable[List[tuple]]] +) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -171,7 +170,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = store.get_all_new_backfill_event_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore super(BackfillStream, self).__init__(hs) @@ -192,16 +191,20 @@ class PresenceStream(Stream): NAME = "presence" ROW_TYPE = PresenceStreamRow - _QUERY_MASTER = True def __init__(self, hs): store = hs.get_datastore() presence_handler = hs.get_presence_handler() + self._is_worker = hs.config.worker_app is not None + self.current_token = store.get_current_presence_token # type: ignore if hs.config.worker_app is None: - self.update_function = presence_handler.get_all_presence_updates # type: ignore + self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + else: + # Query master process + self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore super(PresenceStream, self).__init__(hs) @@ -213,7 +216,6 @@ class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow - _QUERY_MASTER = True def __init__(self, hs): typing_handler = hs.get_typing_handler() @@ -221,7 +223,10 @@ class TypingStream(Stream): self.current_token = typing_handler.get_current_token # type: ignore if hs.config.worker_app is None: - self.update_function = typing_handler.get_all_typing_updates # type: ignore + self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + else: + # Query master process + self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore super(TypingStream, self).__init__(hs) @@ -245,7 +250,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = store.get_all_updated_receipts # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -269,7 +274,13 @@ class PushRulesStream(Stream): async def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], row[2]) for row in rows], to_token, limited class PushersStream(Stream): @@ -288,7 +299,7 @@ class PushersStream(Stream): store = hs.get_datastore() self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = store.get_all_updated_pushers_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore super(PushersStream, self).__init__(hs) @@ -320,7 +331,7 @@ class CachesStream(Stream): store = hs.get_datastore() self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = store.get_all_updated_caches # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore super(CachesStream, self).__init__(hs) @@ -346,7 +357,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = store.get_all_new_public_rooms # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -367,7 +378,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -385,7 +396,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = store.get_all_new_device_messages # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -405,7 +416,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = store.get_all_updated_tags # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -425,10 +436,11 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() self.current_token = self.store.get_max_account_data_stream_id # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(AccountDataStream, self).__init__(hs) - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) @@ -455,7 +467,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() self.current_token = store.get_group_stream_token # type: ignore - self.update_function = store.get_all_groups_changes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore super(GroupServerStream, self).__init__(hs) @@ -473,6 +485,6 @@ class UserSignatureStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index b3afabb8cd..c6a595629f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import Tuple, Type import attr -from ._base import Stream +from ._base import Stream, db_query_to_update_function """Handling of the 'events' replication stream @@ -117,10 +117,11 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() self.current_token = self._store.get_current_events_token # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(EventsStream, self).__init__(hs) - async def update_function(self, from_token, current_token, limit=None): + async def _update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 67e0eaa262..48c1d45718 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,7 +17,7 @@ from collections import namedtuple from twisted.internet import defer -from synapse.replication.tcp.streams._base import Stream +from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function class FederationStream(Stream): @@ -44,9 +44,9 @@ class FederationStream(Stream): if hs.config.worker_app is None or hs.should_send_federation(): federation_sender = hs.get_federation_sender() self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore + self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore else: self.current_token = lambda: 0 # type: ignore - self.update_function = lambda *args, **kwargs: defer.succeed([]) # type: ignore + self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore super(FederationStream, self).__init__(hs) |