summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-08-19 10:39:31 +0100
committerGitHub <noreply@github.com>2020-08-19 10:39:31 +0100
commit76d21d14a042756b0c8a8f520dfd9ea09cf092c7 (patch)
tree8d24252c5e3bdcff215a8536da4bec429fd6fbbd
parentConvert events worker database to async/await. (#8071) (diff)
downloadsynapse-76d21d14a042756b0c8a8f520dfd9ea09cf092c7.tar.xz
Separate `get_current_token` into two. (#8113)
The function is used for two purposes: 1) for subscribers of streams to
get a token they can use to get further updates with, and 2) for
replication to track position of the writers of the stream.

For streams with a single writer the two scenarios produce the same
result, however the situation becomes complicated for streams with
multiple writers. The current `MultiWriterIdGenerator` does not
correctly handle the first case (which is not an issue as its only used
for the `caches` stream which nothing subscribes to outside of
replication).
-rw-r--r--changelog.d/8113.misc1
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py8
-rw-r--r--synapse/replication/tcp/streams/_base.py2
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/util/id_generators.py36
-rw-r--r--tests/storage/test_id_generators.py16
6 files changed, 47 insertions, 20 deletions
diff --git a/changelog.d/8113.misc b/changelog.d/8113.misc
new file mode 100644
index 0000000000..00bec4f8ef
--- /dev/null
+++ b/changelog.d/8113.misc
@@ -0,0 +1 @@
+Separate `get_current_token` into two since there are two different use cases for it.
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..d43eaf3a29 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
             int
         """
         return self._current
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a42de3f7d..1e92d52165 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -405,7 +405,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
+            store.get_cache_stream_token_for_writer,
             store.get_all_updated_caches,
         )
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 },
             )
 
-    def get_cache_stream_token(self, instance_name):
+    def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token(instance_name)
+            return self._cache_id_gen.get_current_token_for_writer(instance_name)
         else:
             return 0
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e2ddd01290..8276a755e5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -158,6 +158,14 @@ class StreamIdGenerator(object):
 
             return self._current
 
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
+
 
 class ChainedIdGenerator(object):
     """Used to generate new stream ids where the stream must be kept in sync
@@ -216,6 +224,14 @@ class ChainedIdGenerator(object):
             "Attempted to advance token on source for table %r", self._table
         )
 
+    def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
+
 
 class MultiWriterIdGenerator:
     """An ID generator that tracks a stream that can have multiple writers.
@@ -298,7 +314,7 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than what we currently
         # believe the ID to be. If not, then the sequence and table have got
         # out of sync somehow.
-        assert self.get_current_token() < next_id
+        assert self.get_current_token_for_writer(self._instance_name) < next_id
 
         with self._lock:
             self._unfinished_ids.add(next_id)
@@ -344,16 +360,18 @@ class MultiWriterIdGenerator:
                 curr = self._current_positions.get(self._instance_name, 0)
                 self._current_positions[self._instance_name] = max(curr, next_id)
 
-    def get_current_token(self, instance_name: str = None) -> int:
-        """Gets the current position of a named writer (defaults to current
-        instance).
-
-        Returns 0 if we don't have a position for the named writer (likely due
-        to it being a new writer).
+    def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
         """
 
-        if instance_name is None:
-            instance_name = self._instance_name
+        # Currently we don't support this operation, as it's not obvious how to
+        # condense the stream positions of multiple writers into a single int.
+        raise NotImplementedError()
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+        """
 
         with self._lock:
             return self._current_positions.get(instance_name, 0)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index e845410dae..7a05194653 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen = self._create_id_generator()
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token("master"), 7)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(id_gen.get_positions(), {"master": 7})
-                self.assertEqual(id_gen.get_current_token("master"), 7)
+                self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         self.get_success(_get_next_async())
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token("master"), 8)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
     def test_multi_instance(self):
         """Test that reads and writes from multiple processes are handled
@@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen = self._create_id_generator("second")
 
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token("first"), 3)
-        self.assertEqual(first_id_gen.get_current_token("second"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen = self._create_id_generator()
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token("master"), 7)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
             self.assertEqual(stream_id, 8)
 
             self.assertEqual(id_gen.get_positions(), {"master": 7})
-            self.assertEqual(id_gen.get_current_token("master"), 7)
+            self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token("master"), 8)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)