summary refs log tree commit diff
path: root/synapse/replication/tcp/resource.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/resource.py')
-rw-r--r--synapse/replication/tcp/resource.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 6e2ebaf614..4374e99e32 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
 
 import logging
 import random
-from typing import Any, List
+from typing import Any, Dict, List
 
 from six import itervalues
 
@@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.metrics import Measure, measure_func
 
 from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
+from .streams import STREAMS_MAP, Stream
 from .streams.federation import FederationStream
 
 stream_updates_counter = Counter(
@@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
     """
 
     def __init__(self, hs):
-        self.streamer = ReplicationStreamer(hs)
+        self.streamer = hs.get_replication_streamer()
         self.clock = hs.get_clock()
         self.server_name = hs.config.server_name
 
@@ -133,6 +133,11 @@ class ReplicationStreamer(object):
         for conn in self.connections:
             conn.send_error("server shutting down")
 
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a mapp from stream name to stream instance.
+        """
+        return self.streams_by_name
+
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -190,7 +195,8 @@ class ReplicationStreamer(object):
                             stream.current_token(),
                         )
                         try:
-                            updates, current_token = await stream.get_updates()
+                            updates, current_token, limited = await stream.get_updates()
+                            self.pending_updates |= limited
                         except Exception:
                             logger.info("Failed to handle stream %s", stream.NAME)
                             raise
@@ -226,8 +232,7 @@ class ReplicationStreamer(object):
             self.pending_updates = False
             self.is_looping = False
 
-    @measure_func("repl.get_stream_updates")
-    async def get_stream_updates(self, stream_name, token):
+    def get_stream_token(self, stream_name):
         """For a given stream get all updates since token. This is called when
         a client first subscribes to a stream.
         """
@@ -235,7 +240,7 @@ class ReplicationStreamer(object):
         if not stream:
             raise Exception("unknown stream %s", stream_name)
 
-        return await stream.get_updates_since(token)
+        return stream.current_token()
 
     @measure_func("repl.federation_ack")
     def federation_ack(self, token):