summary refs log tree commit diff
path: root/synapse/replication/tcp/resource.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/resource.py')
-rw-r--r--synapse/replication/tcp/resource.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index b690abedad..002171ce7c 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
+from synapse.replication.tcp.streams import (
+    STREAMS_MAP,
+    CachesStream,
+    FederationStream,
+    Stream,
+)
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -71,11 +76,16 @@ class ReplicationStreamer(object):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
 
         self._replication_torture_level = hs.config.replication_torture_level
 
         # Work out list of streams that this instance is the source of.
         self.streams = []  # type: List[Stream]
+
+        # All workers can write to the cache invalidation stream.
+        self.streams.append(CachesStream(hs))
+
         if hs.config.worker_app is None:
             for stream in STREAMS_MAP.values():
                 if stream == FederationStream and hs.config.send_federation:
@@ -83,6 +93,10 @@ class ReplicationStreamer(object):
                     # has been disabled on the master.
                     continue
 
+                if stream == CachesStream:
+                    # We've already added it above.
+                    continue
+
                 self.streams.append(stream(hs))
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
@@ -145,7 +159,9 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.current_token():
+                        if stream.last_token == stream.current_token(
+                            self._instance_name
+                        ):
                             continue
 
                         if self._replication_torture_level:
@@ -157,7 +173,7 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.current_token(),
+                            stream.current_token(self._instance_name),
                         )
                         try:
                             updates, current_token, limited = await stream.get_updates()