summary refs log tree commit diff
path: root/synapse/replication/tcp/streams
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-05-01 17:19:56 +0100
committerGitHub <noreply@github.com>2020-05-01 17:19:56 +0100
commit0e719f23981b8294df66ba7f38b8c7cc99fad228 (patch)
tree42d9aa97954cdbea46b0969bceefd88d2953a623 /synapse/replication/tcp/streams
parentUse `stream.current_token()` and remove `stream_positions()` (#7172) (diff)
downloadsynapse-0e719f23981b8294df66ba7f38b8c7cc99fad228.tar.xz
Thread through instance name to replication client. (#7369)
For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams.
Diffstat (limited to 'synapse/replication/tcp/streams')
-rw-r--r--synapse/replication/tcp/streams/_base.py50
-rw-r--r--synapse/replication/tcp/streams/events.py10
-rw-r--r--synapse/replication/tcp/streams/federation.py4
3 files changed, 47 insertions, 17 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4af1afd119..b0f87c365b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
 
 import attr
 
@@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
 #
 # The arguments are:
 #
+#  * instance_name: the writer of the stream
 #  * from_token: the previous stream token: the starting point for fetching the
 #    updates
 #  * to_token: the new stream token: the point to get updates up to
@@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
 # If there are more updates available, it should set `limited` in the result, and
 # it will be called again to get the next batch.
 #
-UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
 class Stream(object):
@@ -93,6 +94,7 @@ class Stream(object):
 
     def __init__(
         self,
+        local_instance_name: str,
         current_token_function: Callable[[], Token],
         update_function: UpdateFunction,
     ):
@@ -108,9 +110,11 @@ class Stream(object):
         stream tokens. See the UpdateFunction type definition for more info.
 
         Args:
+            local_instance_name: The instance name of the current process
             current_token_function: callback to get the current token, as above
             update_function: callback go get stream updates, as above
         """
+        self.local_instance_name = local_instance_name
         self.current_token = current_token_function
         self.update_function = update_function
 
@@ -135,14 +139,14 @@ class Stream(object):
         """
         current_token = self.current_token()
         updates, current_token, limited = await self.get_updates_since(
-            self.last_token, current_token
+            self.local_instance_name, self.last_token, current_token
         )
         self.last_token = current_token
 
         return updates, current_token, limited
 
     async def get_updates_since(
-        self, from_token: Token, upto_token: Token
+        self, instance_name: str, from_token: Token, upto_token: Token
     ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
@@ -160,19 +164,19 @@ class Stream(object):
             return [], upto_token, False
 
         updates, upto_token, limited = await self.update_function(
-            from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+            instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
         )
         return updates, upto_token, limited
 
 
 def db_query_to_update_function(
-    query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
+    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
 ) -> UpdateFunction:
     """Wraps a db query function which returns a list of rows to make it
     suitable for use as an `update_function` for the Stream class
     """
 
-    async def update_function(from_token, upto_token, limit):
+    async def update_function(instance_name, from_token, upto_token, limit):
         rows = await query_function(from_token, upto_token, limit)
         updates = [(row[0], row[1:]) for row in rows]
         limited = False
@@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
     client = ReplicationGetStreamUpdates.make_client(hs)
 
     async def update_function(
-        from_token: int, upto_token: int, limit: int
+        instance_name: str, from_token: int, upto_token: int, limit: int
     ) -> StreamUpdateResult:
         result = await client(
-            stream_name=stream_name, from_token=from_token, upto_token=upto_token,
+            instance_name=instance_name,
+            stream_name=stream_name,
+            from_token=from_token,
+            upto_token=upto_token,
         )
         return result["updates"], result["upto_token"], result["limited"]
 
@@ -226,6 +233,7 @@ class BackfillStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_current_backfill_token,
             db_query_to_update_function(store.get_all_new_backfill_event_rows),
         )
@@ -261,7 +269,9 @@ class PresenceStream(Stream):
             # Query master process
             update_function = make_http_update_function(hs, self.NAME)
 
-        super().__init__(store.get_current_presence_token, update_function)
+        super().__init__(
+            hs.get_instance_name(), store.get_current_presence_token, update_function
+        )
 
 
 class TypingStream(Stream):
@@ -284,7 +294,9 @@ class TypingStream(Stream):
             # Query master process
             update_function = make_http_update_function(hs, self.NAME)
 
-        super().__init__(typing_handler.get_current_token, update_function)
+        super().__init__(
+            hs.get_instance_name(), typing_handler.get_current_token, update_function
+        )
 
 
 class ReceiptsStream(Stream):
@@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_max_receipt_stream_id,
             db_query_to_update_function(store.get_all_updated_receipts),
         )
@@ -322,14 +335,16 @@ class PushRulesStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
         super(PushRulesStream, self).__init__(
-            self._current_token, self._update_function
+            hs.get_instance_name(), self._current_token, self._update_function
         )
 
     def _current_token(self) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    async def _update_function(self, from_token: Token, to_token: Token, limit: int):
+    async def _update_function(
+        self, instance_name: str, from_token: Token, to_token: Token, limit: int
+    ):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
 
         limited = False
@@ -356,6 +371,7 @@ class PushersStream(Stream):
         store = hs.get_datastore()
 
         super().__init__(
+            hs.get_instance_name(),
             store.get_pushers_stream_token,
             db_query_to_update_function(store.get_all_updated_pushers_rows),
         )
@@ -387,6 +403,7 @@ class CachesStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_cache_stream_token,
             db_query_to_update_function(store.get_all_updated_caches),
         )
@@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_current_public_room_stream_id,
             db_query_to_update_function(store.get_all_new_public_rooms),
         )
@@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_device_stream_token,
             db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
         )
@@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_to_device_stream_token,
             db_query_to_update_function(store.get_all_new_device_messages),
         )
@@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_max_account_data_stream_id,
             db_query_to_update_function(store.get_all_updated_tags),
         )
@@ -487,6 +508,7 @@ class AccountDataStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             self.store.get_max_account_data_stream_id,
             db_query_to_update_function(self._update_function),
         )
@@ -517,6 +539,7 @@ class GroupServerStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_group_stream_token,
             db_query_to_update_function(store.get_all_groups_changes),
         )
@@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_device_stream_token,
             db_query_to_update_function(
                 store.get_all_user_signature_changes_for_remotes
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 52df81b1bd..890e75d827 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -118,11 +118,17 @@ class EventsStream(Stream):
     def __init__(self, hs):
         self._store = hs.get_datastore()
         super().__init__(
-            self._store.get_current_events_token, self._update_function,
+            hs.get_instance_name(),
+            self._store.get_current_events_token,
+            self._update_function,
         )
 
     async def _update_function(
-        self, from_token: Token, current_token: Token, target_row_count: int
+        self,
+        instance_name: str,
+        from_token: Token,
+        current_token: Token,
+        target_row_count: int,
     ) -> StreamUpdateResult:
 
         # the events stream merges together three separate sources:
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 75133d7e40..e8bd52e389 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -48,8 +48,8 @@ class FederationStream(Stream):
             current_token = lambda: 0
             update_function = self._stub_update_function
 
-        super().__init__(current_token, update_function)
+        super().__init__(hs.get_instance_name(), current_token, update_function)
 
     @staticmethod
-    async def _stub_update_function(from_token, upto_token, limit):
+    async def _stub_update_function(instance_name, from_token, upto_token, limit):
         return [], upto_token, False