summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-04-17 14:49:55 +0100
committerGitHub <noreply@github.com>2020-04-17 14:49:55 +0100
commit67ff7b8ba0d3647f3c370341dff3f035b3a1160a (patch)
tree3a4bed20db43138a514ea1ff06d7d0bf1f584891
parentClarify the comments for media_storage_providers options (#7272) (diff)
downloadsynapse-67ff7b8ba0d3647f3c370341dff3f035b3a1160a.tar.xz
Improve type checking in `replication.tcp.Stream` (#7291)
The general idea here is to get rid of the type: ignore annotations on all of the current_token and update_function assignments, which would have caught #7290.

After a bit of experimentation, it seems like the least-awful way to do this is to pass the offending functions in as parameters to the Stream constructor. Unfortunately that means that the concrete implementations no longer have the same constructor signature as Stream itself, which means that it gets hard to correctly annotate STREAMS_MAP.

I've also introduced a couple of new types, to take out some duplication.
Diffstat (limited to '')
-rw-r--r--changelog.d/7291.misc1
-rw-r--r--synapse/replication/tcp/streams/__init__.py5
-rw-r--r--synapse/replication/tcp/streams/_base.py224
-rw-r--r--synapse/replication/tcp/streams/events.py16
-rw-r--r--synapse/replication/tcp/streams/federation.py19
5 files changed, 143 insertions, 122 deletions
diff --git a/changelog.d/7291.misc b/changelog.d/7291.misc
new file mode 100644
index 0000000000..02e7ae3fa2
--- /dev/null
+++ b/changelog.d/7291.misc
@@ -0,0 +1 @@
+Improve typing annotations in `synapse.replication.tcp.streams.Stream`.
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 37bcd3de66..d1a61c3314 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -25,8 +25,6 @@ Each stream is defined by the following information:
     update_function:    The function that returns a list of updates between two tokens
 """
 
-from typing import Dict, Type
-
 from synapse.replication.tcp.streams._base import (
     AccountDataStream,
     BackfillStream,
@@ -67,8 +65,7 @@ STREAMS_MAP = {
         GroupServerStream,
         UserSignatureStream,
     )
-}  # type: Dict[str, Type[Stream]]
-
+}
 
 __all__ = [
     "STREAMS_MAP",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 0d3f050776..a860072ccf 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,12 +16,11 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
 
 import attr
 
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
@@ -34,8 +33,32 @@ MAX_EVENTS_BEHIND = 500000
 # A stream position token
 Token = int
 
-# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
-StreamRow = Tuple[Token, tuple]
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = Tuple
+
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+#   * `updates` is a list of `(token, row)` entries.
+#   * `new_last_token` is the new position in stream.
+#   * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
+
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+#  * from_token: the previous stream token: the starting point for fetching the
+#    updates
+#  * to_token: the new stream token: the point to get updates up to
+#  * limit: the maximum number of rows to return
+#
+UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
 class Stream(object):
@@ -50,7 +73,7 @@ class Stream(object):
     ROW_TYPE = None  # type: Any
 
     @classmethod
-    def parse_row(cls, row):
+    def parse_row(cls, row: StreamRow):
         """Parse a row received over replication
 
         By default, assumes that the row data is an array object and passes its contents
@@ -64,7 +87,28 @@ class Stream(object):
         """
         return cls.ROW_TYPE(*row)
 
-    def __init__(self, hs):
+    def __init__(
+        self,
+        current_token_function: Callable[[], Token],
+        update_function: UpdateFunction,
+    ):
+        """Instantiate a Stream
+
+        current_token_function and update_function are callbacks which should be
+        implemented by subclasses.
+
+        current_token_function is called to get the current token of the underlying
+        stream.
+
+        update_function is called to get updates for this stream between a pair of
+        stream tokens. See the UpdateFunction type definition for more info.
+
+        Args:
+            current_token_function: callback to get the current token, as above
+            update_function: callback go get stream updates, as above
+        """
+        self.current_token = current_token_function
+        self.update_function = update_function
 
         # The token from which we last asked for updates
         self.last_token = self.current_token()
@@ -75,7 +119,7 @@ class Stream(object):
         """
         self.last_token = self.current_token()
 
-    async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
+    async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
         since the stream was constructed if it hadn't been called before).
 
@@ -95,7 +139,7 @@ class Stream(object):
 
     async def get_updates_since(
         self, from_token: Token, upto_token: Token, limit: int = 100
-    ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
+    ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
 
@@ -112,33 +156,14 @@ class Stream(object):
             return [], upto_token, False
 
         updates, upto_token, limited = await self.update_function(
-            from_token, upto_token, limit=limit,
+            from_token, upto_token, limit,
         )
         return updates, upto_token, limited
 
-    def current_token(self):
-        """Gets the current token of the underlying streams. Should be provided
-        by the sub classes
-
-        Returns:
-            int
-        """
-        raise NotImplementedError()
-
-    def update_function(self, from_token, current_token, limit):
-        """Get updates between from_token and to_token.
-
-        Returns:
-            Deferred(list(tuple)): the first entry in the tuple is the token for
-                that update, and the rest of the tuple gets used to construct
-                a ``ROW_TYPE`` instance
-        """
-        raise NotImplementedError()
-
 
 def db_query_to_update_function(
-    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
-) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+    query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
+) -> UpdateFunction:
     """Wraps a db query function which returns a list of rows to make it
     suitable for use as an `update_function` for the Stream class
     """
@@ -157,9 +182,7 @@ def db_query_to_update_function(
     return update_function
 
 
-def make_http_update_function(
-    hs, stream_name: str
-) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
     """Makes a suitable function for use as an `update_function` that queries
     the master process for updates.
     """
@@ -168,7 +191,7 @@ def make_http_update_function(
 
     async def update_function(
         from_token: int, upto_token: int, limit: int
-    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+    ) -> StreamUpdateResult:
         result = await client(
             stream_name=stream_name,
             from_token=from_token,
@@ -202,10 +225,10 @@ class BackfillStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows)  # type: ignore
-
-        super(BackfillStream, self).__init__(hs)
+        super().__init__(
+            store.get_current_backfill_token,
+            db_query_to_update_function(store.get_all_new_backfill_event_rows),
+        )
 
 
 class PresenceStream(Stream):
@@ -227,19 +250,18 @@ class PresenceStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        presence_handler = hs.get_presence_handler()
-
-        self._is_worker = hs.config.worker_app is not None
-
-        self.current_token = store.get_current_presence_token  # type: ignore
 
         if hs.config.worker_app is None:
-            self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates)  # type: ignore
+            # on the master, query the presence handler
+            presence_handler = hs.get_presence_handler()
+            update_function = db_query_to_update_function(
+                presence_handler.get_all_presence_updates
+            )
         else:
             # Query master process
-            self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(PresenceStream, self).__init__(hs)
+        super().__init__(store.get_current_presence_token, update_function)
 
 
 class TypingStream(Stream):
@@ -253,15 +275,16 @@ class TypingStream(Stream):
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
-        self.current_token = typing_handler.get_current_token  # type: ignore
-
         if hs.config.worker_app is None:
-            self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates)  # type: ignore
+            # on the master, query the typing handler
+            update_function = db_query_to_update_function(
+                typing_handler.get_all_typing_updates
+            )
         else:
             # Query master process
-            self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(TypingStream, self).__init__(hs)
+        super().__init__(typing_handler.get_current_token, update_function)
 
 
 class ReceiptsStream(Stream):
@@ -281,11 +304,10 @@ class ReceiptsStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_updated_receipts)  # type: ignore
-
-        super(ReceiptsStream, self).__init__(hs)
+        super().__init__(
+            store.get_max_receipt_stream_id,
+            db_query_to_update_function(store.get_all_updated_receipts),
+        )
 
 
 class PushRulesStream(Stream):
@@ -299,13 +321,15 @@ class PushRulesStream(Stream):
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        super(PushRulesStream, self).__init__(hs)
+        super(PushRulesStream, self).__init__(
+            self._current_token, self._update_function
+        )
 
-    def current_token(self):
+    def _current_token(self) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token: Token, to_token: Token, limit: int):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
 
         limited = False
@@ -331,10 +355,10 @@ class PushersStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows)  # type: ignore
-
-        super(PushersStream, self).__init__(hs)
+        super().__init__(
+            store.get_pushers_stream_token,
+            db_query_to_update_function(store.get_all_updated_pushers_rows),
+        )
 
 
 class CachesStream(Stream):
@@ -362,11 +386,10 @@ class CachesStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_updated_caches)  # type: ignore
-
-        super(CachesStream, self).__init__(hs)
+        super().__init__(
+            store.get_cache_stream_token,
+            db_query_to_update_function(store.get_all_updated_caches),
+        )
 
 
 class PublicRoomsStream(Stream):
@@ -388,11 +411,10 @@ class PublicRoomsStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_new_public_rooms)  # type: ignore
-
-        super(PublicRoomsStream, self).__init__(hs)
+        super().__init__(
+            store.get_current_public_room_stream_id,
+            db_query_to_update_function(store.get_all_new_public_rooms),
+        )
 
 
 class DeviceListsStream(Stream):
@@ -409,11 +431,10 @@ class DeviceListsStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes)  # type: ignore
-
-        super(DeviceListsStream, self).__init__(hs)
+        super().__init__(
+            store.get_device_stream_token,
+            db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+        )
 
 
 class ToDeviceStream(Stream):
@@ -427,11 +448,10 @@ class ToDeviceStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_new_device_messages)  # type: ignore
-
-        super(ToDeviceStream, self).__init__(hs)
+        super().__init__(
+            store.get_to_device_stream_token,
+            db_query_to_update_function(store.get_all_new_device_messages),
+        )
 
 
 class TagAccountDataStream(Stream):
@@ -447,11 +467,10 @@ class TagAccountDataStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_updated_tags)  # type: ignore
-
-        super(TagAccountDataStream, self).__init__(hs)
+        super().__init__(
+            store.get_max_account_data_stream_id,
+            db_query_to_update_function(store.get_all_updated_tags),
+        )
 
 
 class AccountDataStream(Stream):
@@ -467,11 +486,10 @@ class AccountDataStream(Stream):
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-
-        self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
-
-        super(AccountDataStream, self).__init__(hs)
+        super().__init__(
+            self.store.get_max_account_data_stream_id,
+            db_query_to_update_function(self._update_function),
+        )
 
     async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
@@ -498,11 +516,10 @@ class GroupServerStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_groups_changes)  # type: ignore
-
-        super(GroupServerStream, self).__init__(hs)
+        super().__init__(
+            store.get_group_stream_token,
+            db_query_to_update_function(store.get_all_groups_changes),
+        )
 
 
 class UserSignatureStream(Stream):
@@ -516,8 +533,9 @@ class UserSignatureStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes)  # type: ignore
-
-        super(UserSignatureStream, self).__init__(hs)
+        super().__init__(
+            store.get_device_stream_token,
+            db_query_to_update_function(
+                store.get_all_user_signature_changes_for_remotes
+            ),
+        )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index c6a595629f..051114596b 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,11 +15,11 @@
 # limitations under the License.
 
 import heapq
-from typing import Tuple, Type
+from typing import Iterable, Tuple, Type
 
 import attr
 
-from ._base import Stream, db_query_to_update_function
+from ._base import Stream, Token, db_query_to_update_function
 
 
 """Handling of the 'events' replication stream
@@ -116,12 +116,14 @@ class EventsStream(Stream):
 
     def __init__(self, hs):
         self._store = hs.get_datastore()
-        self.current_token = self._store.get_current_events_token  # type: ignore
-        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
-
-        super(EventsStream, self).__init__(hs)
+        super().__init__(
+            self._store.get_current_events_token,
+            db_query_to_update_function(self._update_function),
+        )
 
-    async def _update_function(self, from_token, current_token, limit=None):
+    async def _update_function(
+        self, from_token: Token, current_token: Token, limit: int
+    ) -> Iterable[tuple]:
         event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 48c1d45718..75133d7e40 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,8 +15,6 @@
 # limitations under the License.
 from collections import namedtuple
 
-from twisted.internet import defer
-
 from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
 
 
@@ -35,7 +33,6 @@ class FederationStream(Stream):
 
     NAME = "federation"
     ROW_TYPE = FederationStreamRow
-    _QUERY_MASTER = True
 
     def __init__(self, hs):
         # Not all synapse instances will have a federation sender instance,
@@ -43,10 +40,16 @@ class FederationStream(Stream):
         # so we stub the stream out when that is the case.
         if hs.config.worker_app is None or hs.should_send_federation():
             federation_sender = hs.get_federation_sender()
-            self.current_token = federation_sender.get_current_token  # type: ignore
-            self.update_function = db_query_to_update_function(federation_sender.get_replication_rows)  # type: ignore
+            current_token = federation_sender.get_current_token
+            update_function = db_query_to_update_function(
+                federation_sender.get_replication_rows
+            )
         else:
-            self.current_token = lambda: 0  # type: ignore
-            self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool))  # type: ignore
+            current_token = lambda: 0
+            update_function = self._stub_update_function
+
+        super().__init__(current_token, update_function)
 
-        super(FederationStream, self).__init__(hs)
+    @staticmethod
+    async def _stub_update_function(from_token, upto_token, limit):
+        return [], upto_token, False