summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/http/test_fedclient.py14
-rw-r--r--tests/http/test_simple_client.py180
-rw-r--r--tests/rest/client/v2_alpha/test_account.py6
-rw-r--r--tests/storage/test_monthly_active_users.py17
4 files changed, 210 insertions, 7 deletions
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 5604af3795..212484a7fe 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -318,14 +318,14 @@ class FederationClientTests(HomeserverTestCase):
         r = self.successResultOf(d)
         self.assertEqual(r.code, 200)
 
-    def test_client_headers_no_body(self):
+    @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
+    def test_timeout_reading_body(self, method_name: str):
         """
         If the HTTP request is connected, but gets no response before being
-        timed out, it'll give a ResponseNeverReceived.
+        timed out, it'll give a RequestSendFailed with can_retry.
         """
-        d = defer.ensureDeferred(
-            self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
-        )
+        method = getattr(self.cl, method_name)
+        d = defer.ensureDeferred(method("testserv:8008", "foo/bar", timeout=10000))
 
         self.pump()
 
@@ -349,7 +349,9 @@ class FederationClientTests(HomeserverTestCase):
         self.reactor.advance(10.5)
         f = self.failureResultOf(d)
 
-        self.assertIsInstance(f.value, TimeoutError)
+        self.assertIsInstance(f.value, RequestSendFailed)
+        self.assertTrue(f.value.can_retry)
+        self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
 
     def test_client_requires_trailing_slashes(self):
         """
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
new file mode 100644
index 0000000000..a1cf0862d4
--- /dev/null
+++ b/tests/http/test_simple_client.py
@@ -0,0 +1,180 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from netaddr import IPSet
+
+from twisted.internet import defer
+from twisted.internet.error import DNSLookupError
+
+from synapse.http import RequestTimedOutError
+from synapse.http.client import SimpleHttpClient
+from synapse.server import HomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SimpleHttpClientTests(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs: "HomeServer"):
+        # Add a DNS entry for a test server
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        self.cl = hs.get_simple_http_client()
+
+    def test_dns_error(self):
+        """
+        If the DNS lookup returns an error, it will bubble up.
+        """
+        d = defer.ensureDeferred(self.cl.get_json("http://testserv2:8008/foo/bar"))
+        self.pump()
+
+        f = self.failureResultOf(d)
+        self.assertIsInstance(f.value, DNSLookupError)
+
+    def test_client_connection_refused(self):
+        d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+        self.pump()
+
+        # Nothing happened yet
+        self.assertNoResult(d)
+
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8008)
+        e = Exception("go away")
+        factory.clientConnectionFailed(None, e)
+        self.pump(0.5)
+
+        f = self.failureResultOf(d)
+
+        self.assertIs(f.value, e)
+
+    def test_client_never_connect(self):
+        """
+        If the HTTP request is not connected and is timed out, it'll give a
+        ConnectingCancelledError or TimeoutError.
+        """
+        d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+        self.pump()
+
+        # Nothing happened yet
+        self.assertNoResult(d)
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        self.assertEqual(clients[0][0], "1.2.3.4")
+        self.assertEqual(clients[0][1], 8008)
+
+        # Deferred is still without a result
+        self.assertNoResult(d)
+
+        # Push by enough to time it out
+        self.reactor.advance(120)
+        f = self.failureResultOf(d)
+
+        self.assertIsInstance(f.value, RequestTimedOutError)
+
+    def test_client_connect_no_response(self):
+        """
+        If the HTTP request is connected, but gets no response before being
+        timed out, it'll give a ResponseNeverReceived.
+        """
+        d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+        self.pump()
+
+        # Nothing happened yet
+        self.assertNoResult(d)
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        self.assertEqual(clients[0][0], "1.2.3.4")
+        self.assertEqual(clients[0][1], 8008)
+
+        conn = Mock()
+        client = clients[0][2].buildProtocol(None)
+        client.makeConnection(conn)
+
+        # Deferred is still without a result
+        self.assertNoResult(d)
+
+        # Push by enough to time it out
+        self.reactor.advance(120)
+        f = self.failureResultOf(d)
+
+        self.assertIsInstance(f.value, RequestTimedOutError)
+
+    def test_client_ip_range_blacklist(self):
+        """Ensure that Synapse does not try to connect to blacklisted IPs"""
+
+        # Add some DNS entries we'll blacklist
+        self.reactor.lookups["internal"] = "127.0.0.1"
+        self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
+        ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"])
+
+        cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist)
+
+        # Try making a GET request to a blacklisted IPv4 address
+        # ------------------------------------------------------
+        # Make the request
+        d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar"))
+        self.pump(1)
+
+        # Check that it was unable to resolve the address
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 0)
+
+        self.failureResultOf(d, DNSLookupError)
+
+        # Try making a POST request to a blacklisted IPv6 address
+        # -------------------------------------------------------
+        # Make the request
+        d = defer.ensureDeferred(
+            cl.post_json_get_json("http://internalv6:8008/foo/bar", {})
+        )
+
+        # Move the reactor forwards
+        self.pump(1)
+
+        # Check that it was unable to resolve the address
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 0)
+
+        # Check that it was due to a blacklisted DNS lookup
+        self.failureResultOf(d, DNSLookupError)
+
+        # Try making a GET request to a non-blacklisted IPv4 address
+        # ----------------------------------------------------------
+        # Make the request
+        d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar"))
+
+        # Nothing has happened yet
+        self.assertNoResult(d)
+
+        # Move the reactor forwards
+        self.pump(1)
+
+        # Check that it was able to resolve the address
+        clients = self.reactor.tcpClients
+        self.assertNotEqual(len(clients), 0)
+
+        # Connection will still fail as this IP address does not resolve to anything
+        self.failureResultOf(d, RequestTimedOutError)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 93f899d861..ae2cd67f35 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -732,6 +732,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
     def test_next_link_domain_whitelist(self):
         """Tests next_link parameters must fit the whitelist if provided"""
+
+        # Ensure not providing a next_link parameter still works
+        self._request_token(
+            "something@example.com", "some_secret", next_link=None, expect_code=200,
+        )
+
         self._request_token(
             "something@example.com",
             "some_secret",
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 643072bbaf..8d97b6d4cd 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -137,6 +137,21 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         count = self.get_success(self.store.get_monthly_active_count())
         self.assertEqual(count, 1)
 
+    def test_appservice_user_not_counted_in_mau(self):
+        self.get_success(
+            self.store.register_user(
+                user_id="@appservice_user:server", appservice_id="wibble"
+            )
+        )
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, 0)
+
+        d = self.store.upsert_monthly_active_user("@appservice_user:server")
+        self.get_success(d)
+
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, 0)
+
     def test_user_last_seen_monthly_active(self):
         user_id1 = "@user1:server"
         user_id2 = "@user2:server"
@@ -383,7 +398,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
 
         count = self.get_success(self.store.get_monthly_active_count())
-        self.assertEqual(count, 4)
+        self.assertEqual(count, 1)
 
         d = self.store.get_monthly_active_count_by_service()
         result = self.get_success(d)