summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r--synapse/replication/tcp/streams/_base.py61
1 files changed, 26 insertions, 35 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index abf5c6c6a8..99cef97532 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,10 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple, Union
 
 import attr
 
@@ -153,61 +152,53 @@ class Stream(object):
         """
         self.last_token = self.current_token()
 
-    async def get_updates(self):
+    async def get_updates(self) -> Tuple[List[Tuple[int, JsonDict]], int, bool]:
         """Gets all updates since the last time this function was called (or
         since the stream was constructed if it hadn't been called before).
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            Resolves to a pair `(updates, new_last_token, limited)`, where
+            `updates` is a list of `(token, row)` entries, `new_last_token` is
+            the new position in stream, and `limited` is whether there are
+            more updates to fetch.
         """
-        updates, current_token = await self.get_updates_since(self.last_token)
+        current_token = self.current_token()
+        updates, current_token, limited = await self.get_updates_since(
+            self.last_token, current_token
+        )
         self.last_token = current_token
 
-        return updates, current_token
+        return updates, current_token, limited
 
     async def get_updates_since(
-        self, from_token: int
-    ) -> Tuple[List[Tuple[int, JsonDict]], int]:
+        self, from_token: Union[int, str], upto_token: int, limit: int = 100
+    ) -> Tuple[List[Tuple[int, JsonDict]], int, bool]:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Resolves to a pair `(updates, new_last_token)`, where `updates` is
-            a list of `(token, row)` entries and `new_last_token` is the new
-            position in stream.
+            Resolves to a pair `(updates, new_last_token, limited)`, where
+            `updates` is a list of `(token, row)` entries, `new_last_token` is
+            the new position in stream, and `limited` is whether there are
+            more updates to fetch.
         """
 
         if from_token in ("NOW", "now"):
-            return [], self.current_token()
-
-        current_token = self.current_token()
+            return [], upto_token, False
 
         from_token = int(from_token)
 
-        if from_token == current_token:
-            return [], current_token
-
-        rows = await self.update_function(
-            from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-        )
-
-        # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-        rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
+        if from_token == upto_token:
+            return [], upto_token, False
 
+        limited = False
+        rows = await self.update_function(from_token, upto_token, limit=limit)
         updates = [(row[0], row[1:]) for row in rows]
+        if len(updates) == limit:
+            upto_token = rows[-1][0]
+            limited = True
 
-        # check we didn't get more rows than the limit.
-        # doing it like this allows the update_function to be a generator.
-        if len(updates) >= MAX_EVENTS_BEHIND:
-            raise Exception("stream %s has fallen behind" % (self.NAME))
-
-        # The update function didn't hit the limit, so we must have got all
-        # the updates to `current_token`, and can return that as our new
-        # stream position.
-        return updates, current_token
+        return updates, upto_token, limited
 
     def current_token(self):
         """Gets the current token of the underlying streams. Should be provided