diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index a83f567ebd..8bdbc608a9 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -59,7 +59,7 @@ class FrontendProxyTests(HomeserverTestCase):
def test_listen_http_with_presence_disabled(self):
"""
- When presence is on, the stub servlet will register.
+ When presence is off, the stub servlet will register.
"""
# Presence is off
self.hs.config.use_presence = False
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
new file mode 100644
index 0000000000..590abc1e92
--- /dev/null
+++ b/tests/app/test_openid_listener.py
@@ -0,0 +1,119 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock, patch
+
+from parameterized import parameterized
+
+from synapse.app.federation_reader import FederationReaderServer
+from synapse.app.homeserver import SynapseHomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class FederationReaderOpenIDListenerTests(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(
+ http_client=None, homeserverToUse=FederationReaderServer,
+ )
+ return hs
+
+ @parameterized.expand([
+ (["federation"], "auth_fail"),
+ ([], "no_resource"),
+ (["openid", "federation"], "auth_fail"),
+ (["openid"], "auth_fail"),
+ ])
+ def test_openid_listener(self, names, expectation):
+ """
+ Test different openid listener configurations.
+
+ 401 is success here since it means we hit the handler and auth failed.
+ """
+ config = {
+ "port": 8080,
+ "bind_addresses": ["0.0.0.0"],
+ "resources": [{"names": names}],
+ }
+
+ # Listen with the config
+ self.hs._listen_http(config)
+
+ # Grab the resource from the site that was told to listen
+ site = self.reactor.tcpServers[0][1]
+ try:
+ self.resource = (
+ site.resource.children[b"_matrix"].children[b"federation"]
+ )
+ except KeyError:
+ if expectation == "no_resource":
+ return
+ raise
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/federation/v1/openid/userinfo",
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 401)
+
+
+@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock())
+class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(
+ http_client=None, homeserverToUse=SynapseHomeServer,
+ )
+ return hs
+
+ @parameterized.expand([
+ (["federation"], "auth_fail"),
+ ([], "no_resource"),
+ (["openid", "federation"], "auth_fail"),
+ (["openid"], "auth_fail"),
+ ])
+ def test_openid_listener(self, names, expectation):
+ """
+ Test different openid listener configurations.
+
+ 401 is success here since it means we hit the handler and auth failed.
+ """
+ config = {
+ "port": 8080,
+ "bind_addresses": ["0.0.0.0"],
+ "resources": [{"names": names}],
+ }
+
+ # Listen with the config
+ self.hs._listener_http(config, config)
+
+ # Grab the resource from the site that was told to listen
+ site = self.reactor.tcpServers[0][1]
+ try:
+ self.resource = (
+ site.resource.children[b"_matrix"].children[b"federation"]
+ )
+ except KeyError:
+ if expectation == "no_resource":
+ return
+ raise
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/federation/v1/openid/userinfo",
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 401)
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index b5ad99348d..795b4c298d 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -50,8 +50,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
"homeserver.yaml",
"lemurs.win.log.config",
"lemurs.win.signing.key",
- "lemurs.win.tls.crt",
- "lemurs.win.tls.key",
]
),
set(os.listdir(self.dir)),
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
new file mode 100644
index 0000000000..c260d3359f
--- /dev/null
+++ b/tests/config/test_tls.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from synapse.config.tls import TlsConfig
+
+from tests.unittest import TestCase
+
+
+class TestConfig(TlsConfig):
+ def has_tls_listener(self):
+ return False
+
+
+class TLSConfigTests(TestCase):
+
+ def test_warn_self_signed(self):
+ """
+ Synapse will give a warning when it loads a self-signed certificate.
+ """
+ config_dir = self.mktemp()
+ os.mkdir(config_dir)
+ with open(os.path.join(config_dir, "cert.pem"), 'w') as f:
+ f.write("""-----BEGIN CERTIFICATE-----
+MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
+BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv
+Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb
+QXV0b21hdGVkIFRlc3RpbmcgQXV0aG9yaXR5MSkwJwYJKoZIhvcNAQkBFhpzZWN1
+cml0eUB0d2lzdGVkbWF0cml4LmNvbTAgFw0xNzA3MTIxNDAxNTNaGA8yMTE3MDYx
+ODE0MDE1M1owgbcxCzAJBgNVBAYTAlRSMQ8wDQYDVQQIDAbDh29ydW0xFDASBgNV
+BAcMC0JhxZ9tYWvDp8SxMRIwEAYDVQQDDAlsb2NhbGhvc3QxHDAaBgNVBAoME1R3
+aXN0ZWQgTWF0cml4IExhYnMxJDAiBgNVBAsMG0F1dG9tYXRlZCBUZXN0aW5nIEF1
+dGhvcml0eTEpMCcGCSqGSIb3DQEJARYac2VjdXJpdHlAdHdpc3RlZG1hdHJpeC5j
+b20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDwT6kbqtMUI0sMkx4h
+I+L780dA59KfksZCqJGmOsMD6hte9EguasfkZzvCF3dk3NhwCjFSOvKx6rCwiteo
+WtYkVfo+rSuVNmt7bEsOUDtuTcaxTzIFB+yHOYwAaoz3zQkyVW0c4pzioiLCGCmf
+FLdiDBQGGp74tb+7a0V6kC3vMLFoM3L6QWq5uYRB5+xLzlPJ734ltyvfZHL3Us6p
+cUbK+3WTWvb4ER0W2RqArAj6Bc/ERQKIAPFEiZi9bIYTwvBH27OKHRz+KoY/G8zY
++l+WZoJqDhupRAQAuh7O7V/y6bSP+KNxJRie9QkZvw1PSaGSXtGJI3WWdO12/Ulg
+epJpAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAJXEq5P9xwvP9aDkXIqzcD0L8sf8
+ewlhlxTQdeqt2Nace0Yk18lIo2oj1t86Y8jNbpAnZJeI813Rr5M7FbHCXoRc/SZG
+I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj
+iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2
+SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz
+s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
+-----END CERTIFICATE-----""")
+
+ config = {
+ "tls_certificate_path": os.path.join(config_dir, "cert.pem"),
+ "tls_fingerprints": []
+ }
+
+ t = TestConfig()
+ t.read_config(config)
+ t.read_certificate_from_disk(require_cert_and_key=False)
+
+ warnings = self.flushWarnings()
+ self.assertEqual(len(warnings), 1)
+ self.assertEqual(
+ warnings[0]["message"],
+ (
+ "Self-signed TLS certificates will not be accepted by "
+ "Synapse 1.0. Please either provide a valid certificate, "
+ "or use Synapse's ACME support to provision one."
+ )
+ )
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index b2536c1e69..71aa731439 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -18,7 +18,7 @@ import nacl.signing
from unpaddedbase64 import decode_base64
from synapse.crypto.event_signing import add_hashes_and_signatures
-from synapse.events.builder import EventBuilder
+from synapse.events import FrozenEvent
from tests import unittest
@@ -40,20 +40,18 @@ class EventSigningTestCase(unittest.TestCase):
self.signing_key.version = KEY_VER
def test_sign_minimal(self):
- builder = EventBuilder(
- {
- 'event_id': "$0:domain",
- 'origin': "domain",
- 'origin_server_ts': 1000000,
- 'signatures': {},
- 'type': "X",
- 'unsigned': {'age_ts': 1000000},
- }
- )
+ event_dict = {
+ 'event_id': "$0:domain",
+ 'origin': "domain",
+ 'origin_server_ts': 1000000,
+ 'signatures': {},
+ 'type': "X",
+ 'unsigned': {'age_ts': 1000000},
+ }
- add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
+ add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
- event = builder.build()
+ event = FrozenEvent(event_dict)
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
@@ -71,23 +69,21 @@ class EventSigningTestCase(unittest.TestCase):
)
def test_sign_message(self):
- builder = EventBuilder(
- {
- 'content': {'body': "Here is the message content"},
- 'event_id': "$0:domain",
- 'origin': "domain",
- 'origin_server_ts': 1000000,
- 'type': "m.room.message",
- 'room_id': "!r:domain",
- 'sender': "@u:domain",
- 'signatures': {},
- 'unsigned': {'age_ts': 1000000},
- }
- )
-
- add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
-
- event = builder.build()
+ event_dict = {
+ 'content': {'body': "Here is the message content"},
+ 'event_id': "$0:domain",
+ 'origin': "domain",
+ 'origin_server_ts': 1000000,
+ 'type': "m.room.message",
+ 'room_id': "!r:domain",
+ 'sender': "@u:domain",
+ 'signatures': {},
+ 'unsigned': {'age_ts': 1000000},
+ }
+
+ add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
+
+ event = FrozenEvent(event_dict)
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index c8994f416e..1c49bbbc3c 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -126,6 +126,78 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
})
@defer.inlineCallbacks
+ def test_update_version(self):
+ """Check that we can update versions.
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ res = yield self.handler.update_version(self.local_user, version, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version
+ })
+ self.assertDictEqual(res, {})
+
+ # check we can retrieve it as the current version
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertDictEqual(res, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version
+ })
+
+ @defer.inlineCallbacks
+ def test_update_missing_version(self):
+ """Check that we get a 404 on updating nonexistent versions
+ """
+ res = None
+ try:
+ yield self.handler.update_version(self.local_user, "1", {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "1"
+ })
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_update_bad_version(self):
+ """Check that we get a 400 if the version in the body is missing or
+ doesn't match
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ res = None
+ try:
+ yield self.handler.update_version(self.local_user, version, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data"
+ })
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 400)
+
+ res = None
+ try:
+ yield self.handler.update_version(self.local_user, version, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "incorrect"
+ })
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 400)
+
+ @defer.inlineCallbacks
def test_delete_missing_version(self):
"""Check that we get a 404 on deleting nonexistent versions
"""
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index e69de29bb2..ee8010f598 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os.path
+
+from OpenSSL import SSL
+
+
+def get_test_cert_file():
+ """get the path to the test cert"""
+
+ # the cert file itself is made with:
+ #
+ # openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \
+ # -nodes -subj '/CN=testserv'
+ return os.path.join(
+ os.path.dirname(__file__),
+ 'server.pem',
+ )
+
+
+class ServerTLSContext(object):
+ """A TLS Context which presents our test cert."""
+ def __init__(self):
+ self.filename = get_test_cert_file()
+
+ def getContext(self):
+ ctx = SSL.Context(SSL.TLSv1_METHOD)
+ ctx.use_certificate_file(self.filename)
+ ctx.use_privatekey_file(self.filename)
+ return ctx
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index eb963d80fb..dcf184d3cf 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,17 +17,26 @@ import logging
from mock import Mock
import treq
+from zope.interface import implementer
from twisted.internet import defer
+from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
-from twisted.test.ssl_helpers import ServerTLSContext
from twisted.web.http import HTTPChannel
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IPolicyForHTTPS
from synapse.crypto.context_factory import ClientTLSOptionsFactory
-from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.federation.matrix_federation_agent import (
+ MatrixFederationAgent,
+ _cache_period_from_headers,
+)
+from synapse.http.federation.srv_resolver import Server
+from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import LoggingContext
+from tests.http import ServerTLSContext
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
@@ -40,13 +49,17 @@ class MatrixFederationAgentTests(TestCase):
self.mock_resolver = Mock()
+ self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(None),
+ _well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver,
+ _well_known_cache=self.well_known_cache,
)
- def _make_connection(self, client_factory):
+ def _make_connection(self, client_factory, expected_sni):
"""Builds a test server, and completes the outgoing client connection
Returns:
@@ -64,14 +77,26 @@ class MatrixFederationAgentTests(TestCase):
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
- client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
+ client_protocol.makeConnection(
+ FakeTransport(server_tls_protocol, self.reactor, client_protocol),
+ )
# tell the server tls protocol to send its stuff back to the client, too
- server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
+ server_tls_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_tls_protocol),
+ )
- # finally, give the reactor a pump to get the TLS juices flowing.
+ # give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
+ # check the SNI
+ server_name = server_tls_protocol._tlsConnection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
# fish the test server back out of the server-side TLS protocol.
return server_tls_protocol.wrappedProtocol
@@ -92,12 +117,57 @@ class MatrixFederationAgentTests(TestCase):
try:
fetch_res = yield fetch_d
defer.returnValue(fetch_res)
+ except Exception as e:
+ logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
+ raise
finally:
_check_logcontext(context)
+ def _handle_well_known_connection(
+ self, client_factory, expected_sni, content, response_headers={},
+ ):
+ """Handle an outgoing HTTPs connection: wire it up to a server, check that the
+ request is for a .well-known, and send the response.
+
+ Args:
+ client_factory (IProtocolFactory): outgoing connection
+ expected_sni (bytes): SNI that we expect the outgoing connection to send
+ content (bytes): content to send back as the .well-known
+ Returns:
+ HTTPChannel: server impl
+ """
+ # make the connection for .well-known
+ well_known_server = self._make_connection(
+ client_factory,
+ expected_sni=expected_sni,
+ )
+ # check the .well-known request and send a response
+ self.assertEqual(len(well_known_server.requests), 1)
+ request = well_known_server.requests[0]
+ self._send_well_known_response(request, content, headers=response_headers)
+ return well_known_server
+
+ def _send_well_known_response(self, request, content, headers={}):
+ """Check that an incoming request looks like a valid .well-known request, and
+ send back the response.
+ """
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/.well-known/matrix/server')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'testserv'],
+ )
+ # send back a response
+ for k, v in headers.items():
+ request.setHeader(k, v)
+ request.write(content)
+ request.finish()
+
+ self.reactor.pump((0.1, ))
+
def test_get(self):
"""
- happy-path test of a GET request
+ happy-path test of a GET request with an explicit port
"""
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
@@ -113,7 +183,10 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8448)
# make a test server, and wire up the client
- http_server = self._make_connection(client_factory)
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ )
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
@@ -150,6 +223,733 @@ class MatrixFederationAgentTests(TestCase):
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
+ def test_get_ip_address(self):
+ """
+ Test the behaviour when the server name contains an explicit IP (with no port)
+ """
+ # there will be a getaddrinfo on the IP
+ self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=None,
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'1.2.3.4'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_ipv6_address(self):
+ """
+ Test the behaviour when the server name contains an explicit IPv6 address
+ (with no port)
+ """
+
+ # there will be a getaddrinfo on the IP
+ self.reactor.lookups["::1"] = "::1"
+
+ test_d = self._make_get_request(b"matrix://[::1]/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '::1')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=None,
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'[::1]'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_ipv6_address_with_port(self):
+ """
+ Test the behaviour when the server name contains an explicit IPv6 address
+ (with explicit port)
+ """
+
+ # there will be a getaddrinfo on the IP
+ self.reactor.lookups["::1"] = "::1"
+
+ test_d = self._make_get_request(b"matrix://[::1]:80/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '::1')
+ self.assertEqual(port, 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=None,
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'[::1]:80'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_no_srv_no_well_known(self):
+ """
+ Test the behaviour when the server name has no port, no SRV, and no well-known
+ """
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # No SRV record lookup yet
+ self.mock_resolver.resolve_service.assert_not_called()
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ # fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
+ # .well-known request fails.
+ self.reactor.pump((0.4,))
+
+ # now there should be a SRV lookup
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv",
+ )
+
+ # we should fall back to a direct connection
+ self.assertEqual(len(clients), 2)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'testserv',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'testserv'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_well_known(self):
+ """Test the behaviour when the .well-known delegates elsewhere
+ """
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["target-server"] = "1::f"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ self._handle_well_known_connection(
+ client_factory, expected_sni=b"testserv",
+ content=b'{ "m.server": "target-server" }',
+ )
+
+ # there should be a SRV lookup
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.target-server",
+ )
+
+ # now we should get a connection to the target server
+ self.assertEqual(len(clients), 2)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+ self.assertEqual(host, '1::f')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'target-server',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'target-server'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
+
+ # check the cache expires
+ self.reactor.pump((25 * 3600,))
+ self.well_known_cache.expire()
+ self.assertNotIn(b"testserv", self.well_known_cache)
+
+ def test_get_well_known_redirect(self):
+ """Test the behaviour when the server name has no port and no SRV record, but
+ the .well-known has a 300 redirect
+ """
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["target-server"] = "1::f"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ redirect_server = self._make_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ )
+
+ # send a 302 redirect
+ self.assertEqual(len(redirect_server.requests), 1)
+ request = redirect_server.requests[0]
+ request.redirect(b'https://testserv/even_better_known')
+ request.finish()
+
+ self.reactor.pump((0.1, ))
+
+ # now there should be another connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ well_known_server = self._make_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ )
+
+ self.assertEqual(len(well_known_server.requests), 1, "No request after 302")
+ request = well_known_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/even_better_known')
+ request.write(b'{ "m.server": "target-server" }')
+ request.finish()
+
+ self.reactor.pump((0.1, ))
+
+ # there should be a SRV lookup
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.target-server",
+ )
+
+ # now we should get a connection to the target server
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1::f')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'target-server',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'target-server'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
+
+ # check the cache expires
+ self.reactor.pump((25 * 3600,))
+ self.well_known_cache.expire()
+ self.assertNotIn(b"testserv", self.well_known_cache)
+
+ def test_get_invalid_well_known(self):
+ """
+ Test the behaviour when the server name has an *invalid* well-known (and no SRV)
+ """
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # No SRV record lookup yet
+ self.mock_resolver.resolve_service.assert_not_called()
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ self._handle_well_known_connection(
+ client_factory, expected_sni=b"testserv", content=b'NOT JSON',
+ )
+
+ # now there should be a SRV lookup
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv",
+ )
+
+ # we should fall back to a direct connection
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'testserv',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'testserv'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_hostname_srv(self):
+ """
+ Test the behaviour when there is a single SRV record
+ """
+ self.mock_resolver.resolve_service.side_effect = lambda _: [
+ Server(host=b"srvtarget", port=8443)
+ ]
+ self.reactor.lookups["srvtarget"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # the request for a .well-known will have failed with a DNS lookup error.
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv",
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'testserv',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'testserv'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_well_known_srv(self):
+ """Test the behaviour when the .well-known redirects to a place where there
+ is a SRV.
+ """
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["srvtarget"] = "5.6.7.8"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: [
+ Server(host=b"srvtarget", port=8443),
+ ]
+
+ self._handle_well_known_connection(
+ client_factory, expected_sni=b"testserv",
+ content=b'{ "m.server": "target-server" }',
+ )
+
+ # there should be a SRV lookup
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.target-server",
+ )
+
+ # now we should get a connection to the target of the SRV record
+ self.assertEqual(len(clients), 2)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+ self.assertEqual(host, '5.6.7.8')
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'target-server',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'target-server'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_idna_servername(self):
+ """test the behaviour when the server name has idna chars in"""
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+
+ # the resolver is always called with the IDNA hostname as a native string.
+ self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
+
+ # this is idna for bücher.com
+ test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # No SRV record lookup yet
+ self.mock_resolver.resolve_service.assert_not_called()
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ # fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
+ # .well-known request fails.
+ self.reactor.pump((0.4,))
+
+ # now there should have been a SRV lookup
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.xn--bcher-kva.com",
+ )
+
+ # We should fall back to port 8448
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 2)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'xn--bcher-kva.com',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'xn--bcher-kva.com'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_idna_srv_target(self):
+ """test the behaviour when the target of a SRV record has idna chars"""
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: [
+ Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
+ ]
+ self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.xn--bcher-kva.com",
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'xn--bcher-kva.com',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'xn--bcher-kva.com'],
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ @defer.inlineCallbacks
+ def do_get_well_known(self, serv):
+ try:
+ result = yield self.agent._get_well_known(serv)
+ logger.info("Result from well-known fetch: %s", result)
+ except Exception as e:
+ logger.warning("Error fetching well-known: %s", e)
+ raise
+ defer.returnValue(result)
+
+ def test_well_known_cache(self):
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ fetch_d = self.do_get_well_known(b'testserv')
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ well_known_server = self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ response_headers={b'Cache-Control': b'max-age=10'},
+ content=b'{ "m.server": "target-server" }',
+ )
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'target-server')
+
+ # close the tcp connection
+ well_known_server.loseConnection()
+
+ # repeat the request: it should hit the cache
+ fetch_d = self.do_get_well_known(b'testserv')
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'target-server')
+
+ # expire the cache
+ self.reactor.pump((10.0,))
+
+ # now it should connect again
+ fetch_d = self.do_get_well_known(b'testserv')
+
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ content=b'{ "m.server": "other-server" }',
+ )
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'other-server')
+
+
+class TestCachePeriodFromHeaders(TestCase):
+ def test_cache_control(self):
+ # uppercase
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
+ ), 100,
+ )
+
+ # missing value
+ self.assertIsNone(_cache_period_from_headers(
+ Headers({b'Cache-Control': [b'max-age=, bar']}),
+ ))
+
+ # hackernews: bogus due to semicolon
+ self.assertIsNone(_cache_period_from_headers(
+ Headers({b'Cache-Control': [b'private; max-age=0']}),
+ ))
+
+ # github
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
+ ), 0,
+ )
+
+ # google
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'cache-control': [b'private, max-age=0']}),
+ ), 0,
+ )
+
+ def test_expires(self):
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
+ time_now=lambda: 1548833700
+ ), 33,
+ )
+
+ # cache-control overrides expires
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({
+ b'cache-control': [b'max-age=10'],
+ b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
+ }),
+ time_now=lambda: 1548833700
+ ), 10,
+ )
+
+ # invalid expires means immediate expiry
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Expires': [b'0']}),
+ ), 0,
+ )
+
def _check_logcontext(context):
current = LoggingContext.current_context()
@@ -181,3 +981,11 @@ def _build_test_server():
def _log_request(request):
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
+
+
+@implementer(IPolicyForHTTPS)
+class TrustingTLSPolicyForHTTPS(object):
+ """An IPolicyForHTTPS which doesn't do any certificate verification"""
+ def creatorForNetloc(self, hostname, port):
+ certificateOptions = OpenSSLCertificateOptions()
+ return ClientTLSOptions(hostname, certificateOptions.getContext())
diff --git a/tests/http/server.pem b/tests/http/server.pem
new file mode 100644
index 0000000000..0584cf1a80
--- /dev/null
+++ b/tests/http/server.pem
@@ -0,0 +1,81 @@
+-----BEGIN PRIVATE KEY-----
+MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCgF43/3lAgJ+p0
+x7Rn8UcL8a4fctvdkikvZrCngw96LkB34Evfq8YGWlOVjU+f9naUJLAKMatmAfEN
+r+rMX4VOXmpTwuu6iLtqwreUrRFMESyrmvQxa15p+y85gkY0CFmXMblv6ORbxHTG
+ncBGwST4WK4Poewcgt6jcISFCESTUKu1zc3cw1ANIDRyDLB5K44KwIe36dcKckyN
+Kdtv4BJ+3fcIZIkPJH62zqCypgFF1oiFt40uJzClxgHdJZlKYpgkfnDTckw4Y/Mx
+9k8BbE310KAzUNMV9H7I1eEolzrNr66FQj1eN64X/dqO8lTbwCqAd4diCT4sIUk0
+0SVsAUjNd3g8j651hx+Qb1t8fuOjrny8dmeMxtUgIBHoQcpcj76R55Fs7KZ9uar0
+8OFTyGIze51W1jG2K/7/5M1zxIqrA+7lsXu5OR81s7I+Ng/UUAhiHA/z+42/aiNa
+qEuk6tqj3rHfLctnCbtZ+JrRNqSSwEi8F0lMA021ivEd2eJV+284OyJjhXOmKHrX
+QADHrmS7Sh4syTZvRNm9n+qWID0KdDr2Sji/KnS3Enp44HDQ4xriT6/xhwEGsyuX
+oH5aAkdLznulbWkHBbyx1SUQSTLpOqzaioF9m1vRrLsFvrkrY3D253mPJ5eU9HM/
+dilduFcUgj4rz+6cdXUAh+KK/v95zwIDAQABAoICAFG5tJPaOa0ws0/KYx5s3YgL
+aIhFalhCNSQtmCDrlwsYcXDA3/rfBchYdDL0YKGYgBBAal3J3WXFt/j0xThvyu2m
+5UC9UPl4s7RckrsjXqEmY1d3UxGnbhtMT19cUdpeKN42VCP9EBaIw9Rg07dLAkSF
+gNYaIx6q8F0fI4eGIPvTQtUcqur4CfWpaxyNvckdovV6M85/YXfDwbCOnacPDGIX
+jfSK3i0MxGMuOHr6o8uzKR6aBUh6WStHWcw7VXXTvzdiFNbckmx3Gb93rf1b/LBw
+QFfx+tBKcC62gKroCOzXso/0sL9YTVeSD/DJZOiJwSiz3Dj/3u1IUMbVvfTU8wSi
+CYS7Z+jHxwSOCSSNTXm1wO/MtDsNKbI1+R0cohr/J9pOMQvrVh1+2zSDOFvXAQ1S
+yvjn+uqdmijRoV2VEGVHd+34C+ci7eJGAhL/f92PohuuFR2shUETgGWzpACZSJwg
+j1d90Hs81hj07vWRb+xCeDh00vimQngz9AD8vYvv/S4mqRGQ6TZdfjLoUwSTg0JD
+6sQgRXX026gQhLhn687vLKZfHwzQPZkpQdxOR0dTZ/ho/RyGGRJXH4kN4cA2tPr+
+AKYQ29YXGlEzGG7OqikaZcprNWG6UFgEpuXyBxCgp9r4ladZo3J+1Rhgus8ZYatd
+uO98q3WEBmP6CZ2n32mBAoIBAQDS/c/ybFTos0YpGHakwdmSfj5OOQJto2y8ywfG
+qDHwO0ebcpNnS1+MA+7XbKUQb/3Iq7iJljkkzJG2DIJ6rpKynYts1ViYpM7M/t0T
+W3V1gvUcUL62iqkgws4pnpWmubFkqV31cPSHcfIIclnzeQ1aOEGsGHNAvhty0ciC
+DnkJACbqApvopFLOR5f6UFTtKExE+hDH0WqgpsCAKJ1L4g6pBzZatI32/CN9JEVU
+tDbxLV75hHlFFjUrG7nT1rPyr/gI8Ceh9/2xeXPfjJUR0PrG3U1nwLqUCZkvFzO6
+XpN2+A+/v4v5xqMjKDKDFy1oq6SCMomwv/viw6wl/84TMbolAoIBAQDCPiMecnR8
+REik6tqVzQO/uSe9ZHjz6J15t5xdwaI6HpSwLlIkQPkLTjyXtFpemK5DOYRxrJvQ
+remfrZrN2qtLlb/DKpuGPWRsPOvWCrSuNEp48ivUehtclljrzxAFfy0sM+fWeJ48
+nTnR+td9KNhjNtZixzWdAy/mE+jdaMsXVnk66L73Uz+2WsnvVMW2R6cpCR0F2eP/
+B4zDWRqlT2w47sePAB81mFYSQLvPC6Xcgg1OqMubfiizJI49c8DO6Jt+FFYdsxhd
+kG52Eqa/Net6rN3ueiS6yXL5TU3Y6g96bPA2KyNCypucGcddcBfqaiVx/o4AH6yT
+NrdsrYtyvk/jAoIBAQDHUwKVeeRJJbvdbQAArCV4MI155n+1xhMe1AuXkCQFWGtQ
+nlBE4D72jmyf1UKnIbW2Uwv15xY6/ouVWYIWlj9+QDmMaozVP7Uiko+WDuwLRNl8
+k4dn+dzHV2HejbPBG2JLv3lFOx23q1zEwArcaXrExaq9Ayg2fKJ/uVHcFAIiD6Oz
+pR1XDY4w1A/uaN+iYFSVQUyDCQLbnEz1hej73CaPZoHh9Pq83vxD5/UbjVjuRTeZ
+L55FNzKpc/r89rNvTPBcuUwnxplDhYKDKVNWzn9rSXwrzTY2Tk8J3rh+k4RqevSd
+6D47jH1n5Dy7/TRn0ueKHGZZtTUnyEUkbOJo3ayFAoIBAHKDyZaQqaX9Z8p6fwWj
+yVsFoK0ih8BcWkLBAdmwZ6DWGJjJpjmjaG/G3ygc9s4gO1R8m12dAnuDnGE8KzDD
+gwtbrKM2Alyg4wyA2hTlWOH/CAzH0RlCJ9Fs/d1/xJVJBeuyajLiB3/6vXTS6qnq
+I7BSSxAPG8eGcn21LSsjNeB7ZZtaTgNnu/8ZBUYo9yrgkWc67TZe3/ChldYxOOlO
+qqHh/BqNWtjxB4VZTp/g4RbgQVInZ2ozdXEv0v/dt0UEk29ANAjsZif7F3RayJ2f
+/0TilzCaJ/9K9pKNhaClVRy7Dt8QjYg6BIWCGSw4ApF7pLnQ9gySn95mersCkVzD
+YDsCggEAb0E/TORjQhKfNQvahyLfQFm151e+HIoqBqa4WFyfFxe/IJUaLH/JSSFw
+VohbQqPdCmaAeuQ8ERL564DdkcY5BgKcax79fLLCOYP5bT11aQx6uFpfl2Dcm6Z9
+QdCRI4jzPftsd5fxLNH1XtGyC4t6vTic4Pji2O71WgWzx0j5v4aeDY4sZQeFxqCV
+/q7Ee8hem1Rn5RFHu14FV45RS4LAWl6wvf5pQtneSKzx8YL0GZIRRytOzdEfnGKr
+FeUlAj5uL+5/p0ZEgM7gPsEBwdm8scF79qSUn8UWSoXNeIauF9D4BDg8RZcFFxka
+KILVFsq3cQC+bEnoM4eVbjEQkGs1RQ==
+-----END PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIIE/jCCAuagAwIBAgIJANFtVaGvJWZlMA0GCSqGSIb3DQEBCwUAMBMxETAPBgNV
+BAMMCHRlc3RzZXJ2MCAXDTE5MDEyNzIyMDIzNloYDzIxMTkwMTAzMjIwMjM2WjAT
+MREwDwYDVQQDDAh0ZXN0c2VydjCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC
+ggIBAKAXjf/eUCAn6nTHtGfxRwvxrh9y292SKS9msKeDD3ouQHfgS9+rxgZaU5WN
+T5/2dpQksAoxq2YB8Q2v6sxfhU5ealPC67qIu2rCt5StEUwRLKua9DFrXmn7LzmC
+RjQIWZcxuW/o5FvEdMadwEbBJPhYrg+h7ByC3qNwhIUIRJNQq7XNzdzDUA0gNHIM
+sHkrjgrAh7fp1wpyTI0p22/gEn7d9whkiQ8kfrbOoLKmAUXWiIW3jS4nMKXGAd0l
+mUpimCR+cNNyTDhj8zH2TwFsTfXQoDNQ0xX0fsjV4SiXOs2vroVCPV43rhf92o7y
+VNvAKoB3h2IJPiwhSTTRJWwBSM13eDyPrnWHH5BvW3x+46OufLx2Z4zG1SAgEehB
+ylyPvpHnkWzspn25qvTw4VPIYjN7nVbWMbYr/v/kzXPEiqsD7uWxe7k5HzWzsj42
+D9RQCGIcD/P7jb9qI1qoS6Tq2qPesd8ty2cJu1n4mtE2pJLASLwXSUwDTbWK8R3Z
+4lX7bzg7ImOFc6YoetdAAMeuZLtKHizJNm9E2b2f6pYgPQp0OvZKOL8qdLcSenjg
+cNDjGuJPr/GHAQazK5egfloCR0vOe6VtaQcFvLHVJRBJMuk6rNqKgX2bW9GsuwW+
+uStjcPbneY8nl5T0cz92KV24VxSCPivP7px1dQCH4or+/3nPAgMBAAGjUzBRMB0G
+A1UdDgQWBBQcQZpzLzTk5KdS/Iz7sGCV7gTd/zAfBgNVHSMEGDAWgBQcQZpzLzTk
+5KdS/Iz7sGCV7gTd/zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC
+AQAr/Pgha57jqYsDDX1LyRrVdqoVBpLBeB7x/p9dKYm7S6tBTDFNMZ0SZyQP8VEG
+7UoC9/OQ9nCdEMoR7ZKpQsmipwcIqpXHS6l4YOkf5EEq5jpMgvlEesHmBJJeJew/
+FEPDl1bl8d0tSrmWaL3qepmwzA+2lwAAouWk2n+rLiP8CZ3jZeoTXFqYYrUlEqO9
+fHMvuWqTV4KCSyNY+GWCrnHetulgKHlg+W2J1mZnrCKcBhWf9C2DesTJO+JldIeM
+ornTFquSt21hZi+k3aySuMn2N3MWiNL8XsZVsAnPSs0zA+2fxjJkShls8Gc7cCvd
+a6XrNC+PY6pONguo7rEU4HiwbvnawSTngFFglmH/ImdA/HkaAekW6o82aI8/UxFx
+V9fFMO3iKDQdOrg77hI1bx9RlzKNZZinE2/Pu26fWd5d2zqDWCjl8ykGQRAfXgYN
+H3BjgyXLl+ao5/pOUYYtzm3ruTXTgRcy5hhL6hVTYhSrf9vYh4LNIeXNKnZ78tyG
+TX77/kU2qXhBGCFEUUMqUNV/+ITir2lmoxVjknt19M07aGr8C7SgYt6Rs+qDpMiy
+JurgvRh8LpVq4pHx1efxzxCFmo58DMrG40I0+CF3y/niNpOb1gp2wAqByRiORkds
+f0ytW6qZ0TpHbD6gOtQLYDnhx3ISuX+QYSekVwQUpffeWQ==
+-----END CERTIFICATE-----
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index d37f8f9981..b03b37affe 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -43,14 +43,11 @@ def check_logcontext(context):
class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
-
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
- hs.tls_client_options_factory = None
return hs
def prepare(self, reactor, clock, homeserver):
-
- self.cl = MatrixFederationHttpClient(self.hs)
+ self.cl = MatrixFederationHttpClient(self.hs, None)
self.reactor.lookups["testserv"] = "1.2.3.4"
def test_client_get(self):
@@ -95,6 +92,7 @@ class FederationClientTests(HomeserverTestCase):
# that should have made it send the request to the transport
self.assertRegex(transport.value(), b"^GET /foo/bar")
+ self.assertRegex(transport.value(), b"Host: testserv:8008")
# Deferred is still without a result
self.assertNoResult(test_d)
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
new file mode 100644
index 0000000000..d3d43970fb
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.constants import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.rest.client.v1 import admin, login
+from synapse.rest.client.v2_alpha import capabilities
+
+from tests import unittest
+
+
+class CapabilitiesTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ capabilities.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.url = b"/_matrix/client/r0/capabilities"
+ hs = self.setup_test_homeserver()
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_check_auth_required(self):
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+
+ self.assertEqual(channel.code, 401)
+
+ def test_get_room_version_capabilities(self):
+ self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ request, channel = self.make_request("GET", self.url, access_token=access_token)
+ self.render(request)
+ capabilities = channel.json_body['capabilities']
+
+ self.assertEqual(channel.code, 200)
+ for room_version in capabilities['m.room_versions']['available'].keys():
+ self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
+ self.assertEqual(
+ DEFAULT_ROOM_VERSION, capabilities['m.room_versions']['default']
+ )
+
+ def test_get_change_password_capabilities(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.login(user, password)
+
+ request, channel = self.make_request("GET", self.url, access_token=access_token)
+ self.render(request)
+ capabilities = channel.json_body['capabilities']
+
+ self.assertEqual(channel.code, 200)
+
+ # Test case where password is handled outside of Synapse
+ self.assertTrue(capabilities['m.change_password']['enabled'])
+ self.get_success(self.store.user_set_password_hash(user, None))
+ request, channel = self.make_request("GET", self.url, access_token=access_token)
+ self.render(request)
+ capabilities = channel.json_body['capabilities']
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities['m.change_password']['enabled'])
diff --git a/tests/server.py b/tests/server.py
index ed2a046ae6..fc1e76d146 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -8,11 +8,10 @@ import attr
from zope.interface import implementer
from twisted.internet import address, threads, udp
-from twisted.internet._resolver import HostResolution
-from twisted.internet.address import IPv4Address
-from twisted.internet.defer import Deferred
+from twisted.internet._resolver import SimpleResolverComplexifier
+from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver
+from twisted.internet.interfaces import IReactorPluggableNameResolver, IResolverSimple
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.http import unquote
@@ -227,30 +226,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self._udp = []
- self.lookups = {}
-
- class Resolver(object):
- def resolveHostName(
- _self,
- resolutionReceiver,
- hostName,
- portNumber=0,
- addressTypes=None,
- transportSemantics='TCP',
- ):
-
- resolution = HostResolution(hostName)
- resolutionReceiver.resolutionBegan(resolution)
- if hostName not in self.lookups:
- raise DNSLookupError("OH NO")
-
- resolutionReceiver.addressResolved(
- IPv4Address('TCP', self.lookups[hostName], portNumber)
- )
- resolutionReceiver.resolutionComplete()
- return resolution
-
- self.nameResolver = Resolver()
+ lookups = self.lookups = {}
+
+ @implementer(IResolverSimple)
+ class FakeResolver(object):
+ def getHostByName(self, name, timeout=None):
+ if name not in lookups:
+ return fail(DNSLookupError("OH NO: unknown %s" % (name, )))
+ return succeed(lookups[name])
+
+ self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super(ThreadedMemoryReactorClock, self).__init__()
def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
@@ -369,7 +354,13 @@ class FakeTransport(object):
:type: twisted.internet.interfaces.IReactorTime
"""
+ _protocol = attr.ib(default=None)
+ """The Protocol which is producing data for this transport. Optional, but if set
+ will get called back for connectionLost() notifications etc.
+ """
+
disconnecting = False
+ disconnected = False
buffer = attr.ib(default=b'')
producer = attr.ib(default=None)
@@ -379,11 +370,17 @@ class FakeTransport(object):
def getHost(self):
return None
- def loseConnection(self):
- self.disconnecting = True
+ def loseConnection(self, reason=None):
+ if not self.disconnecting:
+ logger.info("FakeTransport: loseConnection(%s)", reason)
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(reason)
+ self.disconnected = True
def abortConnection(self):
- self.disconnecting = True
+ logger.info("FakeTransport: abortConnection()")
+ self.loseConnection()
def pauseProducing(self):
if not self.producer:
@@ -422,9 +419,16 @@ class FakeTransport(object):
# TLSMemoryBIOProtocol
return
+ if self.disconnected:
+ return
+ logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
+
if getattr(self.other, "transport") is not None:
- self.other.dataReceived(self.buffer)
- self.buffer = b""
+ try:
+ self.other.dataReceived(self.buffer)
+ self.buffer = b""
+ except Exception as e:
+ logger.warning("Exception writing to protocol: %s", e)
return
self._reactor.callLater(0.0, _write)
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 2e073a3afc..9a5c816927 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -19,7 +19,7 @@ from six.moves import zip
import attr
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import FrozenEvent
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
@@ -539,6 +539,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]])
else:
state_d = resolve_events_with_store(
+ RoomVersions.V2,
[state_at_event[n] for n in prev_events],
event_map=event_map,
state_res_store=TestStateResolutionStore(event_map),
@@ -685,6 +686,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
+ RoomVersions.V2,
[self.state_at_bob, self.state_at_charlie],
event_map=None,
state_res_store=TestStateResolutionStore(self.event_map),
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 81403727c5..5568a607c7 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -11,7 +11,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(
self.addCleanup
- ) # type: synapse.server.HomeServer
+ )
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 829f47d2e8..f18db8c384 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -49,13 +49,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection
config = Mock()
+ config._disable_native_upserts = True
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
+ engine = create_engine(config.database_config)
+ fake_engine = Mock(wraps=engine)
+ fake_engine.can_native_upsert = False
hs = TestHomeServer(
"test",
db_pool=self.db_pool,
config=config,
- database_engine=create_engine(config.database_config),
+ database_engine=fake_engine,
)
self.datastore = SQLBaseStore(None, hs)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index b83f7336d3..11fb8c0c19 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -20,9 +20,6 @@ import tests.utils
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 47f4a8ceac..0d2dc9f325 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -22,9 +22,6 @@ import tests.utils
class KeyStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(KeyStoreTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.keys.KeyStore
@defer.inlineCallbacks
def setUp(self):
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9605301b59..d6569a82bb 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -18,12 +18,12 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
-from tests.unittest import HomeserverTestCase
+from tests import unittest
FORTY_DAYS = 40 * 24 * 60 * 60
-class MonthlyActiveUsersTestCase(HomeserverTestCase):
+class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 02bf975fbf..3957561b1e 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.types import RoomID, UserID
from tests import unittest
@@ -52,6 +52,7 @@ class RedactionTestCase(unittest.TestCase):
content = {"membership": membership}
content.update(extra_content)
builder = self.event_builder_factory.new(
+ RoomVersions.V1,
{
"type": EventTypes.Member,
"sender": user.to_string(),
@@ -74,6 +75,7 @@ class RedactionTestCase(unittest.TestCase):
self.depth += 1
builder = self.event_builder_factory.new(
+ RoomVersions.V1,
{
"type": EventTypes.Message,
"sender": user.to_string(),
@@ -94,6 +96,7 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason):
builder = self.event_builder_factory.new(
+ RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": user.to_string(),
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 978c66133d..7fa2f4fd70 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.types import RoomID, UserID
from tests import unittest
@@ -50,6 +50,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.new(
+ RoomVersions.V1,
{
"type": EventTypes.Member,
"sender": user.to_string(),
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 086a39d834..99cd3e09eb 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
@@ -28,9 +28,6 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(StateStoreTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
@@ -52,6 +49,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.new(
+ RoomVersions.V1,
{
"type": typ,
"sender": sender.to_string(),
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 411b4a9f86..7ee318e4e8 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -16,6 +16,7 @@
import unittest
from synapse import event_auth
+from synapse.api.constants import RoomVersions
from synapse.api.errors import AuthError
from synapse.events import FrozenEvent
@@ -35,12 +36,16 @@ class EventAuthTestCase(unittest.TestCase):
}
# creator should be able to send state
- event_auth.check(_random_state_event(creator), auth_events, do_sig_check=False)
+ event_auth.check(
+ RoomVersions.V1, _random_state_event(creator), auth_events,
+ do_sig_check=False,
+ )
# joiner should not be able to send state
self.assertRaises(
AuthError,
event_auth.check,
+ RoomVersions.V1,
_random_state_event(joiner),
auth_events,
do_sig_check=False,
@@ -69,13 +74,17 @@ class EventAuthTestCase(unittest.TestCase):
self.assertRaises(
AuthError,
event_auth.check,
+ RoomVersions.V1,
_random_state_event(pleb),
auth_events,
do_sig_check=False,
),
# king should be able to send state
- event_auth.check(_random_state_event(king), auth_events, do_sig_check=False)
+ event_auth.check(
+ RoomVersions.V1, _random_state_event(king), auth_events,
+ do_sig_check=False,
+ )
# helpers for making events
diff --git a/tests/test_server.py b/tests/test_server.py
index 634a8fbca5..08fb3fe02f 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -19,7 +19,7 @@ from six import StringIO
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -30,12 +30,18 @@ from synapse.util import Clock
from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
-from tests.server import FakeTransport, make_request, render, setup_test_homeserver
+from tests.server import (
+ FakeTransport,
+ ThreadedMemoryReactorClock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
class JsonResourceTests(unittest.TestCase):
def setUp(self):
- self.reactor = MemoryReactorClock()
+ self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
new file mode 100644
index 0000000000..a7310cf12a
--- /dev/null
+++ b/tests/test_utils/__init__.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utilities for running the unit tests
+"""
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
new file mode 100644
index 0000000000..d0bc8e2112
--- /dev/null
+++ b/tests/test_utils/logging_setup.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+
+import twisted.logger
+
+from synapse.util.logcontext import LoggingContextFilter
+
+
+class ToTwistedHandler(logging.Handler):
+ """logging handler which sends the logs to the twisted log"""
+ tx_log = twisted.logger.Logger()
+
+ def emit(self, record):
+ log_entry = self.format(record)
+ log_level = record.levelname.lower().replace('warning', 'warn')
+ self.tx_log.emit(
+ twisted.logger.LogLevel.levelWithName(log_level),
+ log_entry.replace("{", r"(").replace("}", r")"),
+ )
+
+
+def setup_logging():
+ """Configure the python logging appropriately for the tests.
+
+ (Logs will end up in _trial_temp.)
+ """
+ root_logger = logging.getLogger()
+
+ log_format = (
+ "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
+ )
+
+ handler = ToTwistedHandler()
+ formatter = logging.Formatter(log_format)
+ handler.setFormatter(formatter)
+ handler.addFilter(LoggingContextFilter(request=""))
+ root_logger.addHandler(handler)
+
+ log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
+ root_logger.setLevel(log_level)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 2eea3b098b..455db9f276 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -17,6 +17,7 @@ import logging
from twisted.internet import defer
from twisted.internet.defer import succeed
+from synapse.api.constants import RoomVersions
from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server
@@ -124,6 +125,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
def inject_visibility(self, user_id, visibility):
content = {"history_visibility": visibility}
builder = self.event_builder_factory.new(
+ RoomVersions.V1,
{
"type": "m.room.history_visibility",
"sender": user_id,
@@ -144,6 +146,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
content = {"membership": membership}
content.update(extra_content)
builder = self.event_builder_factory.new(
+ RoomVersions.V1,
{
"type": "m.room.member",
"sender": user_id,
@@ -163,8 +166,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def inject_message(self, user_id, content=None):
if content is None:
- content = {"body": "testytest"}
+ content = {"body": "testytest", "msgtype": "m.text"}
builder = self.event_builder_factory.new(
+ RoomVersions.V1,
{
"type": "m.room.message",
"sender": user_id,
diff --git a/tests/unittest.py b/tests/unittest.py
index 78d2f740f9..fac254ff10 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -31,38 +31,14 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContext, LoggingContextFilter
+from synapse.util.logcontext import LoggingContext
from tests.server import get_clock, make_request, render, setup_test_homeserver
+from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
setupdb()
-
-# Set up putting Synapse's logs into Trial's.
-rootLogger = logging.getLogger()
-
-log_format = (
- "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
-)
-
-
-class ToTwistedHandler(logging.Handler):
- tx_log = twisted.logger.Logger()
-
- def emit(self, record):
- log_entry = self.format(record)
- log_level = record.levelname.lower().replace('warning', 'warn')
- self.tx_log.emit(
- twisted.logger.LogLevel.levelWithName(log_level),
- log_entry.replace("{", r"(").replace("}", r")"),
- )
-
-
-handler = ToTwistedHandler()
-formatter = logging.Formatter(log_format)
-handler.setFormatter(formatter)
-handler.addFilter(LoggingContextFilter(request=""))
-rootLogger.addHandler(handler)
+setup_logging()
def around(target):
@@ -96,7 +72,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName)
- level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
+ level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@around(self)
def setUp(orig):
@@ -114,7 +90,7 @@ class TestCase(unittest.TestCase):
)
old_level = logging.getLogger().level
- if old_level != level:
+ if level is not None and old_level != level:
@around(self)
def tearDown(orig):
@@ -122,7 +98,8 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(old_level)
return ret
- logging.getLogger().setLevel(level)
+ logging.getLogger().setLevel(level)
+
return orig()
@around(self)
@@ -333,7 +310,15 @@ class HomeserverTestCase(TestCase):
"""
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
- return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ stor = hs.get_datastore()
+
+ # Run the database background updates.
+ if hasattr(stor, "do_next_background_update"):
+ while not self.get_success(stor.has_completed_background_updates()):
+ self.get_success(stor.do_next_background_update(1))
+
+ return hs
def pump(self, by=0.0):
"""
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
new file mode 100644
index 0000000000..03b3c15db6
--- /dev/null
+++ b/tests/util/caches/test_ttlcache.py
@@ -0,0 +1,83 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from synapse.util.caches.ttlcache import TTLCache
+
+from tests import unittest
+
+
+class CacheTestCase(unittest.TestCase):
+ def setUp(self):
+ self.mock_timer = Mock(side_effect=lambda: 100.0)
+ self.cache = TTLCache("test_cache", self.mock_timer)
+
+ def test_get(self):
+ """simple set/get tests"""
+ self.cache.set('one', '1', 10)
+ self.cache.set('two', '2', 20)
+ self.cache.set('three', '3', 30)
+
+ self.assertEqual(len(self.cache), 3)
+
+ self.assertTrue('one' in self.cache)
+ self.assertEqual(self.cache.get('one'), '1')
+ self.assertEqual(self.cache['one'], '1')
+ self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110))
+ self.assertEqual(self.cache._metrics.hits, 3)
+ self.assertEqual(self.cache._metrics.misses, 0)
+
+ self.cache.set('two', '2.5', 20)
+ self.assertEqual(self.cache['two'], '2.5')
+ self.assertEqual(self.cache._metrics.hits, 4)
+
+ # non-existent-item tests
+ self.assertEqual(self.cache.get('four', '4'), '4')
+ self.assertIs(self.cache.get('four', None), None)
+
+ with self.assertRaises(KeyError):
+ self.cache['four']
+
+ with self.assertRaises(KeyError):
+ self.cache.get('four')
+
+ with self.assertRaises(KeyError):
+ self.cache.get_with_expiry('four')
+
+ self.assertEqual(self.cache._metrics.hits, 4)
+ self.assertEqual(self.cache._metrics.misses, 5)
+
+ def test_expiry(self):
+ self.cache.set('one', '1', 10)
+ self.cache.set('two', '2', 20)
+ self.cache.set('three', '3', 30)
+
+ self.assertEqual(len(self.cache), 3)
+ self.assertEqual(self.cache['one'], '1')
+ self.assertEqual(self.cache['two'], '2')
+
+ # enough for the first entry to expire, but not the rest
+ self.mock_timer.side_effect = lambda: 110.0
+
+ self.assertEqual(len(self.cache), 2)
+ self.assertFalse('one' in self.cache)
+ self.assertEqual(self.cache['two'], '2')
+ self.assertEqual(self.cache['three'], '3')
+
+ self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120))
+
+ self.assertEqual(self.cache._metrics.hits, 5)
+ self.assertEqual(self.cache._metrics.misses, 0)
diff --git a/tests/utils.py b/tests/utils.py
index 08d6faa0a6..2dfcb70a93 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -26,7 +26,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, RoomVersions
from synapse.api.errors import CodeMessageException, cs_error
from synapse.config.server import ServerConfig
from synapse.federation.transport import server
@@ -154,7 +154,9 @@ def default_config(name):
config.update_user_directory = False
def is_threepid_reserved(threepid):
- return ServerConfig.is_threepid_reserved(config, threepid)
+ return ServerConfig.is_threepid_reserved(
+ config.mau_limits_reserved_threepids, threepid
+ )
config.is_threepid_reserved.side_effect = is_threepid_reserved
@@ -622,6 +624,7 @@ def create_room(hs, room_id, creator_id):
event_creation_handler = hs.get_event_creation_handler()
builder = event_builder_factory.new(
+ RoomVersions.V1,
{
"type": EventTypes.Create,
"state_key": "",
|