summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/federation.py')
-rw-r--r--synapse/replication/tcp/streams/federation.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 4046bdec69..7f5af5852c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -18,6 +18,7 @@ import attr
 
 from synapse.replication.tcp.streams._base import (
     Stream,
+    Token,
     current_token_without_instance,
     make_http_update_function,
 )
@@ -47,7 +48,7 @@ class FederationStream(Stream):
             # will be a real FederationSender, which has stubs for current_token and
             # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
-            current_token = current_token_without_instance(
+            self.current_token_func = current_token_without_instance(
                 federation_sender.get_current_token
             )
             update_function: Callable[
@@ -57,15 +58,21 @@ class FederationStream(Stream):
         elif hs.should_send_federation():
             # federation sender: Query master process
             update_function = make_http_update_function(hs, self.NAME)
-            current_token = self._stub_current_token
+            self.current_token_func = self._stub_current_token
 
         else:
             # other worker: stub out the update function (we're not interested in
             # any updates so when we get a POSITION we do nothing)
             update_function = self._stub_update_function
-            current_token = self._stub_current_token
+            self.current_token_func = self._stub_current_token
 
-        super().__init__(hs.get_instance_name(), current_token, update_function)
+        super().__init__(hs.get_instance_name(), update_function)
+
+    def current_token(self, instance_name: str) -> Token:
+        return self.current_token_func(instance_name)
+
+    def minimal_local_current_token(self) -> Token:
+        return self.current_token(self.local_instance_name)
 
     @staticmethod
     def _stub_current_token(instance_name: str) -> int: