summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py148
-rw-r--r--tests/storage/test_base.py1
-rw-r--r--tests/test_server.py12
-rw-r--r--tests/unittest.py12
-rw-r--r--tests/utils.py4
5 files changed, 163 insertions, 14 deletions
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index eb963d80fb..b32d7566a5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -26,6 +26,7 @@ from twisted.web.http import HTTPChannel
 
 from synapse.crypto.context_factory import ClientTLSOptionsFactory
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.federation.srv_resolver import Server
 from synapse.util.logcontext import LoggingContext
 
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
@@ -46,7 +47,7 @@ class MatrixFederationAgentTests(TestCase):
             _srv_resolver=self.mock_resolver,
         )
 
-    def _make_connection(self, client_factory):
+    def _make_connection(self, client_factory, expected_sni):
         """Builds a test server, and completes the outgoing client connection
 
         Returns:
@@ -69,9 +70,17 @@ class MatrixFederationAgentTests(TestCase):
         # tell the server tls protocol to send its stuff back to the client, too
         server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
 
-        # finally, give the reactor a pump to get the TLS juices flowing.
+        # give the reactor a pump to get the TLS juices flowing.
         self.reactor.pump((0.1,))
 
+        # check the SNI
+        server_name = server_tls_protocol._tlsConnection.get_servername()
+        self.assertEqual(
+            server_name,
+            expected_sni,
+            "Expected SNI %s but got %s" % (expected_sni, server_name),
+        )
+
         # fish the test server back out of the server-side TLS protocol.
         return server_tls_protocol.wrappedProtocol
 
@@ -97,7 +106,7 @@ class MatrixFederationAgentTests(TestCase):
 
     def test_get(self):
         """
-        happy-path test of a GET request
+        happy-path test of a GET request with an explicit port
         """
         self.reactor.lookups["testserv"] = "1.2.3.4"
         test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
@@ -113,16 +122,15 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
-        http_server = self._make_connection(client_factory)
+        http_server = self._make_connection(
+            client_factory,
+            expected_sni=b"testserv",
+        )
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
         self.assertEqual(request.method, b'GET')
         self.assertEqual(request.path, b'/foo/bar')
-        self.assertEqual(
-            request.requestHeaders.getRawHeaders(b'host'),
-            [b'testserv:8448']
-        )
         content = request.content.read()
         self.assertEqual(content, b'')
 
@@ -150,6 +158,130 @@ class MatrixFederationAgentTests(TestCase):
         json = self.successResultOf(treq.json_content(response))
         self.assertEqual(json, {"a": 1})
 
+    def test_get_ip_address(self):
+        """
+        Test the behaviour when the server name contains an explicit IP (with no port)
+        """
+
+        # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
+        self.mock_resolver.resolve_service.side_effect = lambda _: []
+
+        # then there will be a getaddrinfo on the IP
+        self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        self.mock_resolver.resolve_service.assert_called_once_with(
+            b"_matrix._tcp.1.2.3.4",
+        )
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(port, 8448)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory,
+            expected_sni=None,
+        )
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b'GET')
+        self.assertEqual(request.path, b'/foo/bar')
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
+
+    def test_get_hostname_no_srv(self):
+        """
+        Test the behaviour when the server name has no port, and no SRV record
+        """
+
+        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        self.mock_resolver.resolve_service.assert_called_once_with(
+            b"_matrix._tcp.testserv",
+        )
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(port, 8448)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory,
+            expected_sni=b'testserv',
+        )
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b'GET')
+        self.assertEqual(request.path, b'/foo/bar')
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
+
+    def test_get_hostname_srv(self):
+        """
+        Test the behaviour when there is a single SRV record
+        """
+        self.mock_resolver.resolve_service.side_effect = lambda _: [
+            Server(host="srvtarget", port=8443)
+        ]
+        self.reactor.lookups["srvtarget"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        self.mock_resolver.resolve_service.assert_called_once_with(
+            b"_matrix._tcp.testserv",
+        )
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(port, 8443)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory,
+            expected_sni=b'testserv',
+        )
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b'GET')
+        self.assertEqual(request.path, b'/foo/bar')
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
+
 
 def _check_logcontext(context):
     current = LoggingContext.current_context()
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 829f47d2e8..452d76ddd5 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.db_pool.runWithConnection = runWithConnection
 
         config = Mock()
+        config._enable_native_upserts = False
         config.event_cache_size = 1
         config.database_config = {"name": "sqlite3"}
         hs = TestHomeServer(
diff --git a/tests/test_server.py b/tests/test_server.py
index 634a8fbca5..08fb3fe02f 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -19,7 +19,7 @@ from six import StringIO
 
 from twisted.internet.defer import Deferred
 from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import AccumulatingProtocol
 from twisted.web.resource import Resource
 from twisted.web.server import NOT_DONE_YET
 
@@ -30,12 +30,18 @@ from synapse.util import Clock
 from synapse.util.logcontext import make_deferred_yieldable
 
 from tests import unittest
-from tests.server import FakeTransport, make_request, render, setup_test_homeserver
+from tests.server import (
+    FakeTransport,
+    ThreadedMemoryReactorClock,
+    make_request,
+    render,
+    setup_test_homeserver,
+)
 
 
 class JsonResourceTests(unittest.TestCase):
     def setUp(self):
-        self.reactor = MemoryReactorClock()
+        self.reactor = ThreadedMemoryReactorClock()
         self.hs_clock = Clock(self.reactor)
         self.homeserver = setup_test_homeserver(
             self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
diff --git a/tests/unittest.py b/tests/unittest.py
index 78d2f740f9..cda549c783 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -96,7 +96,7 @@ class TestCase(unittest.TestCase):
 
         method = getattr(self, methodName)
 
-        level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
+        level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
 
         @around(self)
         def setUp(orig):
@@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase):
         """
         kwargs = dict(kwargs)
         kwargs.update(self._hs_args)
-        return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+        hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+        stor = hs.get_datastore()
+
+        # Run the database background updates.
+        if hasattr(stor, "do_next_background_update"):
+            while not self.get_success(stor.has_completed_background_updates()):
+                self.get_success(stor.do_next_background_update(1))
+
+        return hs
 
     def pump(self, by=0.0):
         """
diff --git a/tests/utils.py b/tests/utils.py
index 08d6faa0a6..df73c539c3 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -154,7 +154,9 @@ def default_config(name):
     config.update_user_directory = False
 
     def is_threepid_reserved(threepid):
-        return ServerConfig.is_threepid_reserved(config, threepid)
+        return ServerConfig.is_threepid_reserved(
+            config.mau_limits_reserved_threepids, threepid
+        )
 
     config.is_threepid_reserved.side_effect = is_threepid_reserved