diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..295c5d58a6 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Any, Callable, List, Optional, Tuple
import attr
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
@@ -27,7 +27,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -36,7 +36,12 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeTransport, render
+from tests.server import FakeTransport
+
+try:
+ import hiredis
+except ImportError:
+ hiredis = None
logger = logging.getLogger(__name__)
@@ -44,6 +49,11 @@ logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ # hiredis is an optional dependency so we don't want to require it for running
+ # the tests.
+ if not hiredis:
+ skip = "Requires hiredis"
+
servlets = [
streams.register_servlets,
]
@@ -58,7 +68,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
http_client=None,
- homeserverToUse=GenericWorkerServer,
+ homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -68,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
- self.worker_hs.replication_data_handler = self.test_handler
+ self.worker_hs._replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
@@ -197,23 +207,41 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
+ # Fake in memory Redis server that servers can connect to.
+ self._redis_server = FakeRedisPubSubServer()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["localhost"] = "127.0.0.1"
- self._worker_hs_to_resource = {}
+ # A map from a HS instance to the associated HTTP Site to use for
+ # handling inbound HTTP requests to that instance.
+ self._hs_to_site = {self.hs: self.site}
+
+ if self.hs.config.redis.redis_enabled:
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost", 6379, self.connect_any_redis_attempts,
+ )
+
+ self.hs.get_tcp_replication().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
+ #
+ # Register the master replication listener:
self.reactor.add_tcp_client_callback(
- "1.2.3.4", 8765, self._handle_http_replication_attempt
+ "1.2.3.4",
+ 8765,
+ lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
- def create_test_json_resource(self):
- """Overrides `HomeserverTestCase.create_test_json_resource`.
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`.
"""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
@@ -247,34 +275,69 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config.update(extra_config)
worker_hs = self.setup_test_homeserver(
- homeserverToUse=GenericWorkerServer,
+ homeserver_to_use=GenericWorkerServer,
config=config,
reactor=self.reactor,
- **kwargs
+ **kwargs,
)
+ # If the instance is in the `instance_map` config then workers may try
+ # and send HTTP requests to it, so we register it with
+ # `_handle_http_replication_attempt` like we do with the master HS.
+ instance_name = worker_hs.get_instance_name()
+ instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+ if instance_loc:
+ # Ensure the host is one that has a fake DNS entry.
+ if instance_loc.host not in self.reactor.lookups:
+ raise Exception(
+ "Host does not have an IP for instance_map[%r].host = %r"
+ % (instance_name, instance_loc.host,)
+ )
+
+ self.reactor.add_tcp_client_callback(
+ self.reactor.lookups[instance_loc.host],
+ instance_loc.port,
+ lambda: self._handle_http_replication_attempt(
+ worker_hs, instance_loc.port
+ ),
+ )
+
store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
- )
- server = self.server_factory.buildProtocol(None)
+ # Set up TCP replication between master and the new worker if we don't
+ # have Redis support enabled.
+ if not worker_hs.config.redis_enabled:
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
# Set up a resource for the worker
- resource = ReplicationRestResource(self.hs)
+ resource = ReplicationRestResource(worker_hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
- self._worker_hs_to_resource[worker_hs] = resource
+ self._hs_to_site[worker_hs] = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag="{}-{}".format(
+ worker_hs.config.server.server_name, worker_hs.get_instance_name()
+ ),
+ config=worker_hs.config.server.listeners[0],
+ resource=resource,
+ server_version_string="1",
+ )
+
+ if worker_hs.config.redis.redis_enabled:
+ worker_hs.get_tcp_replication().start_replication(worker_hs)
return worker_hs
@@ -284,9 +347,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
- render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
-
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
@@ -294,9 +354,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump()
- def _handle_http_replication_attempt(self):
- """Handles a connection attempt to the master replication HTTP
- listener.
+ def _handle_http_replication_attempt(self, hs, repl_port):
+ """Handles a connection attempt to the given HS replication HTTP
+ listener on the given port.
"""
# We should have at least one outbound connection attempt, where the
@@ -305,7 +365,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 8765)
+ self.assertEqual(port, repl_port)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
@@ -315,7 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
- channel.site = self.site
+ channel.site = self._hs_to_site[hs]
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -333,6 +393,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
+ def connect_any_redis_attempts(self):
+ """If redis is enabled we need to deal with workers connecting to a
+ redis server. We don't want to use a real Redis server so we use a
+ fake one.
+ """
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
+
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
+
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
+
+ return client_to_server_transport, server_to_client_transport
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -467,3 +553,105 @@ class _PullToPushProducer:
pass
self.stopProducing()
+
+
+class FakeRedisPubSubServer:
+ """A fake Redis server for pub/sub.
+ """
+
+ def __init__(self):
+ self._subscribers = set()
+
+ def add_subscriber(self, conn):
+ """A connection has called SUBSCRIBE
+ """
+ self._subscribers.add(conn)
+
+ def remove_subscriber(self, conn):
+ """A connection has called UNSUBSCRIBE
+ """
+ self._subscribers.discard(conn)
+
+ def publish(self, conn, channel, msg) -> int:
+ """A connection want to publish a message to subscribers.
+ """
+ for sub in self._subscribers:
+ sub.send(["message", channel, msg])
+
+ return len(self._subscribers)
+
+ def buildProtocol(self, addr):
+ return FakeRedisPubSubProtocol(self)
+
+
+class FakeRedisPubSubProtocol(Protocol):
+ """A connection from a client talking to the fake Redis server.
+ """
+
+ def __init__(self, server: FakeRedisPubSubServer):
+ self._server = server
+ self._reader = hiredis.Reader()
+
+ def dataReceived(self, data):
+ self._reader.feed(data)
+
+ # We might get multiple messages in one packet.
+ while True:
+ msg = self._reader.gets()
+
+ if msg is False:
+ # No more messages.
+ return
+
+ if not isinstance(msg, list):
+ # Inbound commands should always be a list
+ raise Exception("Expected redis list")
+
+ self.handle_command(msg[0], *msg[1:])
+
+ def handle_command(self, command, *args):
+ """Received a Redis command from the client.
+ """
+
+ # We currently only support pub/sub.
+ if command == b"PUBLISH":
+ channel, message = args
+ num_subscribers = self._server.publish(self, channel, message)
+ self.send(num_subscribers)
+ elif command == b"SUBSCRIBE":
+ (channel,) = args
+ self._server.add_subscriber(self)
+ self.send(["subscribe", channel, 1])
+ else:
+ raise Exception("Unknown command")
+
+ def send(self, msg):
+ """Send a message back to the client.
+ """
+ raw = self.encode(msg).encode("utf-8")
+
+ self.transport.write(raw)
+ self.transport.flush()
+
+ def encode(self, obj):
+ """Encode an object to its Redis format.
+
+ Supports: strings/bytes, integers and list/tuples.
+ """
+
+ if isinstance(obj, bytes):
+ # We assume bytes are just unicode strings.
+ obj = obj.decode("utf-8")
+
+ if isinstance(obj, str):
+ return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+ if isinstance(obj, int):
+ return ":{val}\r\n".format(val=obj)
+ if isinstance(obj, (list, tuple)):
+ items = "".join(self.encode(a) for a in obj)
+ return "*{len}\r\n{items}".format(len=len(obj), items=items)
+
+ raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
+
+ def connectionLost(self, reason):
+ self._server.remove_subscriber(self)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index c9998e88e6..bad0df08cf 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
sender=sender,
type="test_event",
content={"body": body},
- **kwargs
+ **kwargs,
)
)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 86c03fd89c..96801db473 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel
+from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__)
@@ -46,23 +46,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test that registration works when using a single client reader worker.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
- request_1, channel_1 = self.make_request(
+ request_1, channel_1 = make_request(
+ self.reactor,
+ site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs, request_1)
self.assertEqual(request_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ request_2, channel_2 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)
# We're given a registered user.
@@ -74,22 +79,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
- request_1, channel_1 = self.make_request(
+ site_1 = self._hs_to_site[worker_hs_1]
+ request_1, channel_1 = make_request(
+ self.reactor,
+ site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs_1, request_1)
self.assertEqual(request_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ site_2 = self._hs_to_site[worker_hs_2]
+ request_2, channel_2 = make_request(
+ self.reactor,
+ site_2,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)
# We're given a registered user.
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 23be1167a3..1853667558 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+ hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 1d7edee5ba..779745ae9d 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -207,7 +207,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
- federation = self.hs.get_handlers().federation_handler
+ federation = self.hs.get_federation_handler()
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
@@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
}
builder = factory.for_room_version(room_version, event_dict)
- join_event = self.get_success(builder.build(prev_event_ids))
+ join_event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
new file mode 100644
index 0000000000..48b574ccbe
--- /dev/null
+++ b/tests/replication/test_multi_media_repo.py
@@ -0,0 +1,279 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from binascii import unhexlify
+from typing import Tuple
+
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+from twisted.web.server import Request
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.server import HomeServer
+
+from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
+
+logger = logging.getLogger(__name__)
+
+test_server_connection_factory = None
+
+
+class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks running multiple media repos work correctly.
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ self.reactor.lookups["example.com"] = "127.0.0.2"
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+ return conf
+
+ def _get_media_req(
+ self, hs: HomeServer, target: str, media_id: str
+ ) -> Tuple[FakeChannel, Request]:
+ """Request some remote media from the given HS by calling the download
+ API.
+
+ This then triggers an outbound request from the HS to the target.
+
+ Returns:
+ The channel for the *client* request and the *outbound* request for
+ the media which the caller should respond to.
+ """
+ resource = hs.get_media_repository_resource().children[b"download"]
+ _, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
+ "GET",
+ "/{}/{}".format(target, media_id),
+ shorthand=False,
+ access_token=self.access_token,
+ await_result=False,
+ )
+ self.pump()
+
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+
+ # build the test server
+ server_tls_protocol = _build_test_server(get_connection_factory())
+
+ # now, tell the client protocol factory to build the client protocol (it will be a
+ # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+ # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_tls_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server tls protocol to send its stuff back to the client, too
+ server_tls_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_tls_protocol)
+ )
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_server = server_tls_protocol.wrappedProtocol
+
+ # give the reactor a pump to get the TLS juices flowing.
+ self.reactor.pump((0.1,))
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(
+ request.path,
+ "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
+ )
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
+ )
+
+ return channel, request
+
+ def test_basic(self):
+ """Test basic fetching of remote media from a single worker.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+
+ channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
+
+ request.setResponseCode(200)
+ request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request.write(b"Hello!")
+ request.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.result["body"], b"Hello!")
+
+ def test_download_simple_file_race(self):
+ """Test that fetching remote media from two different processes at the
+ same time works.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_media()
+
+ # Make two requests without responding to the outbound media requests.
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
+
+ # Respond to the first outbound media request and check that the client
+ # request is successful
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request1.write(b"Hello!")
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], b"Hello!")
+
+ # Now respond to the second with the same content.
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request2.write(b"Hello!")
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], b"Hello!")
+
+ # We expect only one new file to have been persisted.
+ self.assertEqual(start_count + 1, self._count_remote_media())
+
+ def test_download_image_race(self):
+ """Test that fetching remote *images* from two different processes at
+ the same time works.
+
+ This checks that races generating thumbnails are handled correctly.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_thumbnails()
+
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
+
+ png_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request1.write(png_data)
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], png_data)
+
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request2.write(png_data)
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], png_data)
+
+ # We expect only three new thumbnails to have been persisted.
+ self.assertEqual(start_count + 3, self._count_remote_thumbnails())
+
+ def _count_remote_media(self) -> int:
+ """Count the number of files in our remote media directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_content"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+ def _count_remote_thumbnails(self) -> int:
+ """Count the number of files in our remote thumbnails directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+
+def get_connection_factory():
+ # this needs to happen once, but not until we are ready to run the first test
+ global test_server_connection_factory
+ if test_server_connection_factory is None:
+ test_server_connection_factory = TestServerTLSConnectionFactory(
+ sanlist=[b"DNS:example.com"]
+ )
+ return test_server_connection_factory
+
+
+def _build_test_server(connection_creator):
+ """Construct a test server
+
+ This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
+
+ Args:
+ connection_creator (IOpenSSLServerConnectionCreator): thing to build
+ SSL connections
+ sanlist (list[bytes]): list of the SAN entries for the cert returned
+ by the server
+
+ Returns:
+ TLSMemoryBIOProtocol
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_factory = TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=server_factory
+ )
+
+ return server_tls_factory.buildProtocol(None)
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 2bdc6edbb1..67c27a089f 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_dict["token_id"]
+ token_id = user_dict.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
new file mode 100644
index 0000000000..77fc3856d5
--- /dev/null
+++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import patch
+
+from synapse.api.room_versions import RoomVersion
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import make_request
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+logger = logging.getLogger(__name__)
+
+
+class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks event persisting sharding works
+ """
+
+ # Event persister sharding requires postgres (due to needing
+ # `MutliWriterIdGenerator`).
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ self.room_creator = self.hs.get_room_creation_handler()
+ self.store = hs.get_datastore()
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["redis"] = {"enabled": "true"}
+ conf["stream_writers"] = {"events": ["worker1", "worker2"]}
+ conf["instance_map"] = {
+ "worker1": {"host": "testserv", "port": 1001},
+ "worker2": {"host": "testserv", "port": 1002},
+ }
+ return conf
+
+ def _create_room(self, room_id: str, user_id: str, tok: str):
+ """Create a room with given room_id
+ """
+
+ # We control the room ID generation by patching out the
+ # `_generate_room_id` method
+ async def generate_room(
+ creator_id: str, is_public: bool, room_version: RoomVersion
+ ):
+ await self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id=creator_id,
+ is_public=is_public,
+ room_version=room_version,
+ )
+ return room_id
+
+ with patch(
+ "synapse.handlers.room.RoomCreationHandler._generate_room_id"
+ ) as mock:
+ mock.side_effect = generate_room
+ self.helper.create_room_as(user_id, tok=tok)
+
+ def test_basic(self):
+ """Simple test to ensure that multiple rooms can be created and joined,
+ and that different rooms get handled by different instances.
+ """
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker1"},
+ )
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker2"},
+ )
+
+ persisted_on_1 = False
+ persisted_on_2 = False
+
+ store = self.hs.get_datastore()
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Keep making new rooms until we see rooms being persisted on both
+ # workers.
+ for _ in range(10):
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = rseponse["event_id"]
+
+ # The event position includes which instance persisted the event.
+ pos = self.get_success(store.get_position_for_event(event_id))
+
+ persisted_on_1 |= pos.instance_name == "worker1"
+ persisted_on_2 |= pos.instance_name == "worker2"
+
+ if persisted_on_1 and persisted_on_2:
+ break
+
+ self.assertTrue(persisted_on_1)
+ self.assertTrue(persisted_on_2)
+
+ def test_vector_clock_token(self):
+ """Tests that using a stream token with a vector clock component works
+ correctly with basic /sync and /messages usage.
+ """
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker1"},
+ )
+
+ worker_hs2 = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker2"},
+ )
+
+ sync_hs = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "sync"},
+ )
+ sync_hs_site = self._hs_to_site[sync_hs]
+
+ # Specially selected room IDs that get persisted on different workers.
+ room_id1 = "!foo:test"
+ room_id2 = "!baz:test"
+
+ self.assertEqual(
+ self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1"
+ )
+ self.assertEqual(
+ self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2"
+ )
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ store = self.hs.get_datastore()
+
+ # Create two room on the different workers.
+ self._create_room(room_id1, user_id, access_token)
+ self._create_room(room_id2, user_id, access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room_id1, user=self.other_user_id, tok=self.other_access_token
+ )
+ self.helper.join(
+ room=room_id2, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # Do an initial sync so that we're up to date.
+ request, channel = make_request(
+ self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
+ )
+ next_batch = channel.json_body["next_batch"]
+
+ # We now gut wrench into the events stream MultiWriterIdGenerator on
+ # worker2 to mimic it getting stuck persisting an event. This ensures
+ # that when we send an event on worker1 we end up in a state where
+ # worker2 events stream position lags that on worker1, resulting in a
+ # RoomStreamToken with a non-empty instance map component.
+ #
+ # Worker2's event stream position will not advance until we call
+ # __aexit__ again.
+ actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
+ self.get_success(actx.__aenter__())
+
+ response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
+ first_event_in_room1 = response["event_id"]
+
+ # Assert that the current stream token has an instance map component, as
+ # we are trying to test vector clock tokens.
+ room_stream_token = store.get_room_max_token()
+ self.assertNotEqual(len(room_stream_token.instance_map), 0)
+
+ # Check that syncing still gets the new event, despite the gap in the
+ # stream IDs.
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(next_batch),
+ access_token=access_token,
+ )
+
+ # We should only see the new event and nothing else
+ self.assertIn(room_id1, channel.json_body["rooms"]["join"])
+ self.assertNotIn(room_id2, channel.json_body["rooms"]["join"])
+
+ events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"]
+ self.assertListEqual(
+ [first_event_in_room1], [event["event_id"] for event in events]
+ )
+
+ # Get the next batch and makes sure its a vector clock style token.
+ vector_clock_token = channel.json_body["next_batch"]
+ self.assertTrue(vector_clock_token.startswith("m"))
+
+ # Now that we've got a vector clock token we finish the fake persisting
+ # an event we started above.
+ self.get_success(actx.__aexit__(None, None, None))
+
+ # Now try and send an event to the other rooom so that we can test that
+ # the vector clock style token works as a `since` token.
+ response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
+ first_event_in_room2 = response["event_id"]
+
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(vector_clock_token),
+ access_token=access_token,
+ )
+
+ self.assertNotIn(room_id1, channel.json_body["rooms"]["join"])
+ self.assertIn(room_id2, channel.json_body["rooms"]["join"])
+
+ events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"]
+ self.assertListEqual(
+ [first_event_in_room2], [event["event_id"] for event in events]
+ )
+
+ next_batch = channel.json_body["next_batch"]
+
+ # We also want to test that the vector clock style token works with
+ # pagination. We do this by sending a couple of new events into the room
+ # and syncing again to get a prev_batch token for each room, then
+ # paginating from there back to the vector clock token.
+ self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
+ self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
+
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(next_batch),
+ access_token=access_token,
+ )
+
+ prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][
+ "prev_batch"
+ ]
+ prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][
+ "prev_batch"
+ ]
+
+ # Paginating back in the first room should not produce any results, as
+ # no events have happened in it. This tests that we are correctly
+ # filtering results based on the vector clock portion.
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=b".format(
+ room_id1, prev_batch1, vector_clock_token
+ ),
+ access_token=access_token,
+ )
+ self.assertListEqual([], channel.json_body["chunk"])
+
+ # Paginating back on the second room should produce the first event
+ # again. This tests that pagination isn't completely broken.
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=b".format(
+ room_id2, prev_batch2, vector_clock_token
+ ),
+ access_token=access_token,
+ )
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertEqual(
+ channel.json_body["chunk"][0]["event_id"], first_event_in_room2
+ )
+
+ # Paginating forwards should give the same results
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=f".format(
+ room_id1, vector_clock_token, prev_batch1
+ ),
+ access_token=access_token,
+ )
+ self.assertListEqual([], channel.json_body["chunk"])
+
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=f".format(
+ room_id2, vector_clock_token, prev_batch2,
+ ),
+ access_token=access_token,
+ )
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertEqual(
+ channel.json_body["chunk"][0]["event_id"], first_event_in_room2
+ )
|