diff options
author | Erik Johnston <erik@matrix.org> | 2020-03-24 17:21:26 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2020-03-24 17:21:26 +0000 |
commit | 604f57f1bd8ab251a9d0452f96197328cfe53d05 (patch) | |
tree | 8603ab2cca479f4f043e9818ffe926da6b83c1f7 /synapse | |
parent | Shuffle around code typing handlers (diff) | |
parent | Fixup push rules stream (diff) | |
download | synapse-604f57f1bd8ab251a9d0452f96197328cfe53d05.tar.xz |
Merge branch 'erikj/catchup_on_worker' into erikj/split_out_typing
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/replication/http/streams.py | 6 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 6 | ||||
-rw-r--r-- | synapse/replication/tcp/resource.py | 7 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 99 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 5 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/federation.py | 6 |
6 files changed, 73 insertions, 56 deletions
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index 141df68787..ffd4c61993 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -47,9 +47,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) - from synapse.replication.tcp.streams import STREAMS_MAP - - self.streams = {stream.NAME: stream(hs) for stream in STREAMS_MAP.values()} + # We pull the streams from the replication steamer (if we try and make + # them ourselves we end up in an import loop). + self.streams = hs.get_replication_streamer().get_streams() @staticmethod def _serialize_payload(stream_name, from_token, upto_token, limit): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index d4456f42f3..de6abfc82e 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -72,12 +72,14 @@ from synapse.replication.tcp.commands import ( ServerCommand, ) from synapse.replication.tcp.streams import STREAMS_MAP, Stream +from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string MYPY = False if MYPY: - import synapse.server + from synapse.server import HomeServer + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] @@ -423,7 +425,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, - hs: "synapse.server.HomeServer", + hs: "HomeServer", client_name: str, server_name: str, clock: Clock, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index c9d671210b..2ce171edbd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import List +from typing import List, Dict from prometheus_client import Counter @@ -97,6 +97,11 @@ class ReplicationStreamer(object): self.client = hs.get_tcp_replication() + def get_streams(self) -> Dict[str, Stream]: + """Get a mapp from stream name to stream instance. + """ + return self.streams_by_name + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d5b9c2831b..2699e466bc 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union import attr @@ -40,10 +40,6 @@ class Stream(object): # The type of the row. Used by the default impl of parse_row. ROW_TYPE = None # type: Any - # Whether the update function is only available on master. If True then - # calls to get updates are proxied to the master via a HTTP call. - _QUERY_MASTER = False - @classmethod def parse_row(cls, row): """Parse a row received over replication @@ -60,10 +56,6 @@ class Stream(object): return cls.ROW_TYPE(*row) def __init__(self, hs): - self._is_worker = hs.config.worker_app is not None - - if self._QUERY_MASTER and self._is_worker: - self._replication_client = ReplicationGetStreamUpdates.make_client(hs) # The token from which we last asked for updates self.last_token = self.current_token() @@ -105,31 +97,15 @@ class Stream(object): to fetch. """ - if from_token in ("NOW", "now"): - return [], upto_token, False - from_token = int(from_token) if from_token == upto_token: return [], upto_token, False - if self._is_worker and self._QUERY_MASTER: - result = await self._replication_client( - stream_name=self.NAME, - from_token=from_token, - upto_token=upto_token, - limit=limit, - ) - return result["updates"], result["upto_token"], result["limited"] - else: - limited = False - rows = await self.update_function(from_token, upto_token, limit=limit) - updates = [(row[0], row[1:]) for row in rows] - if len(updates) == limit: - upto_token = rows[-1][0] - limited = True - - return updates, upto_token, limited + updates, upto_token, limited = await self.update_function( + from_token, upto_token, limit=limit, + ) + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -151,6 +127,26 @@ class Stream(object): raise NotImplementedError() +def db_query_to_update_function( + query_function: Callable[[int, int, int], Awaitable[List[tuple]]] +) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -174,7 +170,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = store.get_all_new_backfill_event_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore super(BackfillStream, self).__init__(hs) @@ -195,16 +191,20 @@ class PresenceStream(Stream): NAME = "presence" ROW_TYPE = PresenceStreamRow - _QUERY_MASTER = True def __init__(self, hs): store = hs.get_datastore() presence_handler = hs.get_presence_handler() + self._is_worker = hs.config.worker_app is not None + self.current_token = store.get_current_presence_token # type: ignore if hs.config.worker_app is None: - self.update_function = presence_handler.get_all_presence_updates # type: ignore + self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + else: + # Query master process + self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore super(PresenceStream, self).__init__(hs) @@ -216,7 +216,6 @@ class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow - _QUERY_MASTER = True def __init__(self, hs): typing_handler = hs.get_typing_handler() @@ -224,7 +223,10 @@ class TypingStream(Stream): self.current_token = typing_handler.get_current_token # type: ignore if hs.config.worker_app is None: - self.update_function = typing_handler.get_all_typing_updates # type: ignore + self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + else: + # Query master process + self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore super(TypingStream, self).__init__(hs) @@ -248,7 +250,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = store.get_all_updated_receipts # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -272,7 +274,13 @@ class PushRulesStream(Stream): async def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): @@ -291,7 +299,7 @@ class PushersStream(Stream): store = hs.get_datastore() self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = store.get_all_updated_pushers_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore super(PushersStream, self).__init__(hs) @@ -323,7 +331,7 @@ class CachesStream(Stream): store = hs.get_datastore() self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = store.get_all_updated_caches # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore super(CachesStream, self).__init__(hs) @@ -349,7 +357,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = store.get_all_new_public_rooms # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -370,7 +378,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -388,7 +396,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = store.get_all_new_device_messages # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -408,7 +416,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = store.get_all_updated_tags # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -428,10 +436,11 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() self.current_token = self.store.get_max_account_data_stream_id # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(AccountDataStream, self).__init__(hs) - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) @@ -458,7 +467,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() self.current_token = store.get_group_stream_token # type: ignore - self.update_function = store.get_all_groups_changes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore super(GroupServerStream, self).__init__(hs) @@ -476,6 +485,6 @@ class UserSignatureStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index b3afabb8cd..c6a595629f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import Tuple, Type import attr -from ._base import Stream +from ._base import Stream, db_query_to_update_function """Handling of the 'events' replication stream @@ -117,10 +117,11 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() self.current_token = self._store.get_current_events_token # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(EventsStream, self).__init__(hs) - async def update_function(self, from_token, current_token, limit=None): + async def _update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 67e0eaa262..48c1d45718 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,7 +17,7 @@ from collections import namedtuple from twisted.internet import defer -from synapse.replication.tcp.streams._base import Stream +from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function class FederationStream(Stream): @@ -44,9 +44,9 @@ class FederationStream(Stream): if hs.config.worker_app is None or hs.should_send_federation(): federation_sender = hs.get_federation_sender() self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore + self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore else: self.current_token = lambda: 0 # type: ignore - self.update_function = lambda *args, **kwargs: defer.succeed([]) # type: ignore + self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore super(FederationStream, self).__init__(hs) |