summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py10
-rw-r--r--synapse/replication/tcp/commands.py6
-rw-r--r--synapse/replication/tcp/handler.py16
-rw-r--r--synapse/replication/tcp/protocol.py14
-rw-r--r--synapse/replication/tcp/redis.py8
-rw-r--r--synapse/replication/tcp/streams/_base.py14
-rw-r--r--synapse/replication/tcp/streams/events.py28
-rw-r--r--synapse/replication/tcp/streams/federation.py6
8 files changed, 50 insertions, 52 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 62d7809175..9d4859798b 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -121,13 +121,13 @@ class ReplicationDataHandler:
         self._pusher_pool = hs.get_pusherpool()
         self._presence_handler = hs.get_presence_handler()
 
-        self.send_handler = None  # type: Optional[FederationSenderHandler]
+        self.send_handler: Optional[FederationSenderHandler] = None
         if hs.should_send_federation():
             self.send_handler = FederationSenderHandler(hs)
 
         # Map from stream to list of deferreds waiting for the stream to
         # arrive at a particular position. The lists are sorted by stream position.
-        self._streams_to_waiters = {}  # type: Dict[str, List[Tuple[int, Deferred]]]
+        self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {}
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
@@ -173,7 +173,7 @@ class ReplicationDataHandler:
             if entities:
                 self.notifier.on_new_event("to_device_key", token, users=entities)
         elif stream_name == DeviceListsStream.NAME:
-            all_room_ids = set()  # type: Set[str]
+            all_room_ids: Set[str] = set()
             for row in rows:
                 if row.entity.startswith("@"):
                     room_ids = await self.store.get_rooms_for_user(row.entity)
@@ -201,7 +201,7 @@ class ReplicationDataHandler:
                 if row.data.rejected:
                     continue
 
-                extra_users = ()  # type: Tuple[UserID, ...]
+                extra_users: Tuple[UserID, ...] = ()
                 if row.data.type == EventTypes.Member and row.data.state_key:
                     extra_users = (UserID.from_string(row.data.state_key),)
 
@@ -348,7 +348,7 @@ class FederationSenderHandler:
 
         # Stores the latest position in the federation stream we've gotten up
         # to. This is always set before we use it.
-        self.federation_position = None  # type: Optional[int]
+        self.federation_position: Optional[int] = None
 
         self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 505d450e19..1311b013da 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -34,7 +34,7 @@ class Command(metaclass=abc.ABCMeta):
     A full command line on the wire is constructed from `NAME + " " + to_line()`
     """
 
-    NAME = None  # type: str
+    NAME: str
 
     @classmethod
     @abc.abstractmethod
@@ -380,7 +380,7 @@ class RemoteServerUpCommand(_SimpleCommand):
     NAME = "REMOTE_SERVER_UP"
 
 
-_COMMANDS = (
+_COMMANDS: Tuple[Type[Command], ...] = (
     ServerCommand,
     RdataCommand,
     PositionCommand,
@@ -393,7 +393,7 @@ _COMMANDS = (
     UserIpCommand,
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
-)  # type: Tuple[Type[Command], ...]
+)
 
 # Map of command name to command type.
 COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 2ad7a200bb..eae4515363 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -105,12 +105,12 @@ class ReplicationCommandHandler:
             hs.get_instance_name() in hs.config.worker.writers.presence
         )
 
-        self._streams = {
+        self._streams: Dict[str, Stream] = {
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
-        }  # type: Dict[str, Stream]
+        }
 
         # List of streams that this instance is the source of
-        self._streams_to_replicate = []  # type: List[Stream]
+        self._streams_to_replicate: List[Stream] = []
 
         for stream in self._streams.values():
             if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
@@ -180,14 +180,14 @@ class ReplicationCommandHandler:
 
         # Map of stream name to batched updates. See RdataCommand for info on
         # how batching works.
-        self._pending_batches = {}  # type: Dict[str, List[Any]]
+        self._pending_batches: Dict[str, List[Any]] = {}
 
         # The factory used to create connections.
-        self._factory = None  # type: Optional[ReconnectingClientFactory]
+        self._factory: Optional[ReconnectingClientFactory] = None
 
         # The currently connected connections. (The list of places we need to send
         # outgoing replication commands to.)
-        self._connections = []  # type: List[IReplicationConnection]
+        self._connections: List[IReplicationConnection] = []
 
         LaterGauge(
             "synapse_replication_tcp_resource_total_connections",
@@ -200,7 +200,7 @@ class ReplicationCommandHandler:
         # them in order in a separate background process.
 
         # the streams which are currently being processed by _unsafe_process_queue
-        self._processing_streams = set()  # type: Set[str]
+        self._processing_streams: Set[str] = set()
 
         # for each stream, a queue of commands that are awaiting processing, and the
         # connection that they arrived on.
@@ -210,7 +210,7 @@ class ReplicationCommandHandler:
 
         # For each connection, the incoming stream names that have received a POSITION
         # from that connection.
-        self._streams_by_connection = {}  # type: Dict[IReplicationConnection, Set[str]]
+        self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {}
 
         LaterGauge(
             "synapse_replication_tcp_command_queue",
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 6e3705364f..8c80153ab6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -102,7 +102,7 @@ tcp_outbound_commands_counter = Counter(
 
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
-connected_connections = []  # type: List[BaseReplicationStreamProtocol]
+connected_connections: "List[BaseReplicationStreamProtocol]" = []
 
 
 logger = logging.getLogger(__name__)
@@ -146,15 +146,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     # The transport is going to be an ITCPTransport, but that doesn't have the
     # (un)registerProducer methods, those are only on the implementation.
-    transport = None  # type: Connection
+    transport: Connection
 
     delimiter = b"\n"
 
     # Valid commands we expect to receive
-    VALID_INBOUND_COMMANDS = []  # type: Collection[str]
+    VALID_INBOUND_COMMANDS: Collection[str] = []
 
     # Valid commands we can send
-    VALID_OUTBOUND_COMMANDS = []  # type: Collection[str]
+    VALID_OUTBOUND_COMMANDS: Collection[str] = []
 
     max_line_buffer = 10000
 
@@ -165,7 +165,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
         # When we requested the connection be closed
-        self.time_we_closed = None  # type: Optional[int]
+        self.time_we_closed: Optional[int] = None
 
         self.received_ping = False  # Have we received a ping from the other side
 
@@ -175,10 +175,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.conn_id = random_string(5)  # To dedupe in case of name clashes.
 
         # List of pending commands to send once we've established the connection
-        self.pending_commands = []  # type: List[Command]
+        self.pending_commands: List[Command] = []
 
         # The LoopingCall for sending pings.
-        self._send_ping_loop = None  # type: Optional[task.LoopingCall]
+        self._send_ping_loop: Optional[task.LoopingCall] = None
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 6a2c2655e4..8c0df627c8 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -57,7 +57,7 @@ class ConstantProperty(Generic[T, V]):
     it.
     """
 
-    constant = attr.ib()  # type: V
+    constant: V = attr.ib()
 
     def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
         return self.constant
@@ -91,9 +91,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
             commands.
     """
 
-    synapse_handler = None  # type: ReplicationCommandHandler
-    synapse_stream_name = None  # type: str
-    synapse_outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+    synapse_handler: "ReplicationCommandHandler"
+    synapse_stream_name: str
+    synapse_outbound_redis_connection: txredisapi.RedisProtocol
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b03824925a..3716c41bea 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -85,9 +85,9 @@ class Stream:
     time it was called.
     """
 
-    NAME = None  # type: str  # The name of the stream
+    NAME: str  # The name of the stream
     # The type of the row. Used by the default impl of parse_row.
-    ROW_TYPE = None  # type: Any
+    ROW_TYPE: Any = None
 
     @classmethod
     def parse_row(cls, row: StreamRow):
@@ -283,9 +283,7 @@ class PresenceStream(Stream):
 
             assert isinstance(presence_handler, PresenceHandler)
 
-            update_function = (
-                presence_handler.get_all_presence_updates
-            )  # type: UpdateFunction
+            update_function: UpdateFunction = presence_handler.get_all_presence_updates
         else:
             # Query presence writer process
             update_function = make_http_update_function(hs, self.NAME)
@@ -334,9 +332,9 @@ class TypingStream(Stream):
         if writer_instance == hs.get_instance_name():
             # On the writer, query the typing handler
             typing_writer_handler = hs.get_typing_writer_handler()
-            update_function = (
-                typing_writer_handler.get_all_typing_updates
-            )  # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+            update_function: Callable[
+                [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
+            ] = typing_writer_handler.get_all_typing_updates
             current_token_function = typing_writer_handler.get_current_token
         else:
             # Query the typing writer process
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index e7e87bac92..a030e9299e 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -65,7 +65,7 @@ class BaseEventsStreamRow:
     """
 
     # Unique string that ids the type. Must be overridden in sub classes.
-    TypeId = None  # type: str
+    TypeId: str
 
     @classmethod
     def from_data(cls, data):
@@ -103,10 +103,10 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
     event_id = attr.ib()  # str, optional
 
 
-_EventRows = (
+_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
     EventsStreamEventRow,
     EventsStreamCurrentStateRow,
-)  # type: Tuple[Type[BaseEventsStreamRow], ...]
+)
 
 TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
@@ -157,9 +157,9 @@ class EventsStream(Stream):
 
         # now we fetch up to that many rows from the events table
 
-        event_rows = await self._store.get_all_new_forward_event_rows(
+        event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
             instance_name, from_token, current_token, target_row_count
-        )  # type: List[Tuple]
+        )
 
         # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
         # that we know it is safe to just take upper_limit = event_rows[-1][0].
@@ -172,7 +172,7 @@ class EventsStream(Stream):
 
         if len(event_rows) == target_row_count:
             limited = True
-            upper_limit = event_rows[-1][0]  # type: int
+            upper_limit: int = event_rows[-1][0]
         else:
             limited = False
             upper_limit = current_token
@@ -191,30 +191,30 @@ class EventsStream(Stream):
         # finally, fetch the ex-outliers rows. We assume there are few enough of these
         # not to bother with the limit.
 
-        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+        ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
             instance_name, from_token, upper_limit
-        )  # type: List[Tuple]
+        )
 
         # we now need to turn the raw database rows returned into tuples suitable
         # for the replication protocol (basically, we add an identifier to
         # distinguish the row type). At the same time, we can limit the event_rows
         # to the max stream_id from state_rows.
 
-        event_updates = (
+        event_updates: Iterable[Tuple[int, Tuple]] = (
             (stream_id, (EventsStreamEventRow.TypeId, rest))
             for (stream_id, *rest) in event_rows
             if stream_id <= upper_limit
-        )  # type: Iterable[Tuple[int, Tuple]]
+        )
 
-        state_updates = (
+        state_updates: Iterable[Tuple[int, Tuple]] = (
             (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
             for (stream_id, *rest) in state_rows
-        )  # type: Iterable[Tuple[int, Tuple]]
+        )
 
-        ex_outliers_updates = (
+        ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
             (stream_id, (EventsStreamEventRow.TypeId, rest))
             for (stream_id, *rest) in ex_outliers_rows
-        )  # type: Iterable[Tuple[int, Tuple]]
+        )
 
         # we need to return a sorted list, so merge them together.
         updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 096a85d363..c445af9bd9 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -51,9 +51,9 @@ class FederationStream(Stream):
             current_token = current_token_without_instance(
                 federation_sender.get_current_token
             )
-            update_function = (
-                federation_sender.get_replication_rows
-            )  # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+            update_function: Callable[
+                [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
+            ] = federation_sender.get_replication_rows
 
         elif hs.should_send_federation():
             # federation sender: Query master process