diff --git a/changelog.d/8433.misc b/changelog.d/8433.misc
new file mode 100644
index 0000000000..05f8b5bbf4
--- /dev/null
+++ b/changelog.d/8433.misc
@@ -0,0 +1 @@
+Add unit test for event persister sharding.
diff --git a/mypy.ini b/mypy.ini
index c283f15b21..e84ad04e41 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -143,3 +143,6 @@ ignore_missing_imports = True
[mypy-nacl.*]
ignore_missing_imports = True
+
+[mypy-hiredis]
+ignore_missing_imports = True
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index c66413f003..522244bb57 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -16,7 +16,7 @@
"""Contains *incomplete* type hints for txredisapi.
"""
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Type
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
@@ -42,3 +42,21 @@ def lazyConnection(
class SubscriberFactory:
def buildProtocol(self, addr): ...
+
+class ConnectionHandler: ...
+
+class RedisFactory:
+ continueTrying: bool
+ handler: RedisProtocol
+ def __init__(
+ self,
+ uuid: str,
+ dbid: Optional[int],
+ poolsize: int,
+ isLazy: bool = False,
+ handler: Type = ConnectionHandler,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: Optional[int] = True,
+ ): ...
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..e92da7b263 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -251,10 +251,9 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
- import txredisapi
-
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
+ lazyConnection,
)
logger.info(
@@ -271,7 +270,8 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = txredisapi.lazyConnection(
+ outbound_redis_connection = lazyConnection(
+ reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..de19705c1f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
import txredisapi
@@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.password = self.password
return p
+
+
+def lazyConnection(
+ reactor,
+ host: str = "localhost",
+ port: int = 6379,
+ dbid: Optional[int] = None,
+ reconnect: bool = True,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ connectTimeout: Optional[int] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+ """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+ reactor.
+ """
+
+ isLazy = True
+ poolsize = 1
+
+ uuid = "%s:%d" % (host, port)
+ factory = txredisapi.RedisFactory(
+ uuid,
+ dbid,
+ poolsize,
+ isLazy,
+ txredisapi.ConnectionHandler,
+ charset,
+ password,
+ replyTimeout,
+ convertNumbers,
+ )
+ factory.continueTrying = reconnect
+ for x in range(poolsize):
+ reactor.connectTCP(host, port, factory, connectTimeout)
+
+ return factory.handler
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..81ea985b9f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Any, Callable, List, Optional, Tuple
import attr
+import hiredis
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
@@ -27,7 +28,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
+ # Fake in memory Redis server that servers can connect to.
+ self._redis_server = FakeRedisPubSubServer()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["localhost"] = "127.0.0.1"
+
+ # A map from a HS instance to the associated HTTP Site to use for
+ # handling inbound HTTP requests to that instance.
+ self._hs_to_site = {self.hs: self.site}
+
+ if self.hs.config.redis.redis_enabled:
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost", 6379, self.connect_any_redis_attempts,
+ )
- self._worker_hs_to_resource = {}
+ self.hs.get_tcp_replication().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
+ #
+ # Register the master replication listener:
self.reactor.add_tcp_client_callback(
- "1.2.3.4", 8765, self._handle_http_replication_attempt
+ "1.2.3.4",
+ 8765,
+ lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
def create_test_json_resource(self):
@@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
**kwargs
)
+ # If the instance is in the `instance_map` config then workers may try
+ # and send HTTP requests to it, so we register it with
+ # `_handle_http_replication_attempt` like we do with the master HS.
+ instance_name = worker_hs.get_instance_name()
+ instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+ if instance_loc:
+ # Ensure the host is one that has a fake DNS entry.
+ if instance_loc.host not in self.reactor.lookups:
+ raise Exception(
+ "Host does not have an IP for instance_map[%r].host = %r"
+ % (instance_name, instance_loc.host,)
+ )
+
+ self.reactor.add_tcp_client_callback(
+ self.reactor.lookups[instance_loc.host],
+ instance_loc.port,
+ lambda: self._handle_http_replication_attempt(
+ worker_hs, instance_loc.port
+ ),
+ )
+
store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
- )
- server = self.server_factory.buildProtocol(None)
+ # Set up TCP replication between master and the new worker if we don't
+ # have Redis support enabled.
+ if not worker_hs.config.redis_enabled:
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
# Set up a resource for the worker
- resource = ReplicationRestResource(self.hs)
+ resource = ReplicationRestResource(worker_hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
- self._worker_hs_to_resource[worker_hs] = resource
+ self._hs_to_site[worker_hs] = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag="{}-{}".format(
+ worker_hs.config.server.server_name, worker_hs.get_instance_name()
+ ),
+ config=worker_hs.config.server.listeners[0],
+ resource=resource,
+ server_version_string="1",
+ )
+
+ if worker_hs.config.redis.redis_enabled:
+ worker_hs.get_tcp_replication().start_replication(worker_hs)
return worker_hs
@@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return config
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
- render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+ render(request, self._hs_to_site[worker_hs].resource, self.reactor)
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump()
- def _handle_http_replication_attempt(self):
- """Handles a connection attempt to the master replication HTTP
- listener.
+ def _handle_http_replication_attempt(self, hs, repl_port):
+ """Handles a connection attempt to the given HS replication HTTP
+ listener on the given port.
"""
# We should have at least one outbound connection attempt, where the
@@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 8765)
+ self.assertEqual(port, repl_port)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
@@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
- channel.site = self.site
+ channel.site = self._hs_to_site[hs]
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
+ def connect_any_redis_attempts(self):
+ """If redis is enabled we need to deal with workers connecting to a
+ redis server. We don't want to use a real Redis server so we use a
+ fake one.
+ """
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
+
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
+
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
+
+ return client_to_server_transport, server_to_client_transport
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -467,3 +547,105 @@ class _PullToPushProducer:
pass
self.stopProducing()
+
+
+class FakeRedisPubSubServer:
+ """A fake Redis server for pub/sub.
+ """
+
+ def __init__(self):
+ self._subscribers = set()
+
+ def add_subscriber(self, conn):
+ """A connection has called SUBSCRIBE
+ """
+ self._subscribers.add(conn)
+
+ def remove_subscriber(self, conn):
+ """A connection has called UNSUBSCRIBE
+ """
+ self._subscribers.discard(conn)
+
+ def publish(self, conn, channel, msg) -> int:
+ """A connection want to publish a message to subscribers.
+ """
+ for sub in self._subscribers:
+ sub.send(["message", channel, msg])
+
+ return len(self._subscribers)
+
+ def buildProtocol(self, addr):
+ return FakeRedisPubSubProtocol(self)
+
+
+class FakeRedisPubSubProtocol(Protocol):
+ """A connection from a client talking to the fake Redis server.
+ """
+
+ def __init__(self, server: FakeRedisPubSubServer):
+ self._server = server
+ self._reader = hiredis.Reader()
+
+ def dataReceived(self, data):
+ self._reader.feed(data)
+
+ # We might get multiple messages in one packet.
+ while True:
+ msg = self._reader.gets()
+
+ if msg is False:
+ # No more messages.
+ return
+
+ if not isinstance(msg, list):
+ # Inbound commands should always be a list
+ raise Exception("Expected redis list")
+
+ self.handle_command(msg[0], *msg[1:])
+
+ def handle_command(self, command, *args):
+ """Received a Redis command from the client.
+ """
+
+ # We currently only support pub/sub.
+ if command == b"PUBLISH":
+ channel, message = args
+ num_subscribers = self._server.publish(self, channel, message)
+ self.send(num_subscribers)
+ elif command == b"SUBSCRIBE":
+ (channel,) = args
+ self._server.add_subscriber(self)
+ self.send(["subscribe", channel, 1])
+ else:
+ raise Exception("Unknown command")
+
+ def send(self, msg):
+ """Send a message back to the client.
+ """
+ raw = self.encode(msg).encode("utf-8")
+
+ self.transport.write(raw)
+ self.transport.flush()
+
+ def encode(self, obj):
+ """Encode an object to its Redis format.
+
+ Supports: strings/bytes, integers and list/tuples.
+ """
+
+ if isinstance(obj, bytes):
+ # We assume bytes are just unicode strings.
+ obj = obj.decode("utf-8")
+
+ if isinstance(obj, str):
+ return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+ if isinstance(obj, int):
+ return ":{val}\r\n".format(val=obj)
+ if isinstance(obj, (list, tuple)):
+ items = "".join(self.encode(a) for a in obj)
+ return "*{len}\r\n{items}".format(len=len(obj), items=items)
+
+ raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
+
+ def connectionLost(self, reason):
+ self._server.remove_subscriber(self)
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
new file mode 100644
index 0000000000..6068d14905
--- /dev/null
+++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+logger = logging.getLogger(__name__)
+
+
+class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks event persisting sharding works
+ """
+
+ # Event persister sharding requires postgres (due to needing
+ # `MutliWriterIdGenerator`).
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["redis"] = {"enabled": "true"}
+ conf["stream_writers"] = {"events": ["worker1", "worker2"]}
+ conf["instance_map"] = {
+ "worker1": {"host": "testserv", "port": 1001},
+ "worker2": {"host": "testserv", "port": 1002},
+ }
+ return conf
+
+ def test_basic(self):
+ """Simple test to ensure that multiple rooms can be created and joined,
+ and that different rooms get handled by different instances.
+ """
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker1"},
+ )
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker2"},
+ )
+
+ persisted_on_1 = False
+ persisted_on_2 = False
+
+ store = self.hs.get_datastore()
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Keep making new rooms until we see rooms being persisted on both
+ # workers.
+ for _ in range(10):
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = rseponse["event_id"]
+
+ # The event position includes which instance persisted the event.
+ pos = self.get_success(store.get_position_for_event(event_id))
+
+ persisted_on_1 |= pos.instance_name == "worker1"
+ persisted_on_2 |= pos.instance_name == "worker2"
+
+ if persisted_on_1 and persisted_on_2:
+ break
+
+ self.assertTrue(persisted_on_1)
+ self.assertTrue(persisted_on_2)
diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..82ede9de34 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase):
# create a site to wrap the resource.
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
- site_tag="test",
+ site_tag=self.hs.config.server.server_name,
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
|