diff options
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r-- | tests/replication/_base.py | 224 |
1 files changed, 203 insertions, 21 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ae60874ec3..81ea985b9f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Any, Callable, List, Optional, Tuple import attr +import hiredis from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime +from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel @@ -27,7 +28,7 @@ from synapse.app.generic_worker import ( GenericWorkerServer, ) from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest +from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = self.hs.get_replication_streamer() + # Fake in memory Redis server that servers can connect to. + self._redis_server = FakeRedisPubSubServer() + store = self.hs.get_datastore() self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["localhost"] = "127.0.0.1" + + # A map from a HS instance to the associated HTTP Site to use for + # handling inbound HTTP requests to that instance. + self._hs_to_site = {self.hs: self.site} + + if self.hs.config.redis.redis_enabled: + # Handle attempts to connect to fake redis server. + self.reactor.add_tcp_client_callback( + "localhost", 6379, self.connect_any_redis_attempts, + ) - self._worker_hs_to_resource = {} + self.hs.get_tcp_replication().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't # manually have to go and explicitly set it up each time (plus sometimes # it is impossible to write the handling explicitly in the tests). + # + # Register the master replication listener: self.reactor.add_tcp_client_callback( - "1.2.3.4", 8765, self._handle_http_replication_attempt + "1.2.3.4", + 8765, + lambda: self._handle_http_replication_attempt(self.hs, 8765), ) def create_test_json_resource(self): @@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): **kwargs ) + # If the instance is in the `instance_map` config then workers may try + # and send HTTP requests to it, so we register it with + # `_handle_http_replication_attempt` like we do with the master HS. + instance_name = worker_hs.get_instance_name() + instance_loc = worker_hs.config.worker.instance_map.get(instance_name) + if instance_loc: + # Ensure the host is one that has a fake DNS entry. + if instance_loc.host not in self.reactor.lookups: + raise Exception( + "Host does not have an IP for instance_map[%r].host = %r" + % (instance_name, instance_loc.host,) + ) + + self.reactor.add_tcp_client_callback( + self.reactor.lookups[instance_loc.host], + instance_loc.port, + lambda: self._handle_http_replication_attempt( + worker_hs, instance_loc.port + ), + ) + store = worker_hs.get_datastore() store.db_pool._db_pool = self.database_pool._db_pool - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, - ) - server = self.server_factory.buildProtocol(None) + # Set up TCP replication between master and the new worker if we don't + # have Redis support enabled. + if not worker_hs.config.redis_enabled: + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) # Set up a resource for the worker - resource = ReplicationRestResource(self.hs) + resource = ReplicationRestResource(worker_hs) for servlet in self.servlets: servlet(worker_hs, resource) - self._worker_hs_to_resource[worker_hs] = resource + self._hs_to_site[worker_hs] = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="{}-{}".format( + worker_hs.config.server.server_name, worker_hs.get_instance_name() + ), + config=worker_hs.config.server.listeners[0], + resource=resource, + server_version_string="1", + ) + + if worker_hs.config.redis.redis_enabled: + worker_hs.get_tcp_replication().start_replication(worker_hs) return worker_hs @@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): return config def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): - render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + render(request, self._hs_to_site[worker_hs].resource, self.reactor) def replicate(self): """Tell the master side of replication that something has happened, and then @@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump() - def _handle_http_replication_attempt(self): - """Handles a connection attempt to the master replication HTTP - listener. + def _handle_http_replication_attempt(self, hs, repl_port): + """Handles a connection attempt to the given HS replication HTTP + listener on the given port. """ # We should have at least one outbound connection attempt, where the @@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") - self.assertEqual(port, 8765) + self.assertEqual(port, repl_port) # Set up client side protocol client_protocol = client_factory.buildProtocol(None) @@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up the server side protocol channel = _PushHTTPChannel(self.reactor) channel.requestFactory = request_factory - channel.site = self.site + channel.site = self._hs_to_site[hs] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # inside `connecTCP` before the connection has been passed back to the # code that requested the TCP connection. + def connect_any_redis_attempts(self): + """If redis is enabled we need to deal with workers connecting to a + redis server. We don't want to use a real Redis server so we use a + fake one. + """ + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) + + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) + + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) + + return client_to_server_transport, server_to_client_transport + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -467,3 +547,105 @@ class _PullToPushProducer: pass self.stopProducing() + + +class FakeRedisPubSubServer: + """A fake Redis server for pub/sub. + """ + + def __init__(self): + self._subscribers = set() + + def add_subscriber(self, conn): + """A connection has called SUBSCRIBE + """ + self._subscribers.add(conn) + + def remove_subscriber(self, conn): + """A connection has called UNSUBSCRIBE + """ + self._subscribers.discard(conn) + + def publish(self, conn, channel, msg) -> int: + """A connection want to publish a message to subscribers. + """ + for sub in self._subscribers: + sub.send(["message", channel, msg]) + + return len(self._subscribers) + + def buildProtocol(self, addr): + return FakeRedisPubSubProtocol(self) + + +class FakeRedisPubSubProtocol(Protocol): + """A connection from a client talking to the fake Redis server. + """ + + def __init__(self, server: FakeRedisPubSubServer): + self._server = server + self._reader = hiredis.Reader() + + def dataReceived(self, data): + self._reader.feed(data) + + # We might get multiple messages in one packet. + while True: + msg = self._reader.gets() + + if msg is False: + # No more messages. + return + + if not isinstance(msg, list): + # Inbound commands should always be a list + raise Exception("Expected redis list") + + self.handle_command(msg[0], *msg[1:]) + + def handle_command(self, command, *args): + """Received a Redis command from the client. + """ + + # We currently only support pub/sub. + if command == b"PUBLISH": + channel, message = args + num_subscribers = self._server.publish(self, channel, message) + self.send(num_subscribers) + elif command == b"SUBSCRIBE": + (channel,) = args + self._server.add_subscriber(self) + self.send(["subscribe", channel, 1]) + else: + raise Exception("Unknown command") + + def send(self, msg): + """Send a message back to the client. + """ + raw = self.encode(msg).encode("utf-8") + + self.transport.write(raw) + self.transport.flush() + + def encode(self, obj): + """Encode an object to its Redis format. + + Supports: strings/bytes, integers and list/tuples. + """ + + if isinstance(obj, bytes): + # We assume bytes are just unicode strings. + obj = obj.decode("utf-8") + + if isinstance(obj, str): + return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + if isinstance(obj, int): + return ":{val}\r\n".format(val=obj) + if isinstance(obj, (list, tuple)): + items = "".join(self.encode(a) for a in obj) + return "*{len}\r\n{items}".format(len=len(obj), items=items) + + raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) + + def connectionLost(self, reason): + self._server.remove_subscriber(self) |