From 6c5d5e507e629cf57ae8c1034879e8ffaef33e9f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Oct 2020 09:57:12 +0100 Subject: Add unit test for event persister sharding (#8433) --- tests/replication/_base.py | 224 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 203 insertions(+), 21 deletions(-) (limited to 'tests/replication/_base.py') diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ae60874ec3..81ea985b9f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Any, Callable, List, Optional, Tuple import attr +import hiredis from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime +from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel @@ -27,7 +28,7 @@ from synapse.app.generic_worker import ( GenericWorkerServer, ) from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest +from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = self.hs.get_replication_streamer() + # Fake in memory Redis server that servers can connect to. + self._redis_server = FakeRedisPubSubServer() + store = self.hs.get_datastore() self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["localhost"] = "127.0.0.1" + + # A map from a HS instance to the associated HTTP Site to use for + # handling inbound HTTP requests to that instance. + self._hs_to_site = {self.hs: self.site} + + if self.hs.config.redis.redis_enabled: + # Handle attempts to connect to fake redis server. + self.reactor.add_tcp_client_callback( + "localhost", 6379, self.connect_any_redis_attempts, + ) - self._worker_hs_to_resource = {} + self.hs.get_tcp_replication().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't # manually have to go and explicitly set it up each time (plus sometimes # it is impossible to write the handling explicitly in the tests). + # + # Register the master replication listener: self.reactor.add_tcp_client_callback( - "1.2.3.4", 8765, self._handle_http_replication_attempt + "1.2.3.4", + 8765, + lambda: self._handle_http_replication_attempt(self.hs, 8765), ) def create_test_json_resource(self): @@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): **kwargs ) + # If the instance is in the `instance_map` config then workers may try + # and send HTTP requests to it, so we register it with + # `_handle_http_replication_attempt` like we do with the master HS. + instance_name = worker_hs.get_instance_name() + instance_loc = worker_hs.config.worker.instance_map.get(instance_name) + if instance_loc: + # Ensure the host is one that has a fake DNS entry. + if instance_loc.host not in self.reactor.lookups: + raise Exception( + "Host does not have an IP for instance_map[%r].host = %r" + % (instance_name, instance_loc.host,) + ) + + self.reactor.add_tcp_client_callback( + self.reactor.lookups[instance_loc.host], + instance_loc.port, + lambda: self._handle_http_replication_attempt( + worker_hs, instance_loc.port + ), + ) + store = worker_hs.get_datastore() store.db_pool._db_pool = self.database_pool._db_pool - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, - ) - server = self.server_factory.buildProtocol(None) + # Set up TCP replication between master and the new worker if we don't + # have Redis support enabled. + if not worker_hs.config.redis_enabled: + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) # Set up a resource for the worker - resource = ReplicationRestResource(self.hs) + resource = ReplicationRestResource(worker_hs) for servlet in self.servlets: servlet(worker_hs, resource) - self._worker_hs_to_resource[worker_hs] = resource + self._hs_to_site[worker_hs] = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="{}-{}".format( + worker_hs.config.server.server_name, worker_hs.get_instance_name() + ), + config=worker_hs.config.server.listeners[0], + resource=resource, + server_version_string="1", + ) + + if worker_hs.config.redis.redis_enabled: + worker_hs.get_tcp_replication().start_replication(worker_hs) return worker_hs @@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): return config def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): - render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + render(request, self._hs_to_site[worker_hs].resource, self.reactor) def replicate(self): """Tell the master side of replication that something has happened, and then @@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump() - def _handle_http_replication_attempt(self): - """Handles a connection attempt to the master replication HTTP - listener. + def _handle_http_replication_attempt(self, hs, repl_port): + """Handles a connection attempt to the given HS replication HTTP + listener on the given port. """ # We should have at least one outbound connection attempt, where the @@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") - self.assertEqual(port, 8765) + self.assertEqual(port, repl_port) # Set up client side protocol client_protocol = client_factory.buildProtocol(None) @@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up the server side protocol channel = _PushHTTPChannel(self.reactor) channel.requestFactory = request_factory - channel.site = self.site + channel.site = self._hs_to_site[hs] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # inside `connecTCP` before the connection has been passed back to the # code that requested the TCP connection. + def connect_any_redis_attempts(self): + """If redis is enabled we need to deal with workers connecting to a + redis server. We don't want to use a real Redis server so we use a + fake one. + """ + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) + + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) + + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) + + return client_to_server_transport, server_to_client_transport + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -467,3 +547,105 @@ class _PullToPushProducer: pass self.stopProducing() + + +class FakeRedisPubSubServer: + """A fake Redis server for pub/sub. + """ + + def __init__(self): + self._subscribers = set() + + def add_subscriber(self, conn): + """A connection has called SUBSCRIBE + """ + self._subscribers.add(conn) + + def remove_subscriber(self, conn): + """A connection has called UNSUBSCRIBE + """ + self._subscribers.discard(conn) + + def publish(self, conn, channel, msg) -> int: + """A connection want to publish a message to subscribers. + """ + for sub in self._subscribers: + sub.send(["message", channel, msg]) + + return len(self._subscribers) + + def buildProtocol(self, addr): + return FakeRedisPubSubProtocol(self) + + +class FakeRedisPubSubProtocol(Protocol): + """A connection from a client talking to the fake Redis server. + """ + + def __init__(self, server: FakeRedisPubSubServer): + self._server = server + self._reader = hiredis.Reader() + + def dataReceived(self, data): + self._reader.feed(data) + + # We might get multiple messages in one packet. + while True: + msg = self._reader.gets() + + if msg is False: + # No more messages. + return + + if not isinstance(msg, list): + # Inbound commands should always be a list + raise Exception("Expected redis list") + + self.handle_command(msg[0], *msg[1:]) + + def handle_command(self, command, *args): + """Received a Redis command from the client. + """ + + # We currently only support pub/sub. + if command == b"PUBLISH": + channel, message = args + num_subscribers = self._server.publish(self, channel, message) + self.send(num_subscribers) + elif command == b"SUBSCRIBE": + (channel,) = args + self._server.add_subscriber(self) + self.send(["subscribe", channel, 1]) + else: + raise Exception("Unknown command") + + def send(self, msg): + """Send a message back to the client. + """ + raw = self.encode(msg).encode("utf-8") + + self.transport.write(raw) + self.transport.flush() + + def encode(self, obj): + """Encode an object to its Redis format. + + Supports: strings/bytes, integers and list/tuples. + """ + + if isinstance(obj, bytes): + # We assume bytes are just unicode strings. + obj = obj.decode("utf-8") + + if isinstance(obj, str): + return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + if isinstance(obj, int): + return ":{val}\r\n".format(val=obj) + if isinstance(obj, (list, tuple)): + items = "".join(self.encode(a) for a in obj) + return "*{len}\r\n{items}".format(len=len(obj), items=items) + + raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) + + def connectionLost(self, reason): + self._server.remove_subscriber(self) -- cgit 1.5.1 From 6b5a115c0a0f9036444cd8686b32afbdf5334915 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 15 Oct 2020 21:29:13 +0200 Subject: Solidify the HomeServer constructor. (#8515) This implements a more standard API for instantiating a homeserver and moves some of the dependency injection into the test suite. More concretely this stops using `setattr` on all `kwargs` passed to `HomeServer`. --- changelog.d/8515.misc | 1 + synapse/server.py | 14 +++++++++----- tests/app/test_frontend_proxy.py | 2 +- tests/app/test_openid_listener.py | 4 ++-- tests/replication/_base.py | 4 ++-- tests/replication/test_federation_ack.py | 2 +- tests/utils.py | 31 +++++++++++++++++-------------- 7 files changed, 33 insertions(+), 25 deletions(-) create mode 100644 changelog.d/8515.misc (limited to 'tests/replication/_base.py') diff --git a/changelog.d/8515.misc b/changelog.d/8515.misc new file mode 100644 index 0000000000..1f8aa292d8 --- /dev/null +++ b/changelog.d/8515.misc @@ -0,0 +1 @@ +Apply some internal fixes to the `HomeServer` class to make its code more idiomatic and statically-verifiable. diff --git a/synapse/server.py b/synapse/server.py index f921ee4b53..21a232bbd9 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -205,7 +205,13 @@ class HomeServer(metaclass=abc.ABCMeta): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): + def __init__( + self, + hostname: str, + config: HomeServerConfig, + reactor=None, + version_string="Synapse", + ): """ Args: hostname : The hostname for the server. @@ -236,11 +242,9 @@ class HomeServer(metaclass=abc.ABCMeta): burst_count=config.rc_registration.burst_count, ) - self.datastores = None # type: Optional[Databases] + self.version_string = version_string - # Other kwargs are explicit dependencies - for depname in kwargs: - setattr(self, depname, kwargs[depname]) + self.datastores = None # type: Optional[Databases] def get_instance_id(self) -> str: """A unique ID for this synapse process instance. diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 641093d349..4a301b84e1 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -22,7 +22,7 @@ class FrontendProxyTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=GenericWorkerServer + http_client=None, homeserver_to_use=GenericWorkerServer ) return hs diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 0f016c32eb..c2b10d2c70 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -26,7 +26,7 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=GenericWorkerServer + http_client=None, homeserver_to_use=GenericWorkerServer ) return hs @@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=SynapseHomeServer + http_client=None, homeserver_to_use=SynapseHomeServer ) return hs diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 81ea985b9f..093e2faac7 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -59,7 +59,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.reactor.lookups["testserv"] = "1.2.3.4" self.worker_hs = self.setup_test_homeserver( http_client=None, - homeserverToUse=GenericWorkerServer, + homeserver_to_use=GenericWorkerServer, config=self._get_worker_hs_config(), reactor=self.reactor, ) @@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config.update(extra_config) worker_hs = self.setup_test_homeserver( - homeserverToUse=GenericWorkerServer, + homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, **kwargs diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 23be1167a3..1853667558 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) return hs diff --git a/tests/utils.py b/tests/utils.py index 0c09f5457f..acec74e9e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,6 +21,7 @@ import time import uuid import warnings from inspect import getcallargs +from typing import Type from urllib import parse as urlparse from mock import Mock, patch @@ -194,8 +195,8 @@ def setup_test_homeserver( name="test", config=None, reactor=None, - homeserverToUse=TestHomeServer, - **kargs + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -218,8 +219,8 @@ def setup_test_homeserver( config.ldap_enabled = False - if "clock" not in kargs: - kargs["clock"] = MockClock() + if "clock" not in kwargs: + kwargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex @@ -264,18 +265,20 @@ def setup_test_homeserver( cur.close() db_conn.close() - hs = homeserverToUse( - name, - config=config, - version_string="Synapse/tests", - tls_server_context_factory=Mock(), - tls_client_options_factory=Mock(), - reactor=reactor, - **kargs + hs = homeserver_to_use( + name, config=config, version_string="Synapse/tests", reactor=reactor, ) + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": + if homeserver_to_use == TestHomeServer: hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): @@ -339,7 +342,7 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash - fed = kargs.get("resource_for_federation", None) + fed = kwargs.get("resource_for_federation", None) if fed: register_federation_servlets(hs, fed) -- cgit 1.5.1 From aff1eb7c671b0a3813407321d2702ec46c71fa56 Mon Sep 17 00:00:00 2001 From: Dan Callahan Date: Tue, 27 Oct 2020 23:26:36 +0000 Subject: Tell Black to format code for Python 3.5 (#8664) This allows trailing commas in multi-line arg lists. Minor, but we might as well keep our formatting current with regard to our minimum supported Python version. Signed-off-by: Dan Callahan --- changelog.d/8664.misc | 1 + pyproject.toml | 2 +- synapse/http/client.py | 2 +- synapse/storage/database.py | 4 ++-- synapse/util/retryutils.py | 2 +- tests/replication/_base.py | 2 +- tests/replication/tcp/streams/test_events.py | 2 +- tests/server.py | 4 ++-- tests/storage/test_client_ips.py | 2 +- tests/test_utils/event_injection.py | 2 +- 10 files changed, 12 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8664.misc (limited to 'tests/replication/_base.py') diff --git a/changelog.d/8664.misc b/changelog.d/8664.misc new file mode 100644 index 0000000000..278cf53adc --- /dev/null +++ b/changelog.d/8664.misc @@ -0,0 +1 @@ +Tell Black to format code for Python 3.5. diff --git a/pyproject.toml b/pyproject.toml index db4a2e41e4..cd880d4e39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ showcontent = true [tool.black] -target-version = ['py34'] +target-version = ['py35'] exclude = ''' ( diff --git a/synapse/http/client.py b/synapse/http/client.py index 8324632cb6..f409368802 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -359,7 +359,7 @@ class SimpleHttpClient: agent=self.agent, data=body_producer, headers=headers, - **self._extra_treq_args + **self._extra_treq_args, ) # type: defer.Deferred # we use our own timeout mechanism rather than treq's as a workaround diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0217e63108..a0572b2952 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -94,7 +94,7 @@ def make_pool( cp_openfun=lambda conn: engine.on_new_connection( LoggingDatabaseConnection(conn, engine, "on_new_connection") ), - **db_config.config.get("args", {}) + **db_config.config.get("args", {}), ) @@ -632,7 +632,7 @@ class DatabasePool: func, *args, db_autocommit=db_autocommit, - **kwargs + **kwargs, ) for after_callback, after_args, after_kwargs in after_callbacks: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index a5cc9d0551..4ab379e429 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k failure_ts, retry_interval, backoff_on_failure=backoff_on_failure, - **kwargs + **kwargs, ) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 093e2faac7..f1e53f33cd 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -269,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, - **kwargs + **kwargs, ) # If the instance is in the `instance_map` config then workers may try diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index c9998e88e6..bad0df08cf 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase): sender=sender, type="test_event", content={"body": body}, - **kwargs + **kwargs, ) ) diff --git a/tests/server.py b/tests/server.py index 4d33b84097..ea9c22bc51 100644 --- a/tests/server.py +++ b/tests/server.py @@ -380,7 +380,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runWithConnection, func, *args, - **kwargs + **kwargs, ) def runInteraction(interaction, *args, **kwargs): @@ -390,7 +390,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runInteraction, interaction, *args, - **kwargs + **kwargs, ) pool.runWithConnection = runWithConnection diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 755c70db31..e96ca1c8ca 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/admin/users/" + self.user_id, access_token=access_token, - **make_request_args + **make_request_args, ) request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index e93aa84405..c3c4a93e1f 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -50,7 +50,7 @@ async def inject_member_event( sender=sender, state_key=target, content=content, - **kwargs + **kwargs, ) -- cgit 1.5.1 From 9a7e0d2ea675278510b79fd9b87260c13cc27a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 29 Oct 2020 11:17:35 +0000 Subject: Don't require hiredis to run unit tests (#8680) --- changelog.d/8680.misc | 1 + tests/replication/_base.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8680.misc (limited to 'tests/replication/_base.py') diff --git a/changelog.d/8680.misc b/changelog.d/8680.misc new file mode 100644 index 0000000000..2ca2975464 --- /dev/null +++ b/changelog.d/8680.misc @@ -0,0 +1 @@ +Don't require `hiredis` package to be installed to run unit tests. diff --git a/tests/replication/_base.py b/tests/replication/_base.py index f1e53f33cd..5c633ac6df 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -16,7 +16,6 @@ import logging from typing import Any, Callable, List, Optional, Tuple import attr -import hiredis from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol @@ -39,12 +38,22 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeTransport, render +try: + import hiredis +except ImportError: + hiredis = None + logger = logging.getLogger(__name__) class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + # hiredis is an optional dependency so we don't want to require it for running + # the tests. + if not hiredis: + skip = "Requires hiredis" + servlets = [ streams.register_servlets, ] -- cgit 1.5.1 From 791d7cd6f065e166576538b9cba3d90febf83ea4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 16 Nov 2020 14:45:52 +0000 Subject: Rename `create_test_json_resource` to `create_test_resource` (#8759) The root resource isn't necessarily a JsonResource, so rename this method accordingly, and update a couple of test classes to use the method rather than directly manipulating self.resource. --- changelog.d/8759.misc | 1 + tests/replication/_base.py | 4 ++-- tests/rest/admin/test_admin.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 2 +- tests/rest/test_health.py | 6 ++---- tests/rest/test_well_known.py | 6 ++---- tests/unittest.py | 18 +++++++----------- 7 files changed, 16 insertions(+), 23 deletions(-) create mode 100644 changelog.d/8759.misc (limited to 'tests/replication/_base.py') diff --git a/changelog.d/8759.misc b/changelog.d/8759.misc new file mode 100644 index 0000000000..54502e9b90 --- /dev/null +++ b/changelog.d/8759.misc @@ -0,0 +1 @@ +Refactor test utilities for injecting HTTP requests. diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 5c633ac6df..bc56b13dcd 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -240,8 +240,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): lambda: self._handle_http_replication_attempt(self.hs, 8765), ) - def create_test_json_resource(self): - """Overrides `HomeserverTestCase.create_test_json_resource`. + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`. """ # We override this so that it automatically registers all the HTTP # replication servlets, without having to explicitly do that in all diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 0f1144fe1e..6804f9337f 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -35,7 +35,7 @@ from tests import unittest class VersionTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/server_version" - def create_test_json_resource(self): + def create_test_resource(self): resource = JsonResource(self.hs) VersionServlet(self.hs).register(resource) return resource diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 6850c666be..6671cbd32d 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -41,7 +41,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): self.http_client = Mock() return self.setup_test_homeserver(http_client=self.http_client) - def create_test_json_resource(self): + def create_test_resource(self): return create_resource_tree( {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() ) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index 2d021f6565..f4d06e2200 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -20,11 +20,9 @@ from tests import unittest class HealthCheckTests(unittest.HomeserverTestCase): - def setUp(self): - super().setUp() - + def create_test_resource(self): # replace the JsonResource with a HealthResource. - self.resource = HealthResource() + return HealthResource() def test_health(self): request, channel = self.make_request("GET", "/health", shorthand=False) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index dcd65c2a50..a3746e7130 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -20,11 +20,9 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): - def setUp(self): - super().setUp() - + def create_test_resource(self): # replace the JsonResource with a WellKnownResource - self.resource = WellKnownResource(self.hs) + return WellKnownResource(self.hs) def test_well_known(self): self.hs.config.public_baseurl = "https://tesths" diff --git a/tests/unittest.py b/tests/unittest.py index e36ac89196..c630760e51 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool from twisted.trial import unittest +from twisted.web.resource import Resource from synapse.api.constants import EventTypes, Membership from synapse.config.homeserver import HomeServerConfig @@ -239,10 +240,8 @@ class HomeserverTestCase(TestCase): if not isinstance(self.hs, HomeServer): raise Exception("A homeserver wasn't returned, but %r" % (self.hs,)) - # Register the resources - self.resource = self.create_test_json_resource() - - # create a site to wrap the resource. + # create the root resource, and a site to wrap it. + self.resource = self.create_test_resource() self.site = SynapseSite( logger_name="synapse.access.http.fake", site_tag=self.hs.config.server.server_name, @@ -323,15 +322,12 @@ class HomeserverTestCase(TestCase): hs = self.setup_test_homeserver() return hs - def create_test_json_resource(self): + def create_test_resource(self) -> Resource: """ - Create a test JsonResource, with the relevant servlets registerd to it - - The default implementation calls each function in `servlets` to do the - registration. + Create a the root resource for the test server. - Returns: - JsonResource: + The default implementation creates a JsonResource and calls each function in + `servlets` to register servletes against it """ resource = JsonResource(self.hs) -- cgit 1.5.1 From be8fa65d0baddcc0a64954e21d38a854e4ee00d7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Sun, 15 Nov 2020 22:49:21 +0000 Subject: Remove redundant calls to `render()` --- tests/app/test_frontend_proxy.py | 10 ++--- tests/app/test_openid_listener.py | 12 +++--- tests/http/test_additional_resource.py | 4 +- tests/replication/_base.py | 5 +-- tests/replication/test_client_reader_shard.py | 4 -- tests/replication/test_sharded_event_persister.py | 8 ---- tests/rest/client/test_consent.py | 6 +-- tests/rest/client/v1/utils.py | 16 +++----- tests/rest/client/v2_alpha/test_sync.py | 6 +-- tests/server.py | 5 --- tests/storage/test_client_ips.py | 3 +- tests/test_server.py | 49 +++++------------------ tests/unittest.py | 10 +---- 13 files changed, 32 insertions(+), 106 deletions(-) (limited to 'tests/replication/_base.py') diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 0bac7995e8..40abe9d72d 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -15,7 +15,7 @@ from synapse.app.generic_worker import GenericWorkerServer -from tests.server import make_request, render +from tests.server import make_request from tests.unittest import HomeserverTestCase @@ -56,10 +56,8 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - resource = site.resource.children[b"_matrix"].children[b"client"] - request, channel = make_request(self.reactor, site, "PUT", "presence/a/status") - render(request, resource, self.reactor) + _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 400 + unrecognised, because nothing is registered self.assertEqual(channel.code, 400) @@ -78,10 +76,8 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - resource = site.resource.children[b"_matrix"].children[b"client"] - request, channel = make_request(self.reactor, site, "PUT", "presence/a/status") - render(request, resource, self.reactor) + _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 401, because the stub servlet still checks authentication self.assertEqual(channel.code, 401) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 1292145890..ea3be95cf1 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -20,7 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer from synapse.config.server import parse_listener_def -from tests.server import make_request, render +from tests.server import make_request from tests.unittest import HomeserverTestCase @@ -67,16 +67,15 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - resource = site.resource.children[b"_matrix"].children[b"federation"] + site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise - request, channel = make_request( + _, channel = make_request( self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 401) @@ -116,15 +115,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - resource = site.resource.children[b"_matrix"].children[b"federation"] + site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise - request, channel = make_request( + _, channel = make_request( self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 401) diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index e835512a41..05e9c449be 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -17,7 +17,7 @@ from synapse.http.additional_resource import AdditionalResource from synapse.http.server import respond_with_json -from tests.server import FakeSite, make_request, render +from tests.server import FakeSite, make_request from tests.unittest import HomeserverTestCase @@ -47,7 +47,6 @@ class AdditionalResourceTests(HomeserverTestCase): resource = AdditionalResource(self.hs, handler) request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - render(request, resource, self.reactor) self.assertEqual(request.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) @@ -57,7 +56,6 @@ class AdditionalResourceTests(HomeserverTestCase): resource = AdditionalResource(self.hs, handler) request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - render(request, resource, self.reactor) self.assertEqual(request.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index bc56b13dcd..516db4c30a 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -36,7 +36,7 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport, render +from tests.server import FakeTransport try: import hiredis @@ -347,9 +347,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config["worker_replication_http_port"] = "8765" return config - def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): - render(request, self._hs_to_site[worker_hs].resource, self.reactor) - def replicate(self): """Tell the master side of replication that something has happened, and then wait for the replication to occur. diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 90172bd377..96801db473 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -55,7 +55,6 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -69,7 +68,6 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"auth": {"session": session, "type": "m.login.dummy"}}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. @@ -89,7 +87,6 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs_1, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -104,7 +101,6 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"auth": {"session": session, "type": "m.login.dummy"}}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs_2, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 2820dd622f..77fc3856d5 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -183,7 +183,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): request, channel = make_request( self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token ) - self.render_on_worker(sync_hs, request) next_batch = channel.json_body["next_batch"] # We now gut wrench into the events stream MultiWriterIdGenerator on @@ -214,7 +213,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): "/sync?since={}".format(next_batch), access_token=access_token, ) - self.render_on_worker(sync_hs, request) # We should only see the new event and nothing else self.assertIn(room_id1, channel.json_body["rooms"]["join"]) @@ -245,7 +243,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): "/sync?since={}".format(vector_clock_token), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertNotIn(room_id1, channel.json_body["rooms"]["join"]) self.assertIn(room_id2, channel.json_body["rooms"]["join"]) @@ -271,7 +268,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): "/sync?since={}".format(next_batch), access_token=access_token, ) - self.render_on_worker(sync_hs, request) prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][ "prev_batch" @@ -292,7 +288,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertListEqual([], channel.json_body["chunk"]) # Paginating back on the second room should produce the first event @@ -306,7 +301,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertEqual( channel.json_body["chunk"][0]["event_id"], first_event_in_room2 @@ -322,7 +316,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertListEqual([], channel.json_body["chunk"]) request, channel = make_request( @@ -334,7 +327,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertEqual( channel.json_body["chunk"][0]["event_id"], first_event_in_room2 diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 2931859f25..e2e6a5e16d 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.consent import consent_resource from tests import unittest -from tests.server import FakeSite, make_request, render +from tests.server import FakeSite, make_request class ConsentResourceTestCase(unittest.HomeserverTestCase): @@ -64,7 +64,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): request, channel = make_request( self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) def test_accept_consent(self): @@ -91,7 +90,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Get the version from the body, and whether we've consented @@ -107,7 +105,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Fetch the consent page, to get the consent version -- it should have @@ -120,7 +117,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Get the version from the body, and check that it's the version we diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 040a92d6f0..b58768675b 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -27,7 +27,7 @@ from twisted.web.server import Site from synapse.api.constants import Membership -from tests.server import FakeSite, make_request, render +from tests.server import FakeSite, make_request @attr.s @@ -52,14 +52,13 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), self.site, "POST", path, json.dumps(content).encode("utf8"), ) - render(request, self.site.resource, self.hs.get_reactor()) assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id @@ -129,7 +128,7 @@ class RestHelper: data = {"membership": membership} data.update(extra_data) - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), self.site, "PUT", @@ -137,8 +136,6 @@ class RestHelper: json.dumps(data).encode("utf8"), ) - render(request, self.site.resource, self.hs.get_reactor()) - assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" % (expect_code, int(channel.result["code"]), channel.result["body"]) @@ -166,14 +163,13 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), self.site, "PUT", path, json.dumps(content).encode("utf8"), ) - render(request, self.site.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" @@ -223,12 +219,10 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - request, channel = make_request( + _, channel = make_request( self.hs.get_reactor(), self.site, method, path, content ) - render(request, self.site.resource, self.hs.get_reactor()) - assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" % (expect_code, int(channel.result["code"]), channel.result["body"]) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a31e44c97e..f74d611943 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -320,10 +320,8 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing._reset() # Now it SHOULD fail as it never completes! - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) - self.assertRaises(TimedOutException, self.render, request) + with self.assertRaises(TimedOutException): + self.make_request("GET", sync_url % (access_token, next_batch)) class UnreadMessagesTestCase(unittest.HomeserverTestCase): diff --git a/tests/server.py b/tests/server.py index de7cb1d8b3..a51ad0c14e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -19,7 +19,6 @@ from twisted.internet.interfaces import ( ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock -from twisted.web.http import unquote from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Site @@ -267,10 +266,6 @@ def make_request( return req, channel -def render(request, resource, clock): - pass - - @implementer(IReactorPluggableNameResolver) class ThreadedMemoryReactorClock(MemoryReactorClock): """ diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 583addb5b5..6bdde1a2ba 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): headers1 = {b"User-Agent": b"Mozzila pizza"} headers1.update(headers) - request, channel = make_request( + make_request( self.reactor, self.site, "GET", @@ -421,7 +421,6 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): custom_headers=headers1.items(), **make_request_args, ) - self.render(request) # Advance so the save loop occurs self.reactor.advance(100) diff --git a/tests/test_server.py b/tests/test_server.py index 300d13ac95..c387a85f2e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -29,7 +29,6 @@ from tests.server import ( FakeSite, ThreadedMemoryReactorClock, make_request, - render, setup_test_homeserver, ) @@ -65,7 +64,6 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request( self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" ) - render(request, res, self.reactor) self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) @@ -84,10 +82,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" - ) - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -111,10 +106,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" - ) - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -132,10 +124,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foo" - ) - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -157,10 +146,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request( + _, channel = make_request( self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar" ) - render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") @@ -182,10 +170,7 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - request, channel = make_request( - self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo" - ) - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) @@ -216,16 +201,8 @@ class OptionsResourceTests(unittest.TestCase): "1.0", ) - request, channel = make_request( - self.reactor, site, method, path, shorthand=False - ) - request.prepath = [] # This doesn't get set properly by make_request. - - request.site = site - resource = site.getResourceFor(request) - - # Finally, render the resource and return the channel. - render(request, resource, self.reactor) + # render the request and return the channel + _, channel = make_request(self.reactor, site, method, path, shorthand=False) return channel def test_unknown_options_request(self): @@ -298,8 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"200") body = channel.result["body"] @@ -317,8 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"301") headers = channel.result["headers"] @@ -339,8 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"304") headers = channel.result["headers"] @@ -359,8 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) diff --git a/tests/unittest.py b/tests/unittest.py index 9c7eca3b6e..8a49bb5262 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -48,13 +48,7 @@ from synapse.server import HomeServer from synapse.types import UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter -from tests.server import ( - FakeChannel, - get_clock, - make_request, - render, - setup_test_homeserver, -) +from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -454,7 +448,7 @@ class HomeserverTestCase(TestCase): Args: request (synapse.http.site.SynapseRequest): The request to render. """ - render(request, self.resource, self.reactor) + pass def setup_test_homeserver(self, *args, **kwargs): """ -- cgit 1.5.1 From ca60822b34416e524a80c5689734d870ceb6e130 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 30 Nov 2020 19:28:44 +0100 Subject: Simplify the way the `HomeServer` object caches its internal attributes. (#8565) Changes `@cache_in_self` to use underscore-prefixed attributes. --- changelog.d/8565.misc | 1 + synapse/handlers/identity.py | 3 ++- synapse/rest/client/v2_alpha/account.py | 6 +++--- synapse/rest/client/v2_alpha/register.py | 4 ++-- synapse/rest/key/v2/local_key_resource.py | 2 +- synapse/server.py | 27 +++++++++++++-------------- tests/handlers/test_auth.py | 6 +++--- tests/replication/_base.py | 2 +- tests/rest/client/v1/test_presence.py | 15 +++++++++------ tests/rest/client/v2_alpha/test_register.py | 6 +++--- tests/utils.py | 2 +- 11 files changed, 39 insertions(+), 35 deletions(-) create mode 100644 changelog.d/8565.misc (limited to 'tests/replication/_base.py') diff --git a/changelog.d/8565.misc b/changelog.d/8565.misc new file mode 100644 index 0000000000..7bef422618 --- /dev/null +++ b/changelog.d/8565.misc @@ -0,0 +1 @@ +Simplify the way the `HomeServer` object caches its internal attributes. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index bc3e9607ca..9b3c6b4551 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "An error was encountered when sending the email") token_expires = ( - self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime + self.hs.get_clock().time_msec() + + self.hs.config.email_validation_token_lifetime ) await self.store.start_or_continue_validation_session( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index a54e1011f7..eebee44a44 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index ea68114026..a89ae6ddf9 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index c16280f668..d8e8e48c1c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -66,7 +66,7 @@ class LocalKey(Resource): def __init__(self, hs): self.config = hs.config - self.clock = hs.clock + self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) Resource.__init__(self) diff --git a/synapse/server.py b/synapse/server.py index c82d8f9fad..b017e3489f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -147,7 +147,8 @@ def cache_in_self(builder: T) -> T: "@cache_in_self can only be used on functions starting with `get_`" ) - depname = builder.__name__[len("get_") :] + # get_attr -> _attr + depname = builder.__name__[len("get") :] building = [False] @@ -235,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta): self._instance_id = random_string(5) self._instance_name = config.worker_name or "master" - self.clock = Clock(reactor) - self.distributor = Distributor() - - self.registration_ratelimiter = Ratelimiter( - clock=self.clock, - rate_hz=config.rc_registration.per_second, - burst_count=config.rc_registration.burst_count, - ) - self.version_string = version_string self.datastores = None # type: Optional[Databases] @@ -301,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta): def is_mine_id(self, string: str) -> bool: return string.split(":", 1)[1] == self.hostname + @cache_in_self def get_clock(self) -> Clock: - return self.clock + return Clock(self._reactor) def get_datastore(self) -> DataStore: if not self.datastores: @@ -319,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta): def get_config(self) -> HomeServerConfig: return self.config + @cache_in_self def get_distributor(self) -> Distributor: - return self.distributor + return Distributor() + @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: - return self.registration_ratelimiter + return Ratelimiter( + clock=self.get_clock(), + rate_hz=self.config.rc_registration.per_second, + burst_count=self.config.rc_registration.burst_count, + ) @cache_in_self def get_federation_client(self) -> FederationClient: @@ -687,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: - return FederationRateLimiter(self.clock, config=self.config.rc_federation) + return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation) @cache_in_self def get_module_api(self) -> ModuleApi: diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b5055e018c..e24ce81284 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase): self.fail("some_user was not in %s" % macaroon.inspect()) def test_macaroon_caveats(self): - self.hs.clock.now = 5000 + self.hs.get_clock().now = 5000 token = self.macaroon_generator.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_short_term_login_token_gives_user_id(self): - self.hs.clock.now = 1000 + self.hs.get_clock().now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) user_id = yield defer.ensureDeferred( @@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected - self.hs.clock.now = 6000 + self.hs.get_clock().now = 6000 with self.assertRaises(synapse.api.errors.AuthError): yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id(token) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 516db4c30a..295c5d58a6 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.test_handler = self._build_replication_data_handler() - self.worker_hs.replication_data_handler = self.test_handler + self.worker_hs._replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index b84f86d28c..5d5c24d01c 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -33,13 +33,16 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): + presence_handler = Mock() + presence_handler.set_state.return_value = defer.succeed(None) + hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() + "red", + http_client=None, + federation_client=Mock(), + presence_handler=presence_handler, ) - hs.presence_handler = Mock() - hs.presence_handler.set_state.return_value = defer.succeed(None) - return hs def test_put_presence(self): @@ -55,7 +58,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.presence_handler.set_state.call_count, 1) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) def test_put_presence_disabled(self): """ @@ -70,4 +73,4 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.presence_handler.set_state.call_count, 0) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 699a40c3df..8f0c2430e8 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -569,7 +569,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do # nothing. - now = self.hs.clock.time_msec() + now = self.hs.get_clock().time_msec() self.get_success( self.store.user_add_threepid( user_id=user_id, @@ -587,7 +587,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # We need to manually add an email address otherwise the handler will do # nothing. - now = self.hs.clock.time_msec() + now = self.hs.get_clock().time_msec() self.get_success( self.store.user_add_threepid( user_id=user_id, @@ -646,7 +646,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): self.hs.config.account_validity.startup_job_max_delta = self.max_delta - now_ms = self.hs.clock.time_msec() + now_ms = self.hs.get_clock().time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) diff --git a/tests/utils.py b/tests/utils.py index acec74e9e9..c8d3ffbaba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -271,7 +271,7 @@ def setup_test_homeserver( # Install @cache_in_self attributes for key, val in kwargs.items(): - setattr(hs, key, val) + setattr(hs, "_" + key, val) # Mock TLS hs.tls_server_context_factory = Mock() -- cgit 1.5.1