diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b03824925a..3716c41bea 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -85,9 +85,9 @@ class Stream:
time it was called.
"""
- NAME = None # type: str # The name of the stream
+ NAME: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
- ROW_TYPE = None # type: Any
+ ROW_TYPE: Any = None
@classmethod
def parse_row(cls, row: StreamRow):
@@ -283,9 +283,7 @@ class PresenceStream(Stream):
assert isinstance(presence_handler, PresenceHandler)
- update_function = (
- presence_handler.get_all_presence_updates
- ) # type: UpdateFunction
+ update_function: UpdateFunction = presence_handler.get_all_presence_updates
else:
# Query presence writer process
update_function = make_http_update_function(hs, self.NAME)
@@ -334,9 +332,9 @@ class TypingStream(Stream):
if writer_instance == hs.get_instance_name():
# On the writer, query the typing handler
typing_writer_handler = hs.get_typing_writer_handler()
- update_function = (
- typing_writer_handler.get_all_typing_updates
- ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+ update_function: Callable[
+ [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
+ ] = typing_writer_handler.get_all_typing_updates
current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index e7e87bac92..a030e9299e 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -65,7 +65,7 @@ class BaseEventsStreamRow:
"""
# Unique string that ids the type. Must be overridden in sub classes.
- TypeId = None # type: str
+ TypeId: str
@classmethod
def from_data(cls, data):
@@ -103,10 +103,10 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id = attr.ib() # str, optional
-_EventRows = (
+_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
EventsStreamEventRow,
EventsStreamCurrentStateRow,
-) # type: Tuple[Type[BaseEventsStreamRow], ...]
+)
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
@@ -157,9 +157,9 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
- event_rows = await self._store.get_all_new_forward_event_rows(
+ event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
- ) # type: List[Tuple]
+ )
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so
# that we know it is safe to just take upper_limit = event_rows[-1][0].
@@ -172,7 +172,7 @@ class EventsStream(Stream):
if len(event_rows) == target_row_count:
limited = True
- upper_limit = event_rows[-1][0] # type: int
+ upper_limit: int = event_rows[-1][0]
else:
limited = False
upper_limit = current_token
@@ -191,30 +191,30 @@ class EventsStream(Stream):
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
- ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+ ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
- ) # type: List[Tuple]
+ )
# we now need to turn the raw database rows returned into tuples suitable
# for the replication protocol (basically, we add an identifier to
# distinguish the row type). At the same time, we can limit the event_rows
# to the max stream_id from state_rows.
- event_updates = (
+ event_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in event_rows
if stream_id <= upper_limit
- ) # type: Iterable[Tuple[int, Tuple]]
+ )
- state_updates = (
+ state_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows
- ) # type: Iterable[Tuple[int, Tuple]]
+ )
- ex_outliers_updates = (
+ ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in ex_outliers_rows
- ) # type: Iterable[Tuple[int, Tuple]]
+ )
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 096a85d363..c445af9bd9 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -51,9 +51,9 @@ class FederationStream(Stream):
current_token = current_token_without_instance(
federation_sender.get_current_token
)
- update_function = (
- federation_sender.get_replication_rows
- ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+ update_function: Callable[
+ [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
+ ] = federation_sender.get_replication_rows
elif hs.should_send_federation():
# federation sender: Query master process
|