diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 5a2d90c530..914b9eae84 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -90,7 +90,7 @@ class Stream:
ROW_TYPE: Any = None
@classmethod
- def parse_row(cls, row: StreamRow):
+ def parse_row(cls, row: StreamRow) -> Any:
"""Parse a row received over replication
By default, assumes that the row data is an array object and passes its contents
@@ -139,7 +139,7 @@ class Stream:
# The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name)
- def discard_updates_and_advance(self):
+ def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
@@ -200,7 +200,7 @@ def current_token_without_instance(
return lambda instance_name: current_token()
-def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
+def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 4f4f1ad453..50c4a5ba03 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
-from collections.abc import Iterable
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr
-from ._base import Stream, StreamUpdateResult, Token
+from synapse.replication.tcp.streams._base import (
+ Stream,
+ StreamRow,
+ StreamUpdateResult,
+ Token,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -58,6 +62,9 @@ class EventsStreamRow:
data: "BaseEventsStreamRow"
+T = TypeVar("T", bound="BaseEventsStreamRow")
+
+
class BaseEventsStreamRow:
"""Base class for rows to be sent in the events stream.
@@ -68,7 +75,7 @@ class BaseEventsStreamRow:
TypeId: str
@classmethod
- def from_data(cls, data):
+ def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T:
"""Parse the data from the replication stream into a row.
By default we just call the constructor with the data list as arguments
@@ -221,7 +228,7 @@ class EventsStream(Stream):
return updates, upper_limit, limited
@classmethod
- def parse_row(cls, row):
- (typ, data) = row
- data = TypeToRow[typ].from_data(data)
- return EventsStreamRow(typ, data)
+ def parse_row(cls, row: StreamRow) -> "EventsStreamRow":
+ (typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row)
+ event_stream_row_data = TypeToRow[typ].from_data(data)
+ return EventsStreamRow(typ, event_stream_row_data)
|