summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r--synapse/replication/tcp/streams/_base.py99
1 files changed, 54 insertions, 45 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index d5b9c2831b..2699e466bc 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
 
 import attr
 
@@ -40,10 +40,6 @@ class Stream(object):
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
 
-    # Whether the update function is only available on master. If True then
-    # calls to get updates are proxied to the master via a HTTP call.
-    _QUERY_MASTER = False
-
     @classmethod
     def parse_row(cls, row):
         """Parse a row received over replication
@@ -60,10 +56,6 @@ class Stream(object):
         return cls.ROW_TYPE(*row)
 
     def __init__(self, hs):
-        self._is_worker = hs.config.worker_app is not None
-
-        if self._QUERY_MASTER and self._is_worker:
-            self._replication_client = ReplicationGetStreamUpdates.make_client(hs)
 
         # The token from which we last asked for updates
         self.last_token = self.current_token()
@@ -105,31 +97,15 @@ class Stream(object):
             to fetch.
         """
 
-        if from_token in ("NOW", "now"):
-            return [], upto_token, False
-
         from_token = int(from_token)
 
         if from_token == upto_token:
             return [], upto_token, False
 
-        if self._is_worker and self._QUERY_MASTER:
-            result = await self._replication_client(
-                stream_name=self.NAME,
-                from_token=from_token,
-                upto_token=upto_token,
-                limit=limit,
-            )
-            return result["updates"], result["upto_token"], result["limited"]
-        else:
-            limited = False
-            rows = await self.update_function(from_token, upto_token, limit=limit)
-            updates = [(row[0], row[1:]) for row in rows]
-            if len(updates) == limit:
-                upto_token = rows[-1][0]
-                limited = True
-
-            return updates, upto_token, limited
+        updates, upto_token, limited = await self.update_function(
+            from_token, upto_token, limit=limit,
+        )
+        return updates, upto_token, limited
 
     def current_token(self):
         """Gets the current token of the underlying streams. Should be provided
@@ -151,6 +127,26 @@ class Stream(object):
         raise NotImplementedError()
 
 
+def db_query_to_update_function(
+    query_function: Callable[[int, int, int], Awaitable[List[tuple]]]
+) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) == limit:
+            upto_token = rows[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
+
+    return update_function
+
+
 class BackfillStream(Stream):
     """We fetched some old events and either we had never seen that event before
     or it went from being an outlier to not.
@@ -174,7 +170,7 @@ class BackfillStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows)  # type: ignore
 
         super(BackfillStream, self).__init__(hs)
 
@@ -195,16 +191,20 @@ class PresenceStream(Stream):
 
     NAME = "presence"
     ROW_TYPE = PresenceStreamRow
-    _QUERY_MASTER = True
 
     def __init__(self, hs):
         store = hs.get_datastore()
         presence_handler = hs.get_presence_handler()
 
+        self._is_worker = hs.config.worker_app is not None
+
         self.current_token = store.get_current_presence_token  # type: ignore
 
         if hs.config.worker_app is None:
-            self.update_function = presence_handler.get_all_presence_updates  # type: ignore
+            self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = ReplicationGetStreamUpdates.make_client(hs)  # type: ignore
 
         super(PresenceStream, self).__init__(hs)
 
@@ -216,7 +216,6 @@ class TypingStream(Stream):
 
     NAME = "typing"
     ROW_TYPE = TypingStreamRow
-    _QUERY_MASTER = True
 
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
@@ -224,7 +223,10 @@ class TypingStream(Stream):
         self.current_token = typing_handler.get_current_token  # type: ignore
 
         if hs.config.worker_app is None:
-            self.update_function = typing_handler.get_all_typing_updates  # type: ignore
+            self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = ReplicationGetStreamUpdates.make_client(hs)  # type: ignore
 
         super(TypingStream, self).__init__(hs)
 
@@ -248,7 +250,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_receipts  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_receipts)  # type: ignore
 
         super(ReceiptsStream, self).__init__(hs)
 
@@ -272,7 +274,13 @@ class PushRulesStream(Stream):
 
     async def update_function(self, from_token, to_token, limit):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], (row[2],)) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
@@ -291,7 +299,7 @@ class PushersStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows)  # type: ignore
 
         super(PushersStream, self).__init__(hs)
 
@@ -323,7 +331,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_caches  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_caches)  # type: ignore
 
         super(CachesStream, self).__init__(hs)
 
@@ -349,7 +357,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = store.get_all_new_public_rooms  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_public_rooms)  # type: ignore
 
         super(PublicRoomsStream, self).__init__(hs)
 
@@ -370,7 +378,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes)  # type: ignore
 
         super(DeviceListsStream, self).__init__(hs)
 
@@ -388,7 +396,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = store.get_all_new_device_messages  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_device_messages)  # type: ignore
 
         super(ToDeviceStream, self).__init__(hs)
 
@@ -408,7 +416,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_tags  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_tags)  # type: ignore
 
         super(TagAccountDataStream, self).__init__(hs)
 
@@ -428,10 +436,11 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
 
         self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(AccountDataStream, self).__init__(hs)
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
@@ -458,7 +467,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = store.get_all_groups_changes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_groups_changes)  # type: ignore
 
         super(GroupServerStream, self).__init__(hs)
 
@@ -476,6 +485,6 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes)  # type: ignore
 
         super(UserSignatureStream, self).__init__(hs)