diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 4f4f1ad453..50c4a5ba03 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
-from collections.abc import Iterable
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr
-from ._base import Stream, StreamUpdateResult, Token
+from synapse.replication.tcp.streams._base import (
+ Stream,
+ StreamRow,
+ StreamUpdateResult,
+ Token,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -58,6 +62,9 @@ class EventsStreamRow:
data: "BaseEventsStreamRow"
+T = TypeVar("T", bound="BaseEventsStreamRow")
+
+
class BaseEventsStreamRow:
"""Base class for rows to be sent in the events stream.
@@ -68,7 +75,7 @@ class BaseEventsStreamRow:
TypeId: str
@classmethod
- def from_data(cls, data):
+ def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T:
"""Parse the data from the replication stream into a row.
By default we just call the constructor with the data list as arguments
@@ -221,7 +228,7 @@ class EventsStream(Stream):
return updates, upper_limit, limited
@classmethod
- def parse_row(cls, row):
- (typ, data) = row
- data = TypeToRow[typ].from_data(data)
- return EventsStreamRow(typ, data)
+ def parse_row(cls, row: StreamRow) -> "EventsStreamRow":
+ (typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row)
+ event_stream_row_data = TypeToRow[typ].from_data(data)
+ return EventsStreamRow(typ, event_stream_row_data)
|