summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py5
-rw-r--r--synapse/replication/http/account_data.py6
-rw-r--r--synapse/replication/http/membership.py5
-rw-r--r--synapse/replication/http/register.py6
-rw-r--r--synapse/replication/tcp/commands.py3
-rw-r--r--synapse/replication/tcp/external_cache.py10
-rw-r--r--synapse/replication/tcp/handler.py30
-rw-r--r--synapse/replication/tcp/protocol.py27
-rw-r--r--synapse/replication/tcp/redis.py32
-rw-r--r--synapse/replication/tcp/resource.py6
-rw-r--r--synapse/replication/tcp/streams/_base.py26
-rw-r--r--synapse/replication/tcp/streams/events.py3
12 files changed, 86 insertions, 73 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 288727a566..8a3f113e76 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -273,7 +273,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
 
         http_server.register_paths(
-            method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
+            method,
+            [pattern],
+            self._check_auth_and_handle,
+            self.__class__.__name__,
         )
 
     def _check_auth_and_handle(self, request, **kwargs):
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 52d32528ee..60899b6ad6 100644
--- a/synapse/replication/http/account_data.py
+++ b/synapse/replication/http/account_data.py
@@ -175,7 +175,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(self, request, user_id, room_id, tag):
-        max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
+        max_stream_id = await self.handler.remove_tag_from_room(
+            user_id,
+            room_id,
+            tag,
+        )
 
         return 200, {"max_stream_id": max_stream_id}
 
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 84e002f934..439881be67 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -160,7 +160,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
 
         # hopefully we're now on the master, so this won't recurse!
         event_id, stream_id = await self.member_handler.remote_reject_invite(
-            invite_event_id, txn_id, requester, event_content,
+            invite_event_id,
+            txn_id,
+            requester,
+            event_content,
         )
 
         return 200, {"event_id": event_id, "stream_id": stream_id}
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 7b12ec9060..d005f38767 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
 
 
 class ReplicationRegisterServlet(ReplicationEndpoint):
-    """Register a new user
-    """
+    """Register a new user"""
 
     NAME = "register_user"
     PATH_ARGS = ("user_id",)
@@ -97,8 +96,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
 
 
 class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
-    """Run any post registration actions
-    """
+    """Run any post registration actions"""
 
     NAME = "post_register"
     PATH_ARGS = ("user_id",)
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index ac532ed588..0a9da79c32 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -196,8 +196,7 @@ class ErrorCommand(_SimpleCommand):
 
 
 class PingCommand(_SimpleCommand):
-    """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
-    """
+    """Sent by either side as a keep alive. The data is arbitrary (often timestamp)"""
 
     NAME = "PING"
 
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index 34fa3ff5b3..d89a36f25a 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -60,8 +60,7 @@ class ExternalCache:
         return self._redis_connection is not None
 
     async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
-        """Add the key/value to the named cache, with the expiry time given.
-        """
+        """Add the key/value to the named cache, with the expiry time given."""
 
         if self._redis_connection is None:
             return
@@ -76,13 +75,14 @@ class ExternalCache:
 
         return await make_deferred_yieldable(
             self._redis_connection.set(
-                self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+                self._get_redis_key(cache_name, key),
+                encoded_value,
+                pexpire=expiry_ms,
             )
         )
 
     async def get(self, cache_name: str, key: str) -> Optional[Any]:
-        """Look up a key/value in the named cache.
-        """
+        """Look up a key/value in the named cache."""
 
         if self._redis_connection is None:
             return None
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 8ea8dcd587..d1d00c3717 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -303,7 +303,9 @@ class ReplicationCommandHandler:
                 hs, outbound_redis_connection
             )
             hs.get_reactor().connectTCP(
-                hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
+                hs.config.redis.redis_host,
+                hs.config.redis.redis_port,
+                self._factory,
             )
         else:
             client_name = hs.get_instance_name()
@@ -313,13 +315,11 @@ class ReplicationCommandHandler:
             hs.get_reactor().connectTCP(host, port, self._factory)
 
     def get_streams(self) -> Dict[str, Stream]:
-        """Get a map from stream name to all streams.
-        """
+        """Get a map from stream name to all streams."""
         return self._streams
 
     def get_streams_to_replicate(self) -> List[Stream]:
-        """Get a list of streams that this instances replicates.
-        """
+        """Get a list of streams that this instances replicates."""
         return self._streams_to_replicate
 
     def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
@@ -340,7 +340,10 @@ class ReplicationCommandHandler:
             current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(
-                    stream.NAME, self._instance_name, current_token, current_token,
+                    stream.NAME,
+                    self._instance_name,
+                    current_token,
+                    current_token,
                 )
             )
 
@@ -592,8 +595,7 @@ class ReplicationCommandHandler:
         self.send_command(cmd, ignore_conn=conn)
 
     def new_connection(self, connection: AbstractConnection):
-        """Called when we have a new connection.
-        """
+        """Called when we have a new connection."""
         self._connections.append(connection)
 
         # If we are connected to replication as a client (rather than a server)
@@ -620,8 +622,7 @@ class ReplicationCommandHandler:
             )
 
     def lost_connection(self, connection: AbstractConnection):
-        """Called when a connection is closed/lost.
-        """
+        """Called when a connection is closed/lost."""
         # we no longer need _streams_by_connection for this connection.
         streams = self._streams_by_connection.pop(connection, None)
         if streams:
@@ -678,15 +679,13 @@ class ReplicationCommandHandler:
     def send_user_sync(
         self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
     ):
-        """Poke the master that a user has started/stopped syncing.
-        """
+        """Poke the master that a user has started/stopped syncing."""
         self.send_command(
             UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
         )
 
     def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
-        """Poke the master to remove a pusher for a user
-        """
+        """Poke the master to remove a pusher for a user"""
         cmd = RemovePusherCommand(app_id, push_key, user_id)
         self.send_command(cmd)
 
@@ -699,8 +698,7 @@ class ReplicationCommandHandler:
         device_id: str,
         last_seen: int,
     ):
-        """Tell the master that the user made a request.
-        """
+        """Tell the master that the user made a request."""
         cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
         self.send_command(cmd)
 
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 804da994ea..e0b4ad314d 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -222,8 +222,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 self.send_error("ping timeout")
 
     def lineReceived(self, line: bytes):
-        """Called when we've received a line
-        """
+        """Called when we've received a line"""
         with PreserveLoggingContext(self._logging_context):
             self._parse_and_dispatch_line(line)
 
@@ -299,8 +298,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.on_connection_closed()
 
     def send_error(self, error_string, *args):
-        """Send an error to remote and close the connection.
-        """
+        """Send an error to remote and close the connection."""
         self.send_command(ErrorCommand(error_string % args))
         self.close()
 
@@ -341,8 +339,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.last_sent_command = self.clock.time_msec()
 
     def _queue_command(self, cmd):
-        """Queue the command until the connection is ready to write to again.
-        """
+        """Queue the command until the connection is ready to write to again."""
         logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
         self.pending_commands.append(cmd)
 
@@ -355,8 +352,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             self.close()
 
     def _send_pending_commands(self):
-        """Send any queued commandes
-        """
+        """Send any queued commandes"""
         pending = self.pending_commands
         self.pending_commands = []
         for cmd in pending:
@@ -380,8 +376,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.PAUSED
 
     def resumeProducing(self):
-        """The remote has caught up after we started buffering!
-        """
+        """The remote has caught up after we started buffering!"""
         logger.info("[%s] Resume producing", self.id())
         self.state = ConnectionStates.ESTABLISHED
         self._send_pending_commands()
@@ -440,8 +435,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         return "%s-%s" % (self.name, self.conn_id)
 
     def lineLengthExceeded(self, line):
-        """Called when we receive a line that is above the maximum line length
-        """
+        """Called when we receive a line that is above the maximum line length"""
         self.send_error("Line length exceeded")
 
 
@@ -495,21 +489,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             self.send_error("Wrong remote")
 
     def replicate(self):
-        """Send the subscription request to the server
-        """
+        """Send the subscription request to the server"""
         logger.info("[%s] Subscribing to replication streams", self.id())
 
         self.send_command(ReplicateCommand())
 
 
 class AbstractConnection(abc.ABC):
-    """An interface for replication connections.
-    """
+    """An interface for replication connections."""
 
     @abc.abstractmethod
     def send_command(self, cmd: Command):
-        """Send the command down the connection
-        """
+        """Send the command down the connection"""
         pass
 
 
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index fdd087683b..0e6155cf53 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,8 +15,9 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional, Type, cast
+from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
 
+import attr
 import txredisapi
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -42,6 +43,24 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+T = TypeVar("T")
+V = TypeVar("V")
+
+
+@attr.s
+class ConstantProperty(Generic[T, V]):
+    """A descriptor that returns the given constant, ignoring attempts to set
+    it.
+    """
+
+    constant = attr.ib()  # type: V
+
+    def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V:
+        return self.constant
+
+    def __set__(self, obj: Optional[T], value: V):
+        pass
+
 
 class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     """Connection to redis subscribed to replication stream.
@@ -104,8 +123,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         self.synapse_handler.send_positions_to_connection(self)
 
     def messageReceived(self, pattern: str, channel: str, message: str):
-        """Received a message from redis.
-        """
+        """Received a message from redis."""
         with PreserveLoggingContext(self._logging_context):
             self._parse_and_dispatch_message(message)
 
@@ -118,7 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
             cmd = parse_command_from_line(message)
         except Exception:
             logger.exception(
-                "Failed to parse replication line: %r", message,
+                "Failed to parse replication line: %r",
+                message,
             )
             return
 
@@ -195,6 +214,10 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
     we detect dead connections.
     """
 
+    # We want to *always* retry connecting, txredisapi will stop if there is a
+    # failure during certain operations, e.g. during AUTH.
+    continueTrying = cast(bool, ConstantProperty(True))
+
     def __init__(
         self,
         hs: "HomeServer",
@@ -243,7 +266,6 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     """
 
     maxDelay = 5
-    continueTrying = True
     protocol = RedisSubscriber
 
     def __init__(
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 1d4ceac0f1..2018f9f29e 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -36,8 +36,7 @@ logger = logging.getLogger(__name__)
 
 
 class ReplicationStreamProtocolFactory(Factory):
-    """Factory for new replication connections.
-    """
+    """Factory for new replication connections."""
 
     def __init__(self, hs):
         self.command_handler = hs.get_tcp_replication()
@@ -181,7 +180,8 @@ class ReplicationStreamer:
                             raise
 
                         logger.debug(
-                            "Sending %d updates", len(updates),
+                            "Sending %d updates",
+                            len(updates),
                         )
 
                         if updates:
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 61b282ab2d..38809b5b7c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -183,7 +183,10 @@ class Stream:
             return [], upto_token, False
 
         updates, upto_token, limited = await self.update_function(
-            instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+            instance_name,
+            from_token,
+            upto_token,
+            _STREAM_UPDATE_TARGET_ROW_COUNT,
         )
         return updates, upto_token, limited
 
@@ -339,8 +342,7 @@ class ReceiptsStream(Stream):
 
 
 class PushRulesStream(Stream):
-    """A user has changed their push rules
-    """
+    """A user has changed their push rules"""
 
     PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
 
@@ -362,8 +364,7 @@ class PushRulesStream(Stream):
 
 
 class PushersStream(Stream):
-    """A user has added/changed/removed a pusher
-    """
+    """A user has added/changed/removed a pusher"""
 
     PushersStreamRow = namedtuple(
         "PushersStreamRow",
@@ -416,8 +417,7 @@ class CachesStream(Stream):
 
 
 class PublicRoomsStream(Stream):
-    """The public rooms list changed
-    """
+    """The public rooms list changed"""
 
     PublicRoomsStreamRow = namedtuple(
         "PublicRoomsStreamRow",
@@ -463,8 +463,7 @@ class DeviceListsStream(Stream):
 
 
 class ToDeviceStream(Stream):
-    """New to_device messages for a client
-    """
+    """New to_device messages for a client"""
 
     ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
 
@@ -481,8 +480,7 @@ class ToDeviceStream(Stream):
 
 
 class TagAccountDataStream(Stream):
-    """Someone added/removed a tag for a room
-    """
+    """Someone added/removed a tag for a room"""
 
     TagAccountDataStreamRow = namedtuple(
         "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
@@ -501,8 +499,7 @@ class TagAccountDataStream(Stream):
 
 
 class AccountDataStream(Stream):
-    """Global or per room account data was changed
-    """
+    """Global or per room account data was changed"""
 
     AccountDataStreamRow = namedtuple(
         "AccountDataStream",
@@ -589,8 +586,7 @@ class GroupServerStream(Stream):
 
 
 class UserSignatureStream(Stream):
-    """A user has signed their own device with their user-signing key
-    """
+    """A user has signed their own device with their user-signing key"""
 
     UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 86a62b71eb..fa5e37ba7b 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -113,8 +113,7 @@ TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
 
 class EventsStream(Stream):
-    """We received a new event, or an event went from being an outlier to not
-    """
+    """We received a new event, or an event went from being an outlier to not"""
 
     NAME = "events"