summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/events.py')
-rw-r--r--synapse/replication/tcp/streams/events.py27
1 files changed, 16 insertions, 11 deletions
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..86a62b71eb 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,12 +15,15 @@
 # limitations under the License.
 import heapq
 from collections.abc import Iterable
-from typing import List, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
 
 import attr
 
 from ._base import Stream, StreamUpdateResult, Token
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 """Handling of the 'events' replication stream
 
 This stream contains rows of various types. Each row therefore contains a 'type'
@@ -81,12 +84,14 @@ class BaseEventsStreamRow:
 class EventsStreamEventRow(BaseEventsStreamRow):
     TypeId = "ev"
 
-    event_id = attr.ib()  # str
-    room_id = attr.ib()  # str
-    type = attr.ib()  # str
-    state_key = attr.ib()  # str, optional
-    redacts = attr.ib()  # str, optional
-    relates_to = attr.ib()  # str, optional
+    event_id = attr.ib(type=str)
+    room_id = attr.ib(type=str)
+    type = attr.ib(type=str)
+    state_key = attr.ib(type=Optional[str])
+    redacts = attr.ib(type=Optional[str])
+    relates_to = attr.ib(type=Optional[str])
+    membership = attr.ib(type=Optional[str])
+    rejected = attr.ib(type=bool)
 
 
 @attr.s(slots=True, frozen=True)
@@ -113,7 +118,7 @@ class EventsStream(Stream):
 
     NAME = "events"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -155,7 +160,7 @@ class EventsStream(Stream):
         # now we fetch up to that many rows from the events table
 
         event_rows = await self._store.get_all_new_forward_event_rows(
-            from_token, current_token, target_row_count
+            instance_name, from_token, current_token, target_row_count
         )  # type: List[Tuple]
 
         # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
@@ -180,7 +185,7 @@ class EventsStream(Stream):
             upper_limit,
             state_rows_limited,
         ) = await self._store.get_all_updated_current_state_deltas(
-            from_token, upper_limit, target_row_count
+            instance_name, from_token, upper_limit, target_row_count
         )
 
         limited = limited or state_rows_limited
@@ -189,7 +194,7 @@ class EventsStream(Stream):
         # not to bother with the limit.
 
         ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
-            from_token, upper_limit
+            instance_name, from_token, upper_limit
         )  # type: List[Tuple]
 
         # we now need to turn the raw database rows returned into tuples suitable