summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py8
-rw-r--r--synapse/replication/tcp/streams/_base.py2
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/util/id_generators.py36
4 files changed, 38 insertions, 12 deletions
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..d43eaf3a29 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
             int
         """
         return self._current
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a42de3f7d..1e92d52165 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -405,7 +405,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
+            store.get_cache_stream_token_for_writer,
             store.get_all_updated_caches,
         )
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 },
             )
 
-    def get_cache_stream_token(self, instance_name):
+    def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token(instance_name)
+            return self._cache_id_gen.get_current_token_for_writer(instance_name)
         else:
             return 0
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e2ddd01290..8276a755e5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -158,6 +158,14 @@ class StreamIdGenerator(object):
 
             return self._current
 
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
+
 
 class ChainedIdGenerator(object):
     """Used to generate new stream ids where the stream must be kept in sync
@@ -216,6 +224,14 @@ class ChainedIdGenerator(object):
             "Attempted to advance token on source for table %r", self._table
         )
 
+    def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
+
 
 class MultiWriterIdGenerator:
     """An ID generator that tracks a stream that can have multiple writers.
@@ -298,7 +314,7 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than what we currently
         # believe the ID to be. If not, then the sequence and table have got
         # out of sync somehow.
-        assert self.get_current_token() < next_id
+        assert self.get_current_token_for_writer(self._instance_name) < next_id
 
         with self._lock:
             self._unfinished_ids.add(next_id)
@@ -344,16 +360,18 @@ class MultiWriterIdGenerator:
                 curr = self._current_positions.get(self._instance_name, 0)
                 self._current_positions[self._instance_name] = max(curr, next_id)
 
-    def get_current_token(self, instance_name: str = None) -> int:
-        """Gets the current position of a named writer (defaults to current
-        instance).
-
-        Returns 0 if we don't have a position for the named writer (likely due
-        to it being a new writer).
+    def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
         """
 
-        if instance_name is None:
-            instance_name = self._instance_name
+        # Currently we don't support this operation, as it's not obvious how to
+        # condense the stream positions of multiple writers into a single int.
+        raise NotImplementedError()
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+        """
 
         with self._lock:
             return self._current_positions.get(instance_name, 0)