summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7291.misc1
-rw-r--r--synapse/replication/tcp/streams/__init__.py5
-rw-r--r--synapse/replication/tcp/streams/_base.py224
-rw-r--r--synapse/replication/tcp/streams/events.py16
-rw-r--r--synapse/replication/tcp/streams/federation.py19
5 files changed, 143 insertions, 122 deletions
diff --git a/changelog.d/7291.misc b/changelog.d/7291.misc
new file mode 100644
index 0000000000..02e7ae3fa2
--- /dev/null
+++ b/changelog.d/7291.misc
@@ -0,0 +1 @@
+Improve typing annotations in `synapse.replication.tcp.streams.Stream`.
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 37bcd3de66..d1a61c3314 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -25,8 +25,6 @@ Each stream is defined by the following information:
     update_function:    The function that returns a list of updates between two tokens
 """
 
-from typing import Dict, Type
-
 from synapse.replication.tcp.streams._base import (
     AccountDataStream,
     BackfillStream,
@@ -67,8 +65,7 @@ STREAMS_MAP = {
         GroupServerStream,
         UserSignatureStream,
     )
-}  # type: Dict[str, Type[Stream]]
-
+}
 
 __all__ = [
     "STREAMS_MAP",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 0d3f050776..a860072ccf 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,12 +16,11 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
 
 import attr
 
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
@@ -34,8 +33,32 @@ MAX_EVENTS_BEHIND = 500000
 # A stream position token
 Token = int
 
-# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
-StreamRow = Tuple[Token, tuple]
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = Tuple
+
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+#   * `updates` is a list of `(token, row)` entries.
+#   * `new_last_token` is the new position in stream.
+#   * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
+
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+#  * from_token: the previous stream token: the starting point for fetching the
+#    updates
+#  * to_token: the new stream token: the point to get updates up to
+#  * limit: the maximum number of rows to return
+#
+UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
 class Stream(object):
@@ -50,7 +73,7 @@ class Stream(object):
     ROW_TYPE = None  # type: Any
 
     @classmethod
-    def parse_row(cls, row):
+    def parse_row(cls, row: StreamRow):
         """Parse a row received over replication
 
         By default, assumes that the row data is an array object and passes its contents
@@ -64,7 +87,28 @@ class Stream(object):
         """
         return cls.ROW_TYPE(*row)
 
-    def __init__(self, hs):
+    def __init__(
+        self,
+        current_token_function: Callable[[], Token],
+        update_function: UpdateFunction,
+    ):
+        """Instantiate a Stream
+
+        current_token_function and update_function are callbacks which should be
+        implemented by subclasses.
+
+        current_token_function is called to get the current token of the underlying
+        stream.
+
+        update_function is called to get updates for this stream between a pair of
+        stream tokens. See the UpdateFunction type definition for more info.
+
+        Args:
+            current_token_function: callback to get the current token, as above
+            update_function: callback go get stream updates, as above
+        """
+        self.current_token = current_token_function
+        self.update_function = update_function
 
         # The token from which we last asked for updates
         self.last_token = self.current_token()
@@ -75,7 +119,7 @@ class Stream(object):
         """
         self.last_token = self.current_token()
 
-    async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
+    async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
         since the stream was constructed if it hadn't been called before).
 
@@ -95,7 +139,7 @@ class Stream(object):
 
     async def get_updates_since(
         self, from_token: Token, upto_token: Token, limit: int = 100
-    ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
+    ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
 
@@ -112,33 +156,14 @@ class Stream(object):
             return [], upto_token, False
 
         updates, upto_token, limited = await self.update_function(
-            from_token, upto_token, limit=limit,
+            from_token, upto_token, limit,
         )
         return updates, upto_token, limited
 
-    def current_token(self):
-        """Gets the current token of the underlying streams. Should be provided
-        by the sub classes
-
-        Returns:
-            int
-        """
-        raise NotImplementedError()
-
-    def update_function(self, from_token, current_token, limit):
-        """Get updates between from_token and to_token.
-
-        Returns:
-            Deferred(list(tuple)): the first entry in the tuple is the token for
-                that update, and the rest of the tuple gets used to construct
-                a ``ROW_TYPE`` instance
-        """
-        raise NotImplementedError()
-
 
 def db_query_to_update_function(
-    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
-) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+    query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
+) -> UpdateFunction:
     """Wraps a db query function which returns a list of rows to make it
     suitable for use as an `update_function` for the Stream class
     """
@@ -157,9 +182,7 @@ def db_query_to_update_function(
     return update_function
 
 
-def make_http_update_function(
-    hs, stream_name: str
-) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
     """Makes a suitable function for use as an `update_function` that queries
     the master process for updates.
     """
@@ -168,7 +191,7 @@ def make_http_update_function(
 
     async def update_function(
         from_token: int, upto_token: int, limit: int
-    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+    ) -> StreamUpdateResult:
         result = await client(
             stream_name=stream_name,
             from_token=from_token,
@@ -202,10 +225,10 @@ class BackfillStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows)  # type: ignore
-
-        super(BackfillStream, self).__init__(hs)
+        super().__init__(
+            store.get_current_backfill_token,
+            db_query_to_update_function(store.get_all_new_backfill_event_rows),
+        )
 
 
 class PresenceStream(Stream):
@@ -227,19 +250,18 @@ class PresenceStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        presence_handler = hs.get_presence_handler()
-
-        self._is_worker = hs.config.worker_app is not None
-
-        self.current_token = store.get_current_presence_token  # type: ignore
 
         if hs.config.worker_app is None:
-            self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates)  # type: ignore
+            # on the master, query the presence handler
+            presence_handler = hs.get_presence_handler()
+            update_function = db_query_to_update_function(
+                presence_handler.get_all_presence_updates
+            )
         else:
             # Query master process
-            self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(PresenceStream, self).__init__(hs)
+        super().__init__(store.get_current_presence_token, update_function)
 
 
 class TypingStream(Stream):
@@ -253,15 +275,16 @@ class TypingStream(Stream):
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
-        self.current_token = typing_handler.get_current_token  # type: ignore
-
         if hs.config.worker_app is None:
-            self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates)  # type: ignore
+            # on the master, query the typing handler
+            update_function = db_query_to_update_function(
+                typing_handler.get_all_typing_updates
+            )
         else:
             # Query master process
-            self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(TypingStream, self).__init__(hs)
+        super().__init__(typing_handler.get_current_token, update_function)
 
 
 class ReceiptsStream(Stream):
@@ -281,11 +304,10 @@ class ReceiptsStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_updated_receipts)  # type: ignore
-
-        super(ReceiptsStream, self).__init__(hs)
+        super().__init__(
+            store.get_max_receipt_stream_id,
+            db_query_to_update_function(store.get_all_updated_receipts),
+        )
 
 
 class PushRulesStream(Stream):
@@ -299,13 +321,15 @@ class PushRulesStream(Stream):
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        super(PushRulesStream, self).__init__(hs)
+        super(PushRulesStream, self).__init__(
+            self._current_token, self._update_function
+        )
 
-    def current_token(self):
+    def _current_token(self) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token: Token, to_token: Token, limit: int):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
 
         limited = False
@@ -331,10 +355,10 @@ class PushersStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows)  # type: ignore
-
-        super(PushersStream, self).__init__(hs)
+        super().__init__(
+            store.get_pushers_stream_token,
+            db_query_to_update_function(store.get_all_updated_pushers_rows),
+        )
 
 
 class CachesStream(Stream):
@@ -362,11 +386,10 @@ class CachesStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_updated_caches)  # type: ignore
-
-        super(CachesStream, self).__init__(hs)
+        super().__init__(
+            store.get_cache_stream_token,
+            db_query_to_update_function(store.get_all_updated_caches),
+        )
 
 
 class PublicRoomsStream(Stream):
@@ -388,11 +411,10 @@ class PublicRoomsStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_new_public_rooms)  # type: ignore
-
-        super(PublicRoomsStream, self).__init__(hs)
+        super().__init__(
+            store.get_current_public_room_stream_id,
+            db_query_to_update_function(store.get_all_new_public_rooms),
+        )
 
 
 class DeviceListsStream(Stream):
@@ -409,11 +431,10 @@ class DeviceListsStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes)  # type: ignore
-
-        super(DeviceListsStream, self).__init__(hs)
+        super().__init__(
+            store.get_device_stream_token,
+            db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+        )
 
 
 class ToDeviceStream(Stream):
@@ -427,11 +448,10 @@ class ToDeviceStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_new_device_messages)  # type: ignore
-
-        super(ToDeviceStream, self).__init__(hs)
+        super().__init__(
+            store.get_to_device_stream_token,
+            db_query_to_update_function(store.get_all_new_device_messages),
+        )
 
 
 class TagAccountDataStream(Stream):
@@ -447,11 +467,10 @@ class TagAccountDataStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_updated_tags)  # type: ignore
-
-        super(TagAccountDataStream, self).__init__(hs)
+        super().__init__(
+            store.get_max_account_data_stream_id,
+            db_query_to_update_function(store.get_all_updated_tags),
+        )
 
 
 class AccountDataStream(Stream):
@@ -467,11 +486,10 @@ class AccountDataStream(Stream):
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-
-        self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
-
-        super(AccountDataStream, self).__init__(hs)
+        super().__init__(
+            self.store.get_max_account_data_stream_id,
+            db_query_to_update_function(self._update_function),
+        )
 
     async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
@@ -498,11 +516,10 @@ class GroupServerStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_groups_changes)  # type: ignore
-
-        super(GroupServerStream, self).__init__(hs)
+        super().__init__(
+            store.get_group_stream_token,
+            db_query_to_update_function(store.get_all_groups_changes),
+        )
 
 
 class UserSignatureStream(Stream):
@@ -516,8 +533,9 @@ class UserSignatureStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes)  # type: ignore
-
-        super(UserSignatureStream, self).__init__(hs)
+        super().__init__(
+            store.get_device_stream_token,
+            db_query_to_update_function(
+                store.get_all_user_signature_changes_for_remotes
+            ),
+        )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index c6a595629f..051114596b 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,11 +15,11 @@
 # limitations under the License.
 
 import heapq
-from typing import Tuple, Type
+from typing import Iterable, Tuple, Type
 
 import attr
 
-from ._base import Stream, db_query_to_update_function
+from ._base import Stream, Token, db_query_to_update_function
 
 
 """Handling of the 'events' replication stream
@@ -116,12 +116,14 @@ class EventsStream(Stream):
 
     def __init__(self, hs):
         self._store = hs.get_datastore()
-        self.current_token = self._store.get_current_events_token  # type: ignore
-        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
-
-        super(EventsStream, self).__init__(hs)
+        super().__init__(
+            self._store.get_current_events_token,
+            db_query_to_update_function(self._update_function),
+        )
 
-    async def _update_function(self, from_token, current_token, limit=None):
+    async def _update_function(
+        self, from_token: Token, current_token: Token, limit: int
+    ) -> Iterable[tuple]:
         event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 48c1d45718..75133d7e40 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,8 +15,6 @@
 # limitations under the License.
 from collections import namedtuple
 
-from twisted.internet import defer
-
 from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
 
 
@@ -35,7 +33,6 @@ class FederationStream(Stream):
 
     NAME = "federation"
     ROW_TYPE = FederationStreamRow
-    _QUERY_MASTER = True
 
     def __init__(self, hs):
         # Not all synapse instances will have a federation sender instance,
@@ -43,10 +40,16 @@ class FederationStream(Stream):
         # so we stub the stream out when that is the case.
         if hs.config.worker_app is None or hs.should_send_federation():
             federation_sender = hs.get_federation_sender()
-            self.current_token = federation_sender.get_current_token  # type: ignore
-            self.update_function = db_query_to_update_function(federation_sender.get_replication_rows)  # type: ignore
+            current_token = federation_sender.get_current_token
+            update_function = db_query_to_update_function(
+                federation_sender.get_replication_rows
+            )
         else:
-            self.current_token = lambda: 0  # type: ignore
-            self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool))  # type: ignore
+            current_token = lambda: 0
+            update_function = self._stub_update_function
+
+        super().__init__(current_token, update_function)
 
-        super(FederationStream, self).__init__(hs)
+    @staticmethod
+    async def _stub_update_function(from_token, upto_token, limit):
+        return [], upto_token, False