summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_catch_up.py3
-rw-r--r--tests/handlers/test_register.py31
-rw-r--r--tests/http/test_client.py126
-rw-r--r--tests/replication/_base.py44
-rw-r--r--tests/replication/test_federation_ack.py8
-rw-r--r--tests/rest/client/v1/test_login.py43
-rw-r--r--tests/server.py2
-rw-r--r--tests/storage/test_event_federation.py76
8 files changed, 280 insertions, 53 deletions
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py

index 1a3ccb263d..6f96cd7940 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py
@@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable from tests.unittest import FederatingHomeserverTestCase, override_config @@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): else: data = json_cb() self.failed_pdus.extend(data["pdus"]) - raise IOError("Failed to connect because this is a test!") + raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination) def get_destination_room(self, room: str, destination: str = "host2") -> dict: """ diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..94b6903594 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(requester.shadow_banned) + def test_spam_checker_receives_sso_type(self): + """Test rejecting registration based on SSO type""" + + class BanBadIdPUser: + def check_registration_for_spam( + self, email_threepid, username, request_info, auth_provider_id=None + ): + # Reject any user coming from CAS and whose username contains profanity + if auth_provider_id == "cas" and "flimflob" in username: + return RegistrationBehaviour.DENY + return RegistrationBehaviour.ALLOW + + # Configure a spam checker that denies a certain user on a specific IdP + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanBadIdPUser()] + + f = self.get_failure( + self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), + SynapseError, + ) + exception = f.value + + # We return 429 from the spam checker for denied registrations + self.assertIsInstance(exception, SynapseError) + self.assertEqual(exception.code, 429) + + # Check the same username can register using SAML + self.get_success( + self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml") + ) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 21ecb81c99..0ce181a51e 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py
@@ -16,12 +16,23 @@ from io import BytesIO from mock import Mock +from netaddr import IPSet + +from twisted.internet.error import DNSLookupError from twisted.python.failure import Failure -from twisted.web.client import ResponseDone +from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.web.client import Agent, ResponseDone from twisted.web.iweb import UNKNOWN_LENGTH -from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size +from synapse.api.errors import SynapseError +from synapse.http.client import ( + BlacklistingAgentWrapper, + BlacklistingReactorWrapper, + BodyExceededMaxSize, + read_body_with_max_size, +) +from tests.server import FakeTransport, get_clock from tests.unittest import TestCase @@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase): # The data is never consumed. self.assertEqual(result.getvalue(), b"") + + +class BlacklistingAgentTest(TestCase): + def setUp(self): + self.reactor, self.clock = get_clock() + + self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" + self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8" + self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" + + # Configure the reactor's DNS resolver. + for (domain, ip) in ( + (self.safe_domain, self.safe_ip), + (self.unsafe_domain, self.unsafe_ip), + (self.allowed_domain, self.allowed_ip), + ): + self.reactor.lookups[domain.decode()] = ip.decode() + self.reactor.lookups[ip.decode()] = ip.decode() + + self.ip_whitelist = IPSet([self.allowed_ip.decode()]) + self.ip_blacklist = IPSet(["5.0.0.0/8"]) + + def test_reactor(self): + """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" + agent = Agent( + BlacklistingReactorWrapper( + self.reactor, + ip_whitelist=self.ip_whitelist, + ip_blacklist=self.ip_blacklist, + ), + ) + + # The unsafe domains and IPs should be rejected. + for domain in (self.unsafe_domain, self.unsafe_ip): + self.failureResultOf( + agent.request(b"GET", b"http://" + domain), DNSLookupError + ) + + # The safe domains IPs should be accepted. + for domain in ( + self.safe_domain, + self.allowed_domain, + self.safe_ip, + self.allowed_ip, + ): + d = agent.request(b"GET", b"http://" + domain) + + # Grab the latest TCP connection. + ( + host, + port, + client_factory, + _timeout, + _bindAddress, + ) = self.reactor.tcpClients[-1] + + # Make the connection and pump data through it. + client = client_factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" + ) + + response = self.successResultOf(d) + self.assertEqual(response.code, 200) + + def test_agent(self): + """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" + agent = BlacklistingAgentWrapper( + Agent(self.reactor), + ip_whitelist=self.ip_whitelist, + ip_blacklist=self.ip_blacklist, + ) + + # The unsafe IPs should be rejected. + self.failureResultOf( + agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError + ) + + # The safe and unsafe domains and safe IPs should be accepted. + for domain in ( + self.safe_domain, + self.unsafe_domain, + self.allowed_domain, + self.safe_ip, + self.allowed_ip, + ): + d = agent.request(b"GET", b"http://" + domain) + + # Grab the latest TCP connection. + ( + host, + port, + client_factory, + _timeout, + _bindAddress, + ) = self.reactor.tcpClients[-1] + + # Make the connection and pump data through it. + client = client_factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" + ) + + response = self.successResultOf(d) + self.assertEqual(response.code, 200) diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 20940c8107..67b7913666 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Tuple - -import attr +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol @@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Set up client side protocol client_protocol = client_factory.buildProtocol(None) - request_factory = OneShotRequestFactory() - # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor, request_factory, self.site) + channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server_to_client_transport.loseConnection() client_to_server_transport.loseConnection() - return request_factory.request + return channel.request def assert_request_is_get_repl_stream_updates( self, request: SynapseRequest, stream_name: str @@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - "localhost", + b"localhost", 6379, self.connect_any_redis_attempts, ) @@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up client side protocol client_protocol = client_factory.buildProtocol(None) - request_factory = OneShotRequestFactory() - # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs]) + channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs]) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): clients = self.reactor.tcpClients while clients: (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") + self.assertEqual(host, b"localhost") self.assertEqual(port, 6379) client_protocol = client_factory.buildProtocol(None) @@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): self.received_rdata_rows.append((stream_name, token, r)) -@attr.s() -class OneShotRequestFactory: - """A simple request factory that generates a single `SynapseRequest` and - stores it for future use. Can only be used once. - """ - - request = attr.ib(default=None) - - def __call__(self, *args, **kwargs): - assert self.request is None - - self.request = SynapseRequest(*args, **kwargs) - return self.request - - class _PushHTTPChannel(HTTPChannel): """A HTTPChannel that wraps pull producers to push producers. @@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel): """ def __init__( - self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site + self, reactor: IReactorTime, request_factory: Type[Request], site: Site ): super().__init__() self.reactor = reactor @@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel): request.responseHeaders.setRawHeaders(b"connection", [b"close"]) return False + def requestDone(self, request): + # Store the request for inspection. + self.request = request + super().requestDone(request) + class _PullToPushProducer: """A push producer that wraps a pull producer.""" @@ -597,6 +581,8 @@ class FakeRedisPubSubServer: class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" + transport = None # type: Optional[FakeTransport] + def __init__(self, server: FakeRedisPubSubServer): self._server = server self._reader = hiredis.Reader() @@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol): def send(self, msg): """Send a message back to the client.""" + assert self.transport is not None + raw = self.encode(msg).encode("utf-8") self.transport.write(raw) diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index f235f1bd83..0d9e3bb11d 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py
@@ -17,7 +17,7 @@ import mock from synapse.app.generic_worker import GenericWorkerServer from synapse.replication.tcp.commands import FederationAckCommand -from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.protocol import IReplicationConnection from synapse.replication.tcp.streams.federation import FederationStream from tests.unittest import HomeserverTestCase @@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase): """ rch = self.hs.get_tcp_replication() - # wire up the ReplicationCommandHandler to a mock connection - mock_connection = mock.Mock(spec=AbstractConnection) + # wire up the ReplicationCommandHandler to a mock connection, which needs + # to implement IReplicationConnection. (Note that Mock doesn't understand + # interfaces, but casing an interface to a list gives the attributes.) + mock_connection = mock.Mock(spec=list(IReplicationConnection)) rch.new_connection(mock_connection) # tell it it received an RDATA row diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 20af3285bd..988821b16f 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) - expected_flows = [ - {"type": "m.login.cas"}, - {"type": "m.login.sso"}, - {"type": "m.login.token"}, - {"type": "m.login.password"}, - ] + ADDITIONAL_LOGIN_FLOWS + expected_flow_types = [ + "m.login.cas", + "m.login.sso", + "m.login.token", + "m.login.password", + ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] - self.assertCountEqual(channel.json_body["flows"], expected_flows) + self.assertCountEqual( + [f["type"] for f in channel.json_body["flows"]], expected_flow_types + ) @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_get_msc2858_login_flows(self): @@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_msc2858_disabled(self): - """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_unknown(self): """If the client tries to pick an unknown IdP, return a 404""" - channel = self._make_sso_redirect_request(True, "xxx") + channel = self._make_sso_redirect_request(False, "xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_oidc(self): """If the client pick a known IdP, redirect to it""" + channel = self._make_sso_redirect_request(False, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_msc2858_redirect_to_oidc(self): + """Test the unstable API""" channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 302, channel.result) oidc_uri = channel.headers.getRawHeaders("Location")[0] @@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + def _make_sso_redirect_request( self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None ): diff --git a/tests/server.py b/tests/server.py
index 863f6da738..2287d20076 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -16,6 +16,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IReactorTCP, IResolverSimple, + ITransport, ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock @@ -467,6 +468,7 @@ def get_clock(): return clock, hs_clock +@implementer(ITransport) @attr.s(cmp=False) class FakeTransport: """ diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 06000f81a6..d597d712d6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py
@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - @parameterized.expand([(True,), (False,)]) - def test_auth_difference(self, use_chain_cover_index: bool): + def _setup_auth_chain(self, use_chain_cover_index: bool) -> str: room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "j": 1, } - # Mark the room as not having a cover index + # Mark the room as maybe having a cover index. def store_room(txn): self.store.db_pool.simple_insert_txn( @@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) + return room_id + + @parameterized.expand([(True,), (False,)]) + def test_auth_chain_ids(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + + # a and b have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["a", "b"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + # d and e have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"])) + self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) + self.assertEqual(auth_chain_ids, ["k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) + self.assertEqual(auth_chain_ids, ["j"]) + + # j and k have no parents. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) + self.assertEqual(auth_chain_ids, []) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) + self.assertEqual(auth_chain_ids, []) + + # More complex input sequences. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "c", "d"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["h", "i"]) + ) + self.assertCountEqual(auth_chain_ids, ["k", "j"]) + + # e gets returned even though include_given is false, but it is in the + # auth chain of b. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "e"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + # Test include_given. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["i"], include_given=True) + ) + self.assertCountEqual(auth_chain_ids, ["i", "j"]) + + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + # Now actually test that various combinations give the right result: difference = self.get_success(