diff --git a/synapse/config/redis.py b/synapse/config/redis.py
index e6a75be434..636cb450b8 100644
--- a/synapse/config/redis.py
+++ b/synapse/config/redis.py
@@ -35,3 +35,9 @@ class RedisConfig(Config):
self.redis_port = redis_config.get("port", 6379)
self.redis_dbid = redis_config.get("dbid", None)
self.redis_password = redis_config.get("password")
+
+ self.redis_use_tls = redis_config.get("use_tls", False)
+ self.redis_certificate = redis_config.get("certificate_file", None)
+ self.redis_private_key = redis_config.get("private_key_file", None)
+ self.redis_ca_file = redis_config.get("ca_file", None)
+ self.redis_ca_path = redis_config.get("ca_path", None)
diff --git a/synapse/replication/tcp/context.py b/synapse/replication/tcp/context.py
new file mode 100644
index 0000000000..4688b2200b
--- /dev/null
+++ b/synapse/replication/tcp/context.py
@@ -0,0 +1,34 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from OpenSSL.SSL import Context
+from twisted.internet import ssl
+
+from synapse.config.redis import RedisConfig
+
+
+class ClientContextFactory(ssl.ClientContextFactory):
+ def __init__(self, redis_config: RedisConfig):
+ self.redis_config = redis_config
+
+ def getContext(self) -> Context:
+ ctx = super().getContext()
+ if self.redis_config.redis_certificate:
+ ctx.use_certificate_file(self.redis_config.redis_certificate)
+ if self.redis_config.redis_private_key:
+ ctx.use_privatekey_file(self.redis_config.redis_private_key)
+ if self.redis_config.redis_ca_file:
+ ctx.load_verify_locations(cafile=self.redis_config.redis_ca_file)
+ elif self.redis_config.redis_ca_path:
+ ctx.load_verify_locations(capath=self.redis_config.redis_ca_path)
+ return ctx
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 2290b3e6fe..233ad61d49 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -46,6 +46,7 @@ from synapse.replication.tcp.commands import (
UserIpCommand,
UserSyncCommand,
)
+from synapse.replication.tcp.context import ClientContextFactory
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
@@ -348,13 +349,27 @@ class ReplicationCommandHandler:
outbound_redis_connection,
channel_names=self._channels_to_subscribe_to,
)
- hs.get_reactor().connectTCP(
- hs.config.redis.redis_host,
- hs.config.redis.redis_port,
- self._factory,
- timeout=30,
- bindAddress=None,
- )
+
+ reactor = hs.get_reactor()
+ redis_config = hs.config.redis
+ if hs.config.redis.redis_use_tls:
+ ssl_context_factory = ClientContextFactory(hs.config.redis)
+ reactor.connectSSL(
+ redis_config.redis_host,
+ redis_config.redis_port,
+ self._factory,
+ ssl_context_factory,
+ timeout=30,
+ bindAddress=None,
+ )
+ else:
+ reactor.connectTCP(
+ redis_config.redis_host,
+ redis_config.redis_port,
+ self._factory,
+ timeout=30,
+ bindAddress=None,
+ )
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index dfc061eb5e..c8f4bf8b27 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -35,6 +35,7 @@ from synapse.replication.tcp.commands import (
ReplicateCommand,
parse_command_from_line,
)
+from synapse.replication.tcp.context import ClientContextFactory
from synapse.replication.tcp.protocol import (
IReplicationConnection,
tcp_inbound_commands_counter,
@@ -386,12 +387,24 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(
- host,
- port,
- factory,
- timeout=30,
- bindAddress=None,
- )
+
+ if hs.config.redis.redis_use_tls:
+ ssl_context_factory = ClientContextFactory(hs.config.redis)
+ reactor.connectSSL(
+ host,
+ port,
+ factory,
+ ssl_context_factory,
+ timeout=30,
+ bindAddress=None,
+ )
+ else:
+ reactor.connectTCP(
+ host,
+ port,
+ factory,
+ timeout=30,
+ bindAddress=None,
+ )
return factory.handler
|