summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/replication/tcp/streams/_base.py96
-rw-r--r--synapse/replication/tcp/streams/events.py5
-rw-r--r--synapse/replication/tcp/streams/federation.py6
3 files changed, 60 insertions, 47 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index d7e9371a00..d64cbc5cc8 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
 
 import attr
 
@@ -40,10 +40,6 @@ class Stream(object):
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
 
-    # Whether the update function is only available on master. If True then
-    # calls to get updates are proxied to the master via a HTTP call.
-    _QUERY_MASTER = False
-
     @classmethod
     def parse_row(cls, row):
         """Parse a row received over replication
@@ -60,10 +56,6 @@ class Stream(object):
         return cls.ROW_TYPE(*row)
 
     def __init__(self, hs):
-        self._is_worker = hs.config.worker_app is not None
-
-        if self._QUERY_MASTER and self._is_worker:
-            self._replication_client = ReplicationGetStreamUpdates.make_client(hs)
 
         # The token from which we last asked for updates
         self.last_token = self.current_token()
@@ -110,23 +102,10 @@ class Stream(object):
         if from_token == upto_token:
             return [], upto_token, False
 
-        if self._is_worker and self._QUERY_MASTER:
-            result = await self._replication_client(
-                stream_name=self.NAME,
-                from_token=from_token,
-                upto_token=upto_token,
-                limit=limit,
-            )
-            return result["updates"], result["upto_token"], result["limited"]
-        else:
-            limited = False
-            rows = await self.update_function(from_token, upto_token, limit=limit)
-            updates = [(row[0], row[1:]) for row in rows]
-            if len(updates) == limit:
-                upto_token = rows[-1][0]
-                limited = True
-
-            return updates, upto_token, limited
+        updates, upto_token, limited = await self.update_function(
+            from_token, upto_token, limit=limit,
+        )
+        return updates, upto_token, limited
 
     def current_token(self):
         """Gets the current token of the underlying streams. Should be provided
@@ -148,6 +127,26 @@ class Stream(object):
         raise NotImplementedError()
 
 
+def db_query_to_update_function(
+    query_function: Callable[[int, int, int], Awaitable[List[tuple]]]
+) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) == limit:
+            upto_token = rows[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
+
+    return update_function
+
+
 class BackfillStream(Stream):
     """We fetched some old events and either we had never seen that event before
     or it went from being an outlier to not.
@@ -171,7 +170,7 @@ class BackfillStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows)  # type: ignore
 
         super(BackfillStream, self).__init__(hs)
 
@@ -192,16 +191,20 @@ class PresenceStream(Stream):
 
     NAME = "presence"
     ROW_TYPE = PresenceStreamRow
-    _QUERY_MASTER = True
 
     def __init__(self, hs):
         store = hs.get_datastore()
         presence_handler = hs.get_presence_handler()
 
+        self._is_worker = hs.config.worker_app is not None
+
         self.current_token = store.get_current_presence_token  # type: ignore
 
         if hs.config.worker_app is None:
-            self.update_function = presence_handler.get_all_presence_updates  # type: ignore
+            self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = ReplicationGetStreamUpdates.make_client(hs)  # type: ignore
 
         super(PresenceStream, self).__init__(hs)
 
@@ -213,7 +216,6 @@ class TypingStream(Stream):
 
     NAME = "typing"
     ROW_TYPE = TypingStreamRow
-    _QUERY_MASTER = True
 
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
@@ -221,7 +223,10 @@ class TypingStream(Stream):
         self.current_token = typing_handler.get_current_token  # type: ignore
 
         if hs.config.worker_app is None:
-            self.update_function = typing_handler.get_all_typing_updates  # type: ignore
+            self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = ReplicationGetStreamUpdates.make_client(hs)  # type: ignore
 
         super(TypingStream, self).__init__(hs)
 
@@ -245,7 +250,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_receipts  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_receipts)  # type: ignore
 
         super(ReceiptsStream, self).__init__(hs)
 
@@ -269,7 +274,13 @@ class PushRulesStream(Stream):
 
     async def update_function(self, from_token, to_token, limit):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], row[2]) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
@@ -288,7 +299,7 @@ class PushersStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows)  # type: ignore
 
         super(PushersStream, self).__init__(hs)
 
@@ -320,7 +331,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_caches  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_caches)  # type: ignore
 
         super(CachesStream, self).__init__(hs)
 
@@ -346,7 +357,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = store.get_all_new_public_rooms  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_public_rooms)  # type: ignore
 
         super(PublicRoomsStream, self).__init__(hs)
 
@@ -367,7 +378,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes)  # type: ignore
 
         super(DeviceListsStream, self).__init__(hs)
 
@@ -385,7 +396,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = store.get_all_new_device_messages  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_device_messages)  # type: ignore
 
         super(ToDeviceStream, self).__init__(hs)
 
@@ -405,7 +416,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_tags  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_tags)  # type: ignore
 
         super(TagAccountDataStream, self).__init__(hs)
 
@@ -425,10 +436,11 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
 
         self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(AccountDataStream, self).__init__(hs)
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
@@ -455,7 +467,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = store.get_all_groups_changes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_groups_changes)  # type: ignore
 
         super(GroupServerStream, self).__init__(hs)
 
@@ -473,6 +485,6 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes)  # type: ignore
 
         super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..c6a595629f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import Tuple, Type
 
 import attr
 
-from ._base import Stream
+from ._base import Stream, db_query_to_update_function
 
 
 """Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
     def __init__(self, hs):
         self._store = hs.get_datastore()
         self.current_token = self._store.get_current_events_token  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(EventsStream, self).__init__(hs)
 
-    async def update_function(self, from_token, current_token, limit=None):
+    async def _update_function(self, from_token, current_token, limit=None):
         event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 67e0eaa262..48c1d45718 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -17,7 +17,7 @@ from collections import namedtuple
 
 from twisted.internet import defer
 
-from synapse.replication.tcp.streams._base import Stream
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
 
 
 class FederationStream(Stream):
@@ -44,9 +44,9 @@ class FederationStream(Stream):
         if hs.config.worker_app is None or hs.should_send_federation():
             federation_sender = hs.get_federation_sender()
             self.current_token = federation_sender.get_current_token  # type: ignore
-            self.update_function = federation_sender.get_replication_rows  # type: ignore
+            self.update_function = db_query_to_update_function(federation_sender.get_replication_rows)  # type: ignore
         else:
             self.current_token = lambda: 0  # type: ignore
-            self.update_function = lambda *args, **kwargs: defer.succeed([])  # type: ignore
+            self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool))  # type: ignore
 
         super(FederationStream, self).__init__(hs)