diff --git a/changelog.d/7291.misc b/changelog.d/7291.misc
new file mode 100644
index 0000000000..02e7ae3fa2
--- /dev/null
+++ b/changelog.d/7291.misc
@@ -0,0 +1 @@
+Improve typing annotations in `synapse.replication.tcp.streams.Stream`.
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 37bcd3de66..d1a61c3314 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -25,8 +25,6 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
-from typing import Dict, Type
-
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
@@ -67,8 +65,7 @@ STREAMS_MAP = {
GroupServerStream,
UserSignatureStream,
)
-} # type: Dict[str, Type[Stream]]
-
+}
__all__ = [
"STREAMS_MAP",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 0d3f050776..a860072ccf 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,12 +16,11 @@
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -34,8 +33,32 @@ MAX_EVENTS_BEHIND = 500000
# A stream position token
Token = int
-# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
-StreamRow = Tuple[Token, tuple]
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = Tuple
+
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+# * `updates` is a list of `(token, row)` entries.
+# * `new_last_token` is the new position in stream.
+# * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
+
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+# * from_token: the previous stream token: the starting point for fetching the
+# updates
+# * to_token: the new stream token: the point to get updates up to
+# * limit: the maximum number of rows to return
+#
+UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object):
@@ -50,7 +73,7 @@ class Stream(object):
ROW_TYPE = None # type: Any
@classmethod
- def parse_row(cls, row):
+ def parse_row(cls, row: StreamRow):
"""Parse a row received over replication
By default, assumes that the row data is an array object and passes its contents
@@ -64,7 +87,28 @@ class Stream(object):
"""
return cls.ROW_TYPE(*row)
- def __init__(self, hs):
+ def __init__(
+ self,
+ current_token_function: Callable[[], Token],
+ update_function: UpdateFunction,
+ ):
+ """Instantiate a Stream
+
+ current_token_function and update_function are callbacks which should be
+ implemented by subclasses.
+
+ current_token_function is called to get the current token of the underlying
+ stream.
+
+ update_function is called to get updates for this stream between a pair of
+ stream tokens. See the UpdateFunction type definition for more info.
+
+ Args:
+ current_token_function: callback to get the current token, as above
+ update_function: callback go get stream updates, as above
+ """
+ self.current_token = current_token_function
+ self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token()
@@ -75,7 +119,7 @@ class Stream(object):
"""
self.last_token = self.current_token()
- async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
+ async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
@@ -95,7 +139,7 @@ class Stream(object):
async def get_updates_since(
self, from_token: Token, upto_token: Token, limit: int = 100
- ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
+ ) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
@@ -112,33 +156,14 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
- from_token, upto_token, limit=limit,
+ from_token, upto_token, limit,
)
return updates, upto_token, limited
- def current_token(self):
- """Gets the current token of the underlying streams. Should be provided
- by the sub classes
-
- Returns:
- int
- """
- raise NotImplementedError()
-
- def update_function(self, from_token, current_token, limit):
- """Get updates between from_token and to_token.
-
- Returns:
- Deferred(list(tuple)): the first entry in the tuple is the token for
- that update, and the rest of the tuple gets used to construct
- a ``ROW_TYPE`` instance
- """
- raise NotImplementedError()
-
def db_query_to_update_function(
- query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
-) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
+) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
@@ -157,9 +182,7 @@ def db_query_to_update_function(
return update_function
-def make_http_update_function(
- hs, stream_name: str
-) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
@@ -168,7 +191,7 @@ def make_http_update_function(
async def update_function(
from_token: int, upto_token: int, limit: int
- ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ ) -> StreamUpdateResult:
result = await client(
stream_name=stream_name,
from_token=from_token,
@@ -202,10 +225,10 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_backfill_token # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
-
- super(BackfillStream, self).__init__(hs)
+ super().__init__(
+ store.get_current_backfill_token,
+ db_query_to_update_function(store.get_all_new_backfill_event_rows),
+ )
class PresenceStream(Stream):
@@ -227,19 +250,18 @@ class PresenceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- presence_handler = hs.get_presence_handler()
-
- self._is_worker = hs.config.worker_app is not None
-
- self.current_token = store.get_current_presence_token # type: ignore
if hs.config.worker_app is None:
- self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
+ # on the master, query the presence handler
+ presence_handler = hs.get_presence_handler()
+ update_function = db_query_to_update_function(
+ presence_handler.get_all_presence_updates
+ )
else:
# Query master process
- self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
+ update_function = make_http_update_function(hs, self.NAME)
- super(PresenceStream, self).__init__(hs)
+ super().__init__(store.get_current_presence_token, update_function)
class TypingStream(Stream):
@@ -253,15 +275,16 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- self.current_token = typing_handler.get_current_token # type: ignore
-
if hs.config.worker_app is None:
- self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
+ # on the master, query the typing handler
+ update_function = db_query_to_update_function(
+ typing_handler.get_all_typing_updates
+ )
else:
# Query master process
- self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
+ update_function = make_http_update_function(hs, self.NAME)
- super(TypingStream, self).__init__(hs)
+ super().__init__(typing_handler.get_current_token, update_function)
class ReceiptsStream(Stream):
@@ -281,11 +304,10 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_max_receipt_stream_id # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
-
- super(ReceiptsStream, self).__init__(hs)
+ super().__init__(
+ store.get_max_receipt_stream_id,
+ db_query_to_update_function(store.get_all_updated_receipts),
+ )
class PushRulesStream(Stream):
@@ -299,13 +321,15 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
- super(PushRulesStream, self).__init__(hs)
+ super(PushRulesStream, self).__init__(
+ self._current_token, self._update_function
+ )
- def current_token(self):
+ def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(self, from_token: Token, to_token: Token, limit: int):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False
@@ -331,10 +355,10 @@ class PushersStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_pushers_stream_token # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
-
- super(PushersStream, self).__init__(hs)
+ super().__init__(
+ store.get_pushers_stream_token,
+ db_query_to_update_function(store.get_all_updated_pushers_rows),
+ )
class CachesStream(Stream):
@@ -362,11 +386,10 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_cache_stream_token # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
-
- super(CachesStream, self).__init__(hs)
+ super().__init__(
+ store.get_cache_stream_token,
+ db_query_to_update_function(store.get_all_updated_caches),
+ )
class PublicRoomsStream(Stream):
@@ -388,11 +411,10 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_current_public_room_stream_id # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
-
- super(PublicRoomsStream, self).__init__(hs)
+ super().__init__(
+ store.get_current_public_room_stream_id,
+ db_query_to_update_function(store.get_all_new_public_rooms),
+ )
class DeviceListsStream(Stream):
@@ -409,11 +431,10 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
-
- super(DeviceListsStream, self).__init__(hs)
+ super().__init__(
+ store.get_device_stream_token,
+ db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+ )
class ToDeviceStream(Stream):
@@ -427,11 +448,10 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_to_device_stream_token # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
-
- super(ToDeviceStream, self).__init__(hs)
+ super().__init__(
+ store.get_to_device_stream_token,
+ db_query_to_update_function(store.get_all_new_device_messages),
+ )
class TagAccountDataStream(Stream):
@@ -447,11 +467,10 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_max_account_data_stream_id # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
-
- super(TagAccountDataStream, self).__init__(hs)
+ super().__init__(
+ store.get_max_account_data_stream_id,
+ db_query_to_update_function(store.get_all_updated_tags),
+ )
class AccountDataStream(Stream):
@@ -467,11 +486,10 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
-
- self.current_token = self.store.get_max_account_data_stream_id # type: ignore
- self.update_function = db_query_to_update_function(self._update_function) # type: ignore
-
- super(AccountDataStream, self).__init__(hs)
+ super().__init__(
+ self.store.get_max_account_data_stream_id,
+ db_query_to_update_function(self._update_function),
+ )
async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
@@ -498,11 +516,10 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_group_stream_token # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
-
- super(GroupServerStream, self).__init__(hs)
+ super().__init__(
+ store.get_group_stream_token,
+ db_query_to_update_function(store.get_all_groups_changes),
+ )
class UserSignatureStream(Stream):
@@ -516,8 +533,9 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
-
- super(UserSignatureStream, self).__init__(hs)
+ super().__init__(
+ store.get_device_stream_token,
+ db_query_to_update_function(
+ store.get_all_user_signature_changes_for_remotes
+ ),
+ )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index c6a595629f..051114596b 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,11 +15,11 @@
# limitations under the License.
import heapq
-from typing import Tuple, Type
+from typing import Iterable, Tuple, Type
import attr
-from ._base import Stream, db_query_to_update_function
+from ._base import Stream, Token, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -116,12 +116,14 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
- self.current_token = self._store.get_current_events_token # type: ignore
- self.update_function = db_query_to_update_function(self._update_function) # type: ignore
-
- super(EventsStream, self).__init__(hs)
+ super().__init__(
+ self._store.get_current_events_token,
+ db_query_to_update_function(self._update_function),
+ )
- async def _update_function(self, from_token, current_token, limit=None):
+ async def _update_function(
+ self, from_token: Token, current_token: Token, limit: int
+ ) -> Iterable[tuple]:
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 48c1d45718..75133d7e40 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,8 +15,6 @@
# limitations under the License.
from collections import namedtuple
-from twisted.internet import defer
-
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
@@ -35,7 +33,6 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
- _QUERY_MASTER = True
def __init__(self, hs):
# Not all synapse instances will have a federation sender instance,
@@ -43,10 +40,16 @@ class FederationStream(Stream):
# so we stub the stream out when that is the case.
if hs.config.worker_app is None or hs.should_send_federation():
federation_sender = hs.get_federation_sender()
- self.current_token = federation_sender.get_current_token # type: ignore
- self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
+ current_token = federation_sender.get_current_token
+ update_function = db_query_to_update_function(
+ federation_sender.get_replication_rows
+ )
else:
- self.current_token = lambda: 0 # type: ignore
- self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
+ current_token = lambda: 0
+ update_function = self._stub_update_function
+
+ super().__init__(current_token, update_function)
- super(FederationStream, self).__init__(hs)
+ @staticmethod
+ async def _stub_update_function(from_token, upto_token, limit):
+ return [], upto_token, False
|