summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-03-24 17:21:26 +0000
committerErik Johnston <erik@matrix.org>2020-03-24 17:21:26 +0000
commit604f57f1bd8ab251a9d0452f96197328cfe53d05 (patch)
tree8603ab2cca479f4f043e9818ffe926da6b83c1f7 /synapse
parentShuffle around code typing handlers (diff)
parentFixup push rules stream (diff)
downloadsynapse-604f57f1bd8ab251a9d0452f96197328cfe53d05.tar.xz
Merge branch 'erikj/catchup_on_worker' into erikj/split_out_typing
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/http/streams.py6
-rw-r--r--synapse/replication/tcp/protocol.py6
-rw-r--r--synapse/replication/tcp/resource.py7
-rw-r--r--synapse/replication/tcp/streams/_base.py99
-rw-r--r--synapse/replication/tcp/streams/events.py5
-rw-r--r--synapse/replication/tcp/streams/federation.py6
6 files changed, 73 insertions, 56 deletions
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index 141df68787..ffd4c61993 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -47,9 +47,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
     def __init__(self, hs):
         super().__init__(hs)
 
-        from synapse.replication.tcp.streams import STREAMS_MAP
-
-        self.streams = {stream.NAME: stream(hs) for stream in STREAMS_MAP.values()}
+        # We pull the streams from the replication steamer (if we try and make
+        # them ourselves we end up in an import loop).
+        self.streams = hs.get_replication_streamer().get_streams()
 
     @staticmethod
     def _serialize_payload(stream_name, from_token, upto_token, limit):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index d4456f42f3..de6abfc82e 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -72,12 +72,14 @@ from synapse.replication.tcp.commands import (
     ServerCommand,
 )
 from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 MYPY = False
 if MYPY:
-    import synapse.server
+    from synapse.server import HomeServer
+
 
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
@@ -423,7 +425,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     def __init__(
         self,
-        hs: "synapse.server.HomeServer",
+        hs: "HomeServer",
         client_name: str,
         server_name: str,
         clock: Clock,
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index c9d671210b..2ce171edbd 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
 
 import logging
 import random
-from typing import List
+from typing import List, Dict
 
 from prometheus_client import Counter
 
@@ -97,6 +97,11 @@ class ReplicationStreamer(object):
 
         self.client = hs.get_tcp_replication()
 
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a mapp from stream name to stream instance.
+        """
+        return self.streams_by_name
+
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index d5b9c2831b..2699e466bc 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
 
 import attr
 
@@ -40,10 +40,6 @@ class Stream(object):
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
 
-    # Whether the update function is only available on master. If True then
-    # calls to get updates are proxied to the master via a HTTP call.
-    _QUERY_MASTER = False
-
     @classmethod
     def parse_row(cls, row):
         """Parse a row received over replication
@@ -60,10 +56,6 @@ class Stream(object):
         return cls.ROW_TYPE(*row)
 
     def __init__(self, hs):
-        self._is_worker = hs.config.worker_app is not None
-
-        if self._QUERY_MASTER and self._is_worker:
-            self._replication_client = ReplicationGetStreamUpdates.make_client(hs)
 
         # The token from which we last asked for updates
         self.last_token = self.current_token()
@@ -105,31 +97,15 @@ class Stream(object):
             to fetch.
         """
 
-        if from_token in ("NOW", "now"):
-            return [], upto_token, False
-
         from_token = int(from_token)
 
         if from_token == upto_token:
             return [], upto_token, False
 
-        if self._is_worker and self._QUERY_MASTER:
-            result = await self._replication_client(
-                stream_name=self.NAME,
-                from_token=from_token,
-                upto_token=upto_token,
-                limit=limit,
-            )
-            return result["updates"], result["upto_token"], result["limited"]
-        else:
-            limited = False
-            rows = await self.update_function(from_token, upto_token, limit=limit)
-            updates = [(row[0], row[1:]) for row in rows]
-            if len(updates) == limit:
-                upto_token = rows[-1][0]
-                limited = True
-
-            return updates, upto_token, limited
+        updates, upto_token, limited = await self.update_function(
+            from_token, upto_token, limit=limit,
+        )
+        return updates, upto_token, limited
 
     def current_token(self):
         """Gets the current token of the underlying streams. Should be provided
@@ -151,6 +127,26 @@ class Stream(object):
         raise NotImplementedError()
 
 
+def db_query_to_update_function(
+    query_function: Callable[[int, int, int], Awaitable[List[tuple]]]
+) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) == limit:
+            upto_token = rows[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
+
+    return update_function
+
+
 class BackfillStream(Stream):
     """We fetched some old events and either we had never seen that event before
     or it went from being an outlier to not.
@@ -174,7 +170,7 @@ class BackfillStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows)  # type: ignore
 
         super(BackfillStream, self).__init__(hs)
 
@@ -195,16 +191,20 @@ class PresenceStream(Stream):
 
     NAME = "presence"
     ROW_TYPE = PresenceStreamRow
-    _QUERY_MASTER = True
 
     def __init__(self, hs):
         store = hs.get_datastore()
         presence_handler = hs.get_presence_handler()
 
+        self._is_worker = hs.config.worker_app is not None
+
         self.current_token = store.get_current_presence_token  # type: ignore
 
         if hs.config.worker_app is None:
-            self.update_function = presence_handler.get_all_presence_updates  # type: ignore
+            self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = ReplicationGetStreamUpdates.make_client(hs)  # type: ignore
 
         super(PresenceStream, self).__init__(hs)
 
@@ -216,7 +216,6 @@ class TypingStream(Stream):
 
     NAME = "typing"
     ROW_TYPE = TypingStreamRow
-    _QUERY_MASTER = True
 
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
@@ -224,7 +223,10 @@ class TypingStream(Stream):
         self.current_token = typing_handler.get_current_token  # type: ignore
 
         if hs.config.worker_app is None:
-            self.update_function = typing_handler.get_all_typing_updates  # type: ignore
+            self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = ReplicationGetStreamUpdates.make_client(hs)  # type: ignore
 
         super(TypingStream, self).__init__(hs)
 
@@ -248,7 +250,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_receipts  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_receipts)  # type: ignore
 
         super(ReceiptsStream, self).__init__(hs)
 
@@ -272,7 +274,13 @@ class PushRulesStream(Stream):
 
     async def update_function(self, from_token, to_token, limit):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], (row[2],)) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
@@ -291,7 +299,7 @@ class PushersStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows)  # type: ignore
 
         super(PushersStream, self).__init__(hs)
 
@@ -323,7 +331,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_caches  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_caches)  # type: ignore
 
         super(CachesStream, self).__init__(hs)
 
@@ -349,7 +357,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = store.get_all_new_public_rooms  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_public_rooms)  # type: ignore
 
         super(PublicRoomsStream, self).__init__(hs)
 
@@ -370,7 +378,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes)  # type: ignore
 
         super(DeviceListsStream, self).__init__(hs)
 
@@ -388,7 +396,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = store.get_all_new_device_messages  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_device_messages)  # type: ignore
 
         super(ToDeviceStream, self).__init__(hs)
 
@@ -408,7 +416,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_tags  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_tags)  # type: ignore
 
         super(TagAccountDataStream, self).__init__(hs)
 
@@ -428,10 +436,11 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
 
         self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(AccountDataStream, self).__init__(hs)
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
@@ -458,7 +467,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = store.get_all_groups_changes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_groups_changes)  # type: ignore
 
         super(GroupServerStream, self).__init__(hs)
 
@@ -476,6 +485,6 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes)  # type: ignore
 
         super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..c6a595629f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import Tuple, Type
 
 import attr
 
-from ._base import Stream
+from ._base import Stream, db_query_to_update_function
 
 
 """Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
     def __init__(self, hs):
         self._store = hs.get_datastore()
         self.current_token = self._store.get_current_events_token  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(EventsStream, self).__init__(hs)
 
-    async def update_function(self, from_token, current_token, limit=None):
+    async def _update_function(self, from_token, current_token, limit=None):
         event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 67e0eaa262..48c1d45718 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -17,7 +17,7 @@ from collections import namedtuple
 
 from twisted.internet import defer
 
-from synapse.replication.tcp.streams._base import Stream
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
 
 
 class FederationStream(Stream):
@@ -44,9 +44,9 @@ class FederationStream(Stream):
         if hs.config.worker_app is None or hs.should_send_federation():
             federation_sender = hs.get_federation_sender()
             self.current_token = federation_sender.get_current_token  # type: ignore
-            self.update_function = federation_sender.get_replication_rows  # type: ignore
+            self.update_function = db_query_to_update_function(federation_sender.get_replication_rows)  # type: ignore
         else:
             self.current_token = lambda: 0  # type: ignore
-            self.update_function = lambda *args, **kwargs: defer.succeed([])  # type: ignore
+            self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool))  # type: ignore
 
         super(FederationStream, self).__init__(hs)