diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 5cf408f21f..8ff1460c0d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -41,6 +41,7 @@ from synapse.storage.keys import FetchKeyResult
from tests import unittest
from tests.test_utils import make_awaitable
+from tests.unittest import logcontext_clean
class MockPerspectiveServer:
@@ -67,6 +68,7 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key)
+@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
def check_context(self, val, expected):
self.assertEquals(getattr(current_context(), "request", None), expected)
@@ -309,6 +311,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher2.get_keys.assert_called_once()
+@logcontext_clean
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 5604af3795..212484a7fe 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -318,14 +318,14 @@ class FederationClientTests(HomeserverTestCase):
r = self.successResultOf(d)
self.assertEqual(r.code, 200)
- def test_client_headers_no_body(self):
+ @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
+ def test_timeout_reading_body(self, method_name: str):
"""
If the HTTP request is connected, but gets no response before being
- timed out, it'll give a ResponseNeverReceived.
+ timed out, it'll give a RequestSendFailed with can_retry.
"""
- d = defer.ensureDeferred(
- self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
- )
+ method = getattr(self.cl, method_name)
+ d = defer.ensureDeferred(method("testserv:8008", "foo/bar", timeout=10000))
self.pump()
@@ -349,7 +349,9 @@ class FederationClientTests(HomeserverTestCase):
self.reactor.advance(10.5)
f = self.failureResultOf(d)
- self.assertIsInstance(f.value, TimeoutError)
+ self.assertIsInstance(f.value, RequestSendFailed)
+ self.assertTrue(f.value.can_retry)
+ self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
def test_client_requires_trailing_slashes(self):
"""
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
new file mode 100644
index 0000000000..a1cf0862d4
--- /dev/null
+++ b/tests/http/test_simple_client.py
@@ -0,0 +1,180 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from netaddr import IPSet
+
+from twisted.internet import defer
+from twisted.internet.error import DNSLookupError
+
+from synapse.http import RequestTimedOutError
+from synapse.http.client import SimpleHttpClient
+from synapse.server import HomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SimpleHttpClientTests(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs: "HomeServer"):
+ # Add a DNS entry for a test server
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self.cl = hs.get_simple_http_client()
+
+ def test_dns_error(self):
+ """
+ If the DNS lookup returns an error, it will bubble up.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv2:8008/foo/bar"))
+ self.pump()
+
+ f = self.failureResultOf(d)
+ self.assertIsInstance(f.value, DNSLookupError)
+
+ def test_client_connection_refused(self):
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+ e = Exception("go away")
+ factory.clientConnectionFailed(None, e)
+ self.pump(0.5)
+
+ f = self.failureResultOf(d)
+
+ self.assertIs(f.value, e)
+
+ def test_client_never_connect(self):
+ """
+ If the HTTP request is not connected and is timed out, it'll give a
+ ConnectingCancelledError or TimeoutError.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], "1.2.3.4")
+ self.assertEqual(clients[0][1], 8008)
+
+ # Deferred is still without a result
+ self.assertNoResult(d)
+
+ # Push by enough to time it out
+ self.reactor.advance(120)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestTimedOutError)
+
+ def test_client_connect_no_response(self):
+ """
+ If the HTTP request is connected, but gets no response before being
+ timed out, it'll give a ResponseNeverReceived.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], "1.2.3.4")
+ self.assertEqual(clients[0][1], 8008)
+
+ conn = Mock()
+ client = clients[0][2].buildProtocol(None)
+ client.makeConnection(conn)
+
+ # Deferred is still without a result
+ self.assertNoResult(d)
+
+ # Push by enough to time it out
+ self.reactor.advance(120)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestTimedOutError)
+
+ def test_client_ip_range_blacklist(self):
+ """Ensure that Synapse does not try to connect to blacklisted IPs"""
+
+ # Add some DNS entries we'll blacklist
+ self.reactor.lookups["internal"] = "127.0.0.1"
+ self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
+ ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"])
+
+ cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist)
+
+ # Try making a GET request to a blacklisted IPv4 address
+ # ------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar"))
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ self.failureResultOf(d, DNSLookupError)
+
+ # Try making a POST request to a blacklisted IPv6 address
+ # -------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(
+ cl.post_json_get_json("http://internalv6:8008/foo/bar", {})
+ )
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ # Check that it was due to a blacklisted DNS lookup
+ self.failureResultOf(d, DNSLookupError)
+
+ # Try making a GET request to a non-blacklisted IPv4 address
+ # ----------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar"))
+
+ # Nothing has happened yet
+ self.assertNoResult(d)
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was able to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertNotEqual(len(clients), 0)
+
+ # Connection will still fail as this IP address does not resolve to anything
+ self.failureResultOf(d, RequestTimedOutError)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 93f899d861..ae2cd67f35 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -732,6 +732,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
def test_next_link_domain_whitelist(self):
"""Tests next_link parameters must fit the whitelist if provided"""
+
+ # Ensure not providing a next_link parameter still works
+ self._request_token(
+ "something@example.com", "some_secret", next_link=None, expect_code=200,
+ )
+
self._request_token(
"something@example.com",
"some_secret",
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index d4ff55fbff..4558bee7be 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
from synapse.storage.database import DatabasePool
+from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
@@ -59,7 +58,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
writers=writers,
)
- return self.get_success(self.db_pool.runWithConnection(_create))
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
"""Insert N rows as the given instance, inserting with stream IDs pulled
@@ -411,6 +410,23 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async())
self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+ def test_sequence_consistency(self):
+ """Test that we error out if the table and sequence diverges.
+ """
+
+ # Prefill with some rows
+ self._insert_row_with_id("master", 3)
+
+ # Now we add a row *without* updating the stream ID
+ def _insert(txn):
+ txn.execute("INSERT INTO foobar VALUES (26, 'master')")
+
+ self.get_success(self.db_pool.runInteraction("_insert", _insert))
+
+ # Creating the ID gen should error
+ with self.assertRaises(IncorrectDatabaseSetup):
+ self._create_id_generator("first")
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 643072bbaf..8d97b6d4cd 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -137,6 +137,21 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1)
+ def test_appservice_user_not_counted_in_mau(self):
+ self.get_success(
+ self.store.register_user(
+ user_id="@appservice_user:server", appservice_id="wibble"
+ )
+ )
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ d = self.store.upsert_monthly_active_user("@appservice_user:server")
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
@@ -383,7 +398,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, 4)
+ self.assertEqual(count, 1)
d = self.store.get_monthly_active_count_by_service()
result = self.get_success(d)
diff --git a/tests/unittest.py b/tests/unittest.py
index dabf69cff4..e654c0442d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import gc
import hashlib
import hmac
@@ -23,11 +22,12 @@ import logging
import time
from typing import Optional, Tuple, Type, TypeVar, Union
-from mock import Mock
+from mock import Mock, patch
from canonicaljson import json
from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@@ -169,6 +169,19 @@ def INFO(target):
return target
+def logcontext_clean(target):
+ """A decorator which marks the TestCase or method as 'logcontext_clean'
+
+ ... ie, any logcontext errors should cause a test failure
+ """
+
+ def logcontext_error(msg):
+ raise AssertionError("logcontext error: %s" % (msg))
+
+ patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
+ return patcher(target)
+
+
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -463,6 +476,35 @@ class HomeserverTestCase(TestCase):
self.pump()
return self.failureResultOf(d, exc)
+ def get_success_or_raise(self, d, by=0.0):
+ """Drive deferred to completion and return result or raise exception
+ on failure.
+ """
+
+ if inspect.isawaitable(d):
+ deferred = ensureDeferred(d)
+ if not isinstance(deferred, Deferred):
+ return d
+
+ results = [] # type: list
+ deferred.addBoth(results.append)
+
+ self.pump(by=by)
+
+ if not results:
+ self.fail(
+ "Success result expected on {!r}, found no result instead".format(
+ deferred
+ )
+ )
+
+ result = results[0]
+
+ if isinstance(result, Failure):
+ result.raiseException()
+
+ return result
+
def register_user(self, username, password, admin=False):
"""
Register a user. Requires the Admin API be registered.
|