summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r--synapse/replication/tcp/streams/_base.py87
1 files changed, 60 insertions, 27 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 084604e8b0..b48a6a3e91 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,20 +95,25 @@ class Stream(object):
     def __init__(
         self,
         local_instance_name: str,
-        current_token_function: Callable[[], Token],
+        current_token_function: Callable[[str], Token],
         update_function: UpdateFunction,
     ):
         """Instantiate a Stream
 
-        current_token_function and update_function are callbacks which should be
-        implemented by subclasses.
+        `current_token_function` and `update_function` are callbacks which
+        should be implemented by subclasses.
 
-        current_token_function is called to get the current token of the underlying
-        stream. It is only meaningful on the process that is the source of the
-        replication stream (ie, usually the master).
+        `current_token_function` takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process). On the writer process this is where
+        the writer has successfully written up to, whereas on other processes
+        this is the position which we have received updates up to over
+        replication. (Note that most streams have a single writer and so their
+        implementations ignore the instance name passed in).
 
-        update_function is called to get updates for this stream between a pair of
-        stream tokens. See the UpdateFunction type definition for more info.
+        `update_function` is called to get updates for this stream between a
+        pair of stream tokens. See the `UpdateFunction` type definition for more
+        info.
 
         Args:
             local_instance_name: The instance name of the current process
@@ -120,13 +125,13 @@ class Stream(object):
         self.update_function = update_function
 
         # The token from which we last asked for updates
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
@@ -138,7 +143,7 @@ class Stream(object):
             position in stream, and `limited` is whether there are more updates
             to fetch.
         """
-        current_token = self.current_token()
+        current_token = self.current_token(self.local_instance_name)
         updates, current_token, limited = await self.get_updates_since(
             self.local_instance_name, self.last_token, current_token
         )
@@ -170,6 +175,16 @@ class Stream(object):
         return updates, upto_token, limited
 
 
+def current_token_without_instance(
+    current_token: Callable[[], int]
+) -> Callable[[str], int]:
+    """Takes a current token callback function for a single writer stream
+    that doesn't take an instance name parameter and wraps it in a function that
+    does accept an instance name parameter but ignores it.
+    """
+    return lambda instance_name: current_token()
+
+
 def db_query_to_update_function(
     query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
 ) -> UpdateFunction:
@@ -235,7 +250,7 @@ class BackfillStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_backfill_token,
+            current_token_without_instance(store.get_current_backfill_token),
             db_query_to_update_function(store.get_all_new_backfill_event_rows),
         )
 
@@ -271,7 +286,9 @@ class PresenceStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), store.get_current_presence_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_presence_token),
+            update_function,
         )
 
 
@@ -296,7 +313,9 @@ class TypingStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), typing_handler.get_current_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(typing_handler.get_current_token),
+            update_function,
         )
 
 
@@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_receipt_stream_id,
+            current_token_without_instance(store.get_max_receipt_stream_id),
             db_query_to_update_function(store.get_all_updated_receipts),
         )
 
@@ -339,7 +358,7 @@ class PushRulesStream(Stream):
             hs.get_instance_name(), self._current_token, self._update_function
         )
 
-    def _current_token(self) -> int:
+    def _current_token(self, instance_name: str) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
@@ -373,7 +392,7 @@ class PushersStream(Stream):
 
         super().__init__(
             hs.get_instance_name(),
-            store.get_pushers_stream_token,
+            current_token_without_instance(store.get_pushers_stream_token),
             db_query_to_update_function(store.get_all_updated_pushers_rows),
         )
 
@@ -402,12 +421,26 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
-            db_query_to_update_function(store.get_all_updated_caches),
+            self.store.get_cache_stream_token,
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, upto_token: int, limit: int
+    ):
+        rows = await self.store.get_all_updated_caches(
+            instance_name, from_token, upto_token, limit
         )
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
@@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_public_room_stream_id,
+            current_token_without_instance(store.get_current_public_room_stream_id),
             db_query_to_update_function(store.get_all_new_public_rooms),
         )
 
@@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
         )
 
@@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_to_device_stream_token,
+            current_token_without_instance(store.get_to_device_stream_token),
             db_query_to_update_function(store.get_all_new_device_messages),
         )
 
@@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_account_data_stream_id,
+            current_token_without_instance(store.get_max_account_data_stream_id),
             db_query_to_update_function(store.get_all_updated_tags),
         )
 
@@ -510,7 +543,7 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self.store.get_max_account_data_stream_id,
+            current_token_without_instance(self.store.get_max_account_data_stream_id),
             db_query_to_update_function(self._update_function),
         )
 
@@ -541,7 +574,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_group_stream_token,
+            current_token_without_instance(store.get_group_stream_token),
             db_query_to_update_function(store.get_all_groups_changes),
         )
 
@@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(
                 store.get_all_user_signature_changes_for_remotes
             ),