From 32e7c9e7f20b57dd081023ac42d6931a8da9b3a3 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 20 Jun 2019 19:32:02 +1000 Subject: Run Black. (#5482) --- .../federation/test_matrix_federation_agent.py | 216 ++++++++++----------- tests/http/federation/test_srv_resolver.py | 2 +- 2 files changed, 109 insertions(+), 109 deletions(-) (limited to 'tests/http/federation') diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index ecce473b01..b1094c1448 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -53,13 +53,15 @@ def get_connection_factory(): # this needs to happen once, but not until we are ready to run the first test global test_server_connection_factory if test_server_connection_factory is None: - test_server_connection_factory = TestServerTLSConnectionFactory(sanlist=[ - b'DNS:testserv', - b'DNS:target-server', - b'DNS:xn--bcher-kva.com', - b'IP:1.2.3.4', - b'IP:::1', - ]) + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[ + b"DNS:testserv", + b"DNS:target-server", + b"DNS:xn--bcher-kva.com", + b"IP:1.2.3.4", + b"IP:::1", + ] + ) return test_server_connection_factory @@ -133,7 +135,7 @@ class MatrixFederationAgentTests(TestCase): Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: - fetch_d = self.agent.request(b'GET', uri) + fetch_d = self.agent.request(b"GET", uri) # Nothing happened yet self.assertNoResult(fetch_d) @@ -177,9 +179,9 @@ class MatrixFederationAgentTests(TestCase): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/.well-known/matrix/server') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/.well-known/matrix/server") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # send back a response for k, v in headers.items(): request.setHeader(k, v) @@ -202,7 +204,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -210,20 +212,20 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] + request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] ) content = request.content.read() - self.assertEqual(content, b'') + self.assertEqual(content, b"") # Deferred is still without a result self.assertNoResult(test_d) # send the headers - request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json']) - request.write('') + request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) + request.write("") self.reactor.pump((0.1,)) @@ -233,7 +235,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode('ascii')) + request.write('{ "a": 1 }'.encode("ascii")) request.finish() self.reactor.pump((0.1,)) @@ -258,7 +260,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -266,9 +268,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"1.2.3.4"]) # finish the request request.finish() @@ -293,7 +295,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '::1') + self.assertEqual(host, "::1") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -301,9 +303,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]"]) # finish the request request.finish() @@ -328,7 +330,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '::1') + self.assertEqual(host, "::1") self.assertEqual(port, 80) # make a test server, and wire up the client @@ -336,9 +338,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]:80"]) # finish the request request.finish() @@ -364,7 +366,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -382,11 +384,11 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv1') + http_server = self._make_connection(client_factory, expected_sni=b"testserv1") # there should be no requests self.assertEqual(len(http_server.requests), 0) @@ -413,7 +415,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.5') + self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -447,7 +449,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -465,17 +467,17 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -499,7 +501,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( @@ -516,20 +518,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target server self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1::f') + self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -561,7 +563,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) redirect_server = self._make_connection( @@ -571,7 +573,7 @@ class MatrixFederationAgentTests(TestCase): # send a 302 redirect self.assertEqual(len(redirect_server.requests), 1) request = redirect_server.requests[0] - request.redirect(b'https://testserv/even_better_known') + request.redirect(b"https://testserv/even_better_known") request.finish() self.reactor.pump((0.1,)) @@ -580,7 +582,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._make_connection( @@ -589,8 +591,8 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(well_known_server.requests), 1, "No request after 302") request = well_known_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/even_better_known') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/even_better_known") request.write(b'{ "m.server": "target-server" }') request.finish() @@ -604,20 +606,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target server self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1::f') + self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -652,11 +654,11 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", content=b'NOT JSON' + client_factory, expected_sni=b"testserv", content=b"NOT JSON" ) # now there should be a SRV lookup @@ -667,17 +669,17 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -712,12 +714,10 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - http_proto = self._make_connection( - client_factory, expected_sni=b"testserv", - ) + http_proto = self._make_connection(client_factory, expected_sni=b"testserv") # there should be no requests self.assertEqual(len(http_proto.requests), 0) @@ -750,17 +750,17 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -783,7 +783,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ @@ -804,20 +804,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target of the SRV record self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '5.6.7.8') + self.assertEqual(host, "5.6.7.8") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -846,7 +846,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -865,20 +865,20 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'xn--bcher-kva.com' + client_factory, expected_sni=b"xn--bcher-kva.com" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] ) # finish the request @@ -907,20 +907,20 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'xn--bcher-kva.com' + client_factory, expected_sni=b"xn--bcher-kva.com" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] ) # finish the request @@ -941,42 +941,42 @@ class MatrixFederationAgentTests(TestCase): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", - response_headers={b'Cache-Control': b'max-age=10'}, + response_headers={b"Cache-Control": b"max-age=10"}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b'target-server') + self.assertEqual(r, b"target-server") # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") r = self.successResultOf(fetch_d) - self.assertEqual(r, b'target-server') + self.assertEqual(r, b"target-server") # expire the cache self.reactor.pump((10.0,)) # now it should connect again - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( @@ -986,7 +986,7 @@ class MatrixFederationAgentTests(TestCase): ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b'other-server') + self.assertEqual(r, b"other-server") class TestCachePeriodFromHeaders(TestCase): @@ -994,27 +994,27 @@ class TestCachePeriodFromHeaders(TestCase): # uppercase self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}) + Headers({b"Cache-Control": [b"foo, Max-Age = 100, bar"]}) ), 100, ) # missing value self.assertIsNone( - _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']})) + _cache_period_from_headers(Headers({b"Cache-Control": [b"max-age=, bar"]})) ) # hackernews: bogus due to semicolon self.assertIsNone( _cache_period_from_headers( - Headers({b'Cache-Control': [b'private; max-age=0']}) + Headers({b"Cache-Control": [b"private; max-age=0"]}) ) ) # github self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}) + Headers({b"Cache-Control": [b"max-age=0, private, must-revalidate"]}) ), 0, ) @@ -1022,7 +1022,7 @@ class TestCachePeriodFromHeaders(TestCase): # google self.assertEqual( _cache_period_from_headers( - Headers({b'cache-control': [b'private, max-age=0']}) + Headers({b"cache-control": [b"private, max-age=0"]}) ), 0, ) @@ -1030,7 +1030,7 @@ class TestCachePeriodFromHeaders(TestCase): def test_expires(self): self.assertEqual( _cache_period_from_headers( - Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), + Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}), time_now=lambda: 1548833700, ), 33, @@ -1041,8 +1041,8 @@ class TestCachePeriodFromHeaders(TestCase): _cache_period_from_headers( Headers( { - b'cache-control': [b'max-age=10'], - b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'], + b"cache-control": [b"max-age=10"], + b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"], } ), time_now=lambda: 1548833700, @@ -1051,7 +1051,7 @@ class TestCachePeriodFromHeaders(TestCase): ) # invalid expires means immediate expiry - self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0) + self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0) def _check_logcontext(context): diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 034c0db8d2..cf6c6e95b5 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -100,7 +100,7 @@ class SrvResolverTestCase(unittest.TestCase): def test_from_cache(self): clock = MockClock() - dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock = Mock(spec_set=["lookupService"]) dns_client_mock.lookupService = Mock(spec_set=[]) service_name = b"test_service.example.com" -- cgit 1.5.1 From c3c6b00d956937ad50673cec75eab7989938a39b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 24 Jun 2019 11:34:45 +0100 Subject: Pass config_dir_path and data_dir_path into Config.read_config. (#5522) * Pull config_dir_path and data_dir_path calculation out of read_config_files * Pass config_dir_path and data_dir_path into read_config --- changelog.d/5522.misc | 1 + synapse/config/_base.py | 104 ++++++++++++++------- synapse/config/api.py | 2 +- synapse/config/appservice.py | 2 +- synapse/config/captcha.py | 2 +- synapse/config/cas.py | 2 +- synapse/config/consent_config.py | 2 +- synapse/config/database.py | 2 +- synapse/config/emailconfig.py | 2 +- synapse/config/groups.py | 2 +- synapse/config/jwt_config.py | 2 +- synapse/config/key.py | 2 +- synapse/config/logger.py | 2 +- synapse/config/metrics.py | 2 +- synapse/config/password.py | 2 +- synapse/config/password_auth_providers.py | 2 +- synapse/config/push.py | 2 +- synapse/config/ratelimiting.py | 2 +- synapse/config/registration.py | 2 +- synapse/config/repository.py | 2 +- synapse/config/room_directory.py | 2 +- synapse/config/saml2_config.py | 2 +- synapse/config/server.py | 2 +- synapse/config/server_notices_config.py | 2 +- synapse/config/spam_checker.py | 2 +- synapse/config/stats.py | 2 +- synapse/config/third_party_event_rules.py | 2 +- synapse/config/tls.py | 2 +- synapse/config/user_directory.py | 2 +- synapse/config/voip.py | 2 +- synapse/config/workers.py | 2 +- tests/config/test_tls.py | 2 +- .../federation/test_matrix_federation_agent.py | 2 +- tests/unittest.py | 2 +- tests/utils.py | 2 +- 35 files changed, 104 insertions(+), 67 deletions(-) create mode 100644 changelog.d/5522.misc (limited to 'tests/http/federation') diff --git a/changelog.d/5522.misc b/changelog.d/5522.misc new file mode 100644 index 0000000000..17a7be5c99 --- /dev/null +++ b/changelog.d/5522.misc @@ -0,0 +1 @@ +Pass config_dir_path and data_dir_path into Config.read_config. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 36e9c04cee..6baa315874 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -216,7 +218,7 @@ class Config(object): "--keys-directory", metavar="DIRECTORY", help="Where files such as certs and signing keys are stored when" - " their location is given explicitly in the config." + " their location is not given explicitly in the config." " Defaults to the directory containing the last config file", ) @@ -228,10 +230,22 @@ class Config(object): config_files = find_config_files(search_paths=config_args.config_path) + if not config_files: + config_parser.error("Must supply a config file.") + + if config_args.keys_directory: + config_dir_path = config_args.keys_directory + else: + config_dir_path = os.path.dirname(config_files[-1]) + config_dir_path = os.path.abspath(config_dir_path) + data_dir_path = os.getcwd() + config_dict = obj.read_config_files( - config_files, keys_directory=config_args.keys_directory + config_files, config_dir_path=config_dir_path, data_dir_path=data_dir_path + ) + obj.parse_config_dict( + config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path ) - obj.parse_config_dict(config_dict) obj.invoke_all("read_arguments", config_args) @@ -282,7 +296,7 @@ class Config(object): metavar="DIRECTORY", help=( "Specify where additional config files such as signing keys and log" - " config should be stored. Defaults to the same directory as the main" + " config should be stored. Defaults to the same directory as the last" " config file." ), ) @@ -290,6 +304,20 @@ class Config(object): config_files = find_config_files(search_paths=config_args.config_path) + if not config_files: + config_parser.error( + "Must supply a config file.\nA config file can be automatically" + ' generated using "--generate-config -H SERVER_NAME' + ' -c CONFIG-FILE"' + ) + + if config_args.config_directory: + config_dir_path = config_args.config_directory + else: + config_dir_path = os.path.dirname(config_files[-1]) + config_dir_path = os.path.abspath(config_dir_path) + data_dir_path = os.getcwd() + generate_missing_configs = config_args.generate_missing_configs obj = cls() @@ -300,20 +328,10 @@ class Config(object): "Please specify either --report-stats=yes or --report-stats=no\n\n" + MISSING_REPORT_STATS_SPIEL ) - if not config_files: - config_parser.error( - "Must supply a config file.\nA config file can be automatically" - ' generated using "--generate-config -H SERVER_NAME' - ' -c CONFIG-FILE"' - ) + (config_path,) = config_files if not cls.path_exists(config_path): print("Generating config file %s" % (config_path,)) - if config_args.config_directory: - config_dir_path = config_args.config_directory - else: - config_dir_path = os.path.dirname(config_path) - config_dir_path = os.path.abspath(config_dir_path) server_name = config_args.server_name if not server_name: @@ -324,7 +342,7 @@ class Config(object): config_str = obj.generate_config( config_dir_path=config_dir_path, - data_dir_path=os.getcwd(), + data_dir_path=data_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), generate_secrets=True, @@ -367,35 +385,37 @@ class Config(object): obj.invoke_all("add_arguments", parser) args = parser.parse_args(remaining_args) - if not config_files: - config_parser.error( - "Must supply a config file.\nA config file can be automatically" - ' generated using "--generate-config -H SERVER_NAME' - ' -c CONFIG-FILE"' - ) - config_dict = obj.read_config_files( - config_files, keys_directory=config_args.config_directory + config_files, config_dir_path=config_dir_path, data_dir_path=data_dir_path ) if generate_missing_configs: obj.generate_missing_files(config_dict) return None - obj.parse_config_dict(config_dict) + obj.parse_config_dict( + config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path + ) obj.invoke_all("read_arguments", args) return obj - def read_config_files(self, config_files, keys_directory=None): + def read_config_files(self, config_files, config_dir_path, data_dir_path): """Read the config files into a dict + Args: + config_files (iterable[str]): A list of the config files to read + + config_dir_path (str): The path where the config files are kept. Used to + create filenames for things like the log config and the signing key. + + data_dir_path (str): The path where the data files are kept. Used to create + filenames for things like the database and media store. + Returns: dict """ - if not keys_directory: - keys_directory = os.path.dirname(config_files[-1]) - - self.config_dir_path = os.path.abspath(keys_directory) + # FIXME: get rid of this + self.config_dir_path = config_dir_path # first we read the config files into a dict specified_config = {} @@ -409,8 +429,8 @@ class Config(object): raise ConfigError(MISSING_SERVER_NAME) server_name = specified_config["server_name"] config_string = self.generate_config( - config_dir_path=self.config_dir_path, - data_dir_path=os.getcwd(), + config_dir_path=config_dir_path, + data_dir_path=data_dir_path, server_name=server_name, generate_secrets=False, ) @@ -430,8 +450,24 @@ class Config(object): ) return config - def parse_config_dict(self, config_dict): - self.invoke_all("read_config", config_dict) + def parse_config_dict(self, config_dict, config_dir_path, data_dir_path): + """Read the information from the config dict into this Config object. + + Args: + config_dict (dict): Configuration data, as read from the yaml + + config_dir_path (str): The path where the config files are kept. Used to + create filenames for things like the log config and the signing key. + + data_dir_path (str): The path where the data files are kept. Used to create + filenames for things like the database and media store. + """ + self.invoke_all( + "read_config", + config_dict, + config_dir_path=config_dir_path, + data_dir_path=data_dir_path, + ) def generate_missing_files(self, config_dict): self.invoke_all("generate_files", config_dict) diff --git a/synapse/config/api.py b/synapse/config/api.py index 23b0ea6962..d9eff9ae1f 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -18,7 +18,7 @@ from ._base import Config class ApiConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.room_invite_state_types = config.get( "room_invite_state_types", [ diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 679ee62480..b74cebfca9 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AppServiceConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) self.track_appservice_user_ips = config.get("track_appservice_user_ips", False) diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index e2eb473a92..a08b08570b 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -16,7 +16,7 @@ from ._base import Config class CaptchaConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.recaptcha_private_key = config.get("recaptcha_private_key") self.recaptcha_public_key = config.get("recaptcha_public_key") self.enable_registration_captcha = config.get( diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 609c0815c8..a5f0449955 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -22,7 +22,7 @@ class CasConfig(Config): cas_server_url: URL of CAS server """ - def read_config(self, config): + def read_config(self, config, **kwargs): cas_config = config.get("cas_config", None) if cas_config: self.cas_enabled = cas_config.get("enabled", True) diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index 5b0bf919c7..6fd4931681 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -84,7 +84,7 @@ class ConsentConfig(Config): self.user_consent_at_registration = False self.user_consent_policy_name = "Privacy Policy" - def read_config(self, config): + def read_config(self, config, **kwargs): consent_config = config.get("user_consent") if consent_config is None: return diff --git a/synapse/config/database.py b/synapse/config/database.py index adc0a47ddf..c8963e276a 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -18,7 +18,7 @@ from ._base import Config class DatabaseConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) self.database_config = config.get("database") diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 3a6cb07206..07df7b7173 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -27,7 +27,7 @@ from ._base import Config, ConfigError class EmailConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): # TODO: We should separate better the email configuration from the notification # and account validity config. diff --git a/synapse/config/groups.py b/synapse/config/groups.py index e4be172a79..d11f4d3b96 100644 --- a/synapse/config/groups.py +++ b/synapse/config/groups.py @@ -17,7 +17,7 @@ from ._base import Config class GroupsConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.enable_group_creation = config.get("enable_group_creation", False) self.group_creation_prefix = config.get("group_creation_prefix", "") diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index b190dcbe38..a2c97dea95 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -23,7 +23,7 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login. class JWTConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): jwt_config = config.get("jwt_config", None) if jwt_config: self.jwt_enabled = jwt_config.get("enabled", False) diff --git a/synapse/config/key.py b/synapse/config/key.py index 21c4f5c51c..e58638f708 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -65,7 +65,7 @@ class TrustedKeyServer(object): class KeyConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): # the signing key can be specified inline or in a separate file if "signing_key" in config: self.signing_key = read_signing_keys([config["signing_key"]]) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 9db2e087e4..153a137517 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -74,7 +74,7 @@ root: class LoggingConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.verbosity = config.get("verbose", 0) self.no_redirect_stdio = config.get("no_redirect_stdio", False) self.log_config = self.abspath(config.get("log_config")) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index c85e234d22..6af82e1329 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -21,7 +21,7 @@ MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentr class MetricsConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.enable_metrics = config.get("enable_metrics", False) self.report_stats = config.get("report_stats", None) self.metrics_port = config.get("metrics_port") diff --git a/synapse/config/password.py b/synapse/config/password.py index eea59e772b..300b67f236 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -20,7 +20,7 @@ class PasswordConfig(Config): """Password login configuration """ - def read_config(self, config): + def read_config(self, config, **kwargs): password_config = config.get("password_config", {}) if password_config is None: password_config = {} diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index fcf279e8e1..8ffefd2639 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -21,7 +21,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider" class PasswordAuthProviderConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.password_providers = [] providers = [] diff --git a/synapse/config/push.py b/synapse/config/push.py index 62c0060c9c..99d15e4461 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -18,7 +18,7 @@ from ._base import Config class PushConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 5a9adac480..b03047f2b5 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -36,7 +36,7 @@ class FederationRateLimitConfig(object): class RatelimitConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): # Load the new-style messages config if it exists. Otherwise fall back # to the old method. diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a1e27ba66c..6d8a2df29b 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -46,7 +46,7 @@ class AccountValidityConfig(Config): class RegistrationConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.enable_registration = bool( strtobool(str(config.get("enable_registration", False))) ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 9f9669ebb1..15a19e0911 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -86,7 +86,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): class ContentRepositoryConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M")) self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M")) self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M")) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index c1da0e20e0..24223db7a1 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -19,7 +19,7 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.enable_room_list_search = config.get("enable_room_list_search", True) alias_creation_rules = config.get("alias_creation_rules") diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 2ec38e48e9..d86cf0e6ee 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -17,7 +17,7 @@ from ._base import Config, ConfigError class SAML2Config(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.saml2_enabled = False saml2_config = config.get("saml2_config") diff --git a/synapse/config/server.py b/synapse/config/server.py index 9ceca0a606..1e58b2e91b 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -40,7 +40,7 @@ DEFAULT_ROOM_VERSION = "4" class ServerConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.server_name = config["server_name"] self.server_context = config.get("server_context", None) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index d930eb33b5..05110c17a6 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -66,7 +66,7 @@ class ServerNoticesConfig(Config): self.server_notices_mxid_avatar_url = None self.server_notices_room_name = None - def read_config(self, config): + def read_config(self, config, **kwargs): c = config.get("server_notices") if c is None: return diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index 1502e9faba..1968003cb3 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -19,7 +19,7 @@ from ._base import Config class SpamCheckerConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.spam_checker = None provider = config.get("spam_checker", None) diff --git a/synapse/config/stats.py b/synapse/config/stats.py index 80fc1b9dd0..73a87c73f2 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -25,7 +25,7 @@ class StatsConfig(Config): Configuration for the behaviour of synapse's stats engine """ - def read_config(self, config): + def read_config(self, config, **kwargs): self.stats_enabled = True self.stats_bucket_size = 86400 self.stats_retention = sys.maxsize diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index a89dd5f98a..1bedd607b6 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -19,7 +19,7 @@ from ._base import Config class ThirdPartyRulesConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.third_party_event_rules = None provider = config.get("third_party_event_rules", None) diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 7951bf21fa..28be4366d6 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) class TlsConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): acme_config = config.get("acme", None) if acme_config is None: diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index e031b11599..0665dc3fcf 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -21,7 +21,7 @@ class UserDirectoryConfig(Config): Configuration for the behaviour of the /user_directory API """ - def read_config(self, config): + def read_config(self, config, **kwargs): self.user_directory_search_enabled = True self.user_directory_search_all_users = False user_directory_config = config.get("user_directory", None) diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 82cf8c53a8..01e0cb2e28 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -16,7 +16,7 @@ from ._base import Config class VoipConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.turn_uris = config.get("turn_uris", []) self.turn_shared_secret = config.get("turn_shared_secret") self.turn_username = config.get("turn_username") diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 4f283a0c2f..3b75471d85 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -21,7 +21,7 @@ class WorkerConfig(Config): They have their own pid_file and listener configuration. They use the replication_url to talk to the main synapse process.""" - def read_config(self, config): + def read_config(self, config, **kwargs): self.worker_app = config.get("worker_app") # Canonicalise worker_app so that master always has None diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 0cbbf4e885..a5d88d644a 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -65,7 +65,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= } t = TestConfig() - t.read_config(config) + t.read_config(config, config_dir_path="", data_dir_path="") t.read_certificate_from_disk(require_cert_and_key=False) warnings = self.flushWarnings() diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index b1094c1448..417fda3ab2 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -78,7 +78,7 @@ class MatrixFederationAgentTests(TestCase): # config_dict["trusted_key_servers"] = [] self._config = config = HomeServerConfig() - config.parse_config_dict(config_dict) + config.parse_config_dict(config_dict, "", "") self.agent = MatrixFederationAgent( reactor=self.reactor, diff --git a/tests/unittest.py b/tests/unittest.py index d64702b0c2..36df43c137 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -342,7 +342,7 @@ class HomeserverTestCase(TestCase): # Parse the config from a config dict into a HomeServerConfig config_obj = HomeServerConfig() - config_obj.parse_config_dict(config) + config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) diff --git a/tests/utils.py b/tests/utils.py index bd2c7c954c..da43166f3a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -182,7 +182,7 @@ def default_config(name, parse=False): if parse: config = HomeServerConfig() - config.parse_config_dict(config_dict) + config.parse_config_dict(config_dict, "", "") return config return config_dict -- cgit 1.5.1