summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/events.py')
-rw-r--r--synapse/replication/tcp/streams/events.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 4f4f1ad453..50c4a5ba03 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import heapq
-from collections.abc import Iterable
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token
+from synapse.replication.tcp.streams._base import (
+    Stream,
+    StreamRow,
+    StreamUpdateResult,
+    Token,
+)
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -58,6 +62,9 @@ class EventsStreamRow:
     data: "BaseEventsStreamRow"
 
 
+T = TypeVar("T", bound="BaseEventsStreamRow")
+
+
 class BaseEventsStreamRow:
     """Base class for rows to be sent in the events stream.
 
@@ -68,7 +75,7 @@ class BaseEventsStreamRow:
     TypeId: str
 
     @classmethod
-    def from_data(cls, data):
+    def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T:
         """Parse the data from the replication stream into a row.
 
         By default we just call the constructor with the data list as arguments
@@ -221,7 +228,7 @@ class EventsStream(Stream):
         return updates, upper_limit, limited
 
     @classmethod
-    def parse_row(cls, row):
-        (typ, data) = row
-        data = TypeToRow[typ].from_data(data)
-        return EventsStreamRow(typ, data)
+    def parse_row(cls, row: StreamRow) -> "EventsStreamRow":
+        (typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row)
+        event_stream_row_data = TypeToRow[typ].from_data(data)
+        return EventsStreamRow(typ, event_stream_row_data)