diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index b402f82810..aaf91e5e02 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -21,6 +21,8 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
+ from txredisapi import RedisProtocol
+
from synapse.server import HomeServer
set_counter = Counter(
@@ -59,7 +61,12 @@ class ExternalCache:
"""
def __init__(self, hs: "HomeServer"):
- self._redis_connection = hs.get_outbound_redis_connection()
+ if hs.config.redis.redis_enabled:
+ self._redis_connection: Optional[
+ "RedisProtocol"
+ ] = hs.get_outbound_redis_connection()
+ else:
+ self._redis_connection = None
def _get_redis_key(self, cache_name: str, key: str) -> str:
return "cache_v1:%s:%s" % (cache_name, key)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 6aa9318027..06fd06fdf3 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -294,7 +294,7 @@ class ReplicationCommandHandler:
# This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
- def start_replication(self, hs):
+ def start_replication(self, hs: "HomeServer"):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
@@ -321,6 +321,8 @@ class ReplicationCommandHandler:
hs.config.redis.redis_host, # type: ignore[arg-type]
hs.config.redis.redis_port,
self._factory,
+ timeout=30,
+ bindAddress=None,
)
else:
client_name = hs.get_instance_name()
@@ -331,6 +333,8 @@ class ReplicationCommandHandler:
host, # type: ignore[arg-type]
port,
self._factory,
+ timeout=30,
+ bindAddress=None,
)
def get_streams(self) -> Dict[str, Stream]:
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 8c80153ab6..7bae36db16 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -182,9 +182,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
- self._logging_context = BackgroundProcessLoggingContext(
- "replication-conn", self.conn_id
- )
+ with PreserveLoggingContext():
+ # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to
+ # capture the sentinel context as its containing context and won't prevent
+ # GC of / unintentionally reactivate what would be the current context.
+ self._logging_context = BackgroundProcessLoggingContext(
+ "replication-conn", self.conn_id
+ )
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -434,8 +438,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if self.transport:
self.transport.unregisterProducer()
- # mark the logging context as finished
- self._logging_context.__exit__(None, None, None)
+ # mark the logging context as finished by triggering `__exit__()`
+ with PreserveLoggingContext():
+ with self._logging_context:
+ pass
+ # the sentinel context is now active, which may not be correct.
+ # PreserveLoggingContext() will restore the correct logging context.
def __str__(self):
addr = None
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 062fe2f33e..8d28bd3f3f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -100,9 +100,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
- self._logging_context = BackgroundProcessLoggingContext(
- "replication_command_handler"
- )
+ with PreserveLoggingContext():
+ # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to
+ # capture the sentinel context as its containing context and won't prevent
+ # GC of / unintentionally reactivate what would be the current context.
+ self._logging_context = BackgroundProcessLoggingContext(
+ "replication_command_handler"
+ )
def connectionMade(self):
logger.info("Connected to redis")
@@ -182,8 +186,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
super().connectionLost(reason)
self.synapse_handler.lost_connection(self)
- # mark the logging context as finished
- self._logging_context.__exit__(None, None, None)
+ # mark the logging context as finished by triggering `__exit__()`
+ with PreserveLoggingContext():
+ with self._logging_context:
+ pass
+ # the sentinel context is now active, which may not be correct.
+ # PreserveLoggingContext() will restore the correct logging context.
def send_command(self, cmd: Command):
"""Send a command if connection has been established.
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 80f9b23bfd..55326877fd 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -16,6 +16,7 @@
import logging
import random
+from typing import TYPE_CHECKING
from prometheus_client import Counter
@@ -27,6 +28,9 @@ from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
)
@@ -37,7 +41,7 @@ logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory):
"""Factory for new replication connections."""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.command_handler = hs.get_tcp_replication()
self.clock = hs.get_clock()
self.server_name = hs.config.server.server_name
@@ -65,7 +69,7 @@ class ReplicationStreamer:
data is available it will propagate to all connected clients.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9b905aba9d..c8b188ae4e 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -241,7 +241,7 @@ class BackfillStream(Stream):
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -363,7 +363,7 @@ class ReceiptsStream(Stream):
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -380,7 +380,7 @@ class PushRulesStream(Stream):
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
super().__init__(
@@ -405,7 +405,7 @@ class PushersStream(Stream):
NAME = "pushers"
ROW_TYPE = PushersStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
@@ -438,7 +438,7 @@ class CachesStream(Stream):
NAME = "caches"
ROW_TYPE = CachesStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -459,7 +459,7 @@ class DeviceListsStream(Stream):
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -476,7 +476,7 @@ class ToDeviceStream(Stream):
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -495,7 +495,7 @@ class TagAccountDataStream(Stream):
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -582,7 +582,7 @@ class GroupServerStream(Stream):
NAME = "groups"
ROW_TYPE = GroupsStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -599,7 +599,7 @@ class UserSignatureStream(Stream):
NAME = "user_signature"
ROW_TYPE = UserSignatureStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
|