summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r--synapse/replication/tcp/streams/_base.py66
1 files changed, 26 insertions, 40 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a8b6e9df1..abf5c6c6a8 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -17,10 +17,12 @@
 import itertools
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Tuple
 
 import attr
 
+from synapse.types import JsonDict
+
 logger = logging.getLogger(__name__)
 
 
@@ -119,13 +121,12 @@ class Stream(object):
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
-    time it was called up until the point `advance_current_token` was called.
+    time it was called.
     """
 
     NAME = None  # type: str  # The name of the stream
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
-    _LIMITED = True  # Whether the update function takes a limit
 
     @classmethod
     def parse_row(cls, row):
@@ -146,26 +147,15 @@ class Stream(object):
         # The token from which we last asked for updates
         self.last_token = self.current_token()
 
-        # The token that we will get updates up to
-        self.upto_token = self.current_token()
-
-    def advance_current_token(self):
-        """Updates `upto_token` to "now", which updates up until which point
-        get_updates[_since] will fetch rows till.
-        """
-        self.upto_token = self.current_token()
-
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.upto_token = self.current_token()
-        self.last_token = self.upto_token
+        self.last_token = self.current_token()
 
     async def get_updates(self):
         """Gets all updates since the last time this function was called (or
-        since the stream was constructed if it hadn't been called before),
-        until the `upto_token`
+        since the stream was constructed if it hadn't been called before).
 
         Returns:
             Deferred[Tuple[List[Tuple[int, Any]], int]:
@@ -178,44 +168,45 @@ class Stream(object):
 
         return updates, current_token
 
-    async def get_updates_since(self, from_token):
+    async def get_updates_since(
+        self, from_token: int
+    ) -> Tuple[List[Tuple[int, JsonDict]], int]:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            Resolves to a pair `(updates, new_last_token)`, where `updates` is
+            a list of `(token, row)` entries and `new_last_token` is the new
+            position in stream.
         """
+
         if from_token in ("NOW", "now"):
-            return [], self.upto_token
+            return [], self.current_token()
 
-        current_token = self.upto_token
+        current_token = self.current_token()
 
         from_token = int(from_token)
 
         if from_token == current_token:
             return [], current_token
 
-        logger.info("get_updates_since: %s", self.__class__)
-        if self._LIMITED:
-            rows = await self.update_function(
-                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-            )
+        rows = await self.update_function(
+            from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+        )
 
-            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-        else:
-            rows = await self.update_function(from_token, current_token)
+        # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
+        rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
 
         updates = [(row[0], row[1:]) for row in rows]
 
         # check we didn't get more rows than the limit.
         # doing it like this allows the update_function to be a generator.
-        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
+        if len(updates) >= MAX_EVENTS_BEHIND:
             raise Exception("stream %s has fallen behind" % (self.NAME))
 
+        # The update function didn't hit the limit, so we must have got all
+        # the updates to `current_token`, and can return that as our new
+        # stream position.
         return updates, current_token
 
     def current_token(self):
@@ -227,9 +218,8 @@ class Stream(object):
         """
         raise NotImplementedError()
 
-    def update_function(self, from_token, current_token, limit=None):
-        """Get updates between from_token and to_token. If Stream._LIMITED is
-        True then limit is provided, otherwise it's not.
+    def update_function(self, from_token, current_token, limit):
+        """Get updates between from_token and to_token.
 
         Returns:
             Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -257,7 +247,6 @@ class BackfillStream(Stream):
 
 class PresenceStream(Stream):
     NAME = "presence"
-    _LIMITED = False
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs):
@@ -272,7 +261,6 @@ class PresenceStream(Stream):
 
 class TypingStream(Stream):
     NAME = "typing"
-    _LIMITED = False
     ROW_TYPE = TypingStreamRow
 
     def __init__(self, hs):
@@ -372,7 +360,6 @@ class DeviceListsStream(Stream):
     """
 
     NAME = "device_lists"
-    _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs):
@@ -462,7 +449,6 @@ class UserSignatureStream(Stream):
     """
 
     NAME = "user_signature"
-    _LIMITED = False
     ROW_TYPE = UserSignatureStreamRow
 
     def __init__(self, hs):