diff options
Diffstat (limited to 'synapse/replication/tcp/streams/federation.py')
-rw-r--r-- | synapse/replication/tcp/streams/federation.py | 42 |
1 files changed, 27 insertions, 15 deletions
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 615f3dc9ac..e8bd52e389 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,15 +15,7 @@ # limitations under the License. from collections import namedtuple -from ._base import Stream - -FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), -) +from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function class FederationStream(Stream): @@ -31,13 +23,33 @@ class FederationStream(Stream): sending disabled. """ + FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), + ) + NAME = "federation" ROW_TYPE = FederationStreamRow def __init__(self, hs): - federation_sender = hs.get_federation_sender() - - self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore - - super(FederationStream, self).__init__(hs) + # Not all synapse instances will have a federation sender instance, + # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, + # so we stub the stream out when that is the case. + if hs.config.worker_app is None or hs.should_send_federation(): + federation_sender = hs.get_federation_sender() + current_token = federation_sender.get_current_token + update_function = db_query_to_update_function( + federation_sender.get_replication_rows + ) + else: + current_token = lambda: 0 + update_function = self._stub_update_function + + super().__init__(hs.get_instance_name(), current_token, update_function) + + @staticmethod + async def _stub_update_function(instance_name, from_token, upto_token, limit): + return [], upto_token, False |