diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 1a3ccb263d..6f96cd7940 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
+from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
from tests.unittest import FederatingHomeserverTestCase, override_config
@@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
else:
data = json_cb()
self.failed_pdus.extend(data["pdus"])
- raise IOError("Failed to connect because this is a test!")
+ raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination)
def get_destination_room(self, room: str, destination: str = "host2") -> dict:
"""
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..94b6903594 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ def test_spam_checker_receives_sso_type(self):
+ """Test rejecting registration based on SSO type"""
+
+ class BanBadIdPUser:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+ # Configure a spam checker that denies a certain user on a specific IdP
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanBadIdPUser()]
+
+ f = self.get_failure(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
+ SynapseError,
+ )
+ exception = f.value
+
+ # We return 429 from the spam checker for denied registrations
+ self.assertIsInstance(exception, SynapseError)
+ self.assertEqual(exception.code, 429)
+
+ # Check the same username can register using SAML
+ self.get_success(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
+ )
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 21ecb81c99..0ce181a51e 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -16,12 +16,23 @@ from io import BytesIO
from mock import Mock
+from netaddr import IPSet
+
+from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
-from twisted.web.client import ResponseDone
+from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH
-from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
+from synapse.api.errors import SynapseError
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ BlacklistingReactorWrapper,
+ BodyExceededMaxSize,
+ read_body_with_max_size,
+)
+from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
@@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")
+
+
+class BlacklistingAgentTest(TestCase):
+ def setUp(self):
+ self.reactor, self.clock = get_clock()
+
+ self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
+ self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
+ self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
+
+ # Configure the reactor's DNS resolver.
+ for (domain, ip) in (
+ (self.safe_domain, self.safe_ip),
+ (self.unsafe_domain, self.unsafe_ip),
+ (self.allowed_domain, self.allowed_ip),
+ ):
+ self.reactor.lookups[domain.decode()] = ip.decode()
+ self.reactor.lookups[ip.decode()] = ip.decode()
+
+ self.ip_whitelist = IPSet([self.allowed_ip.decode()])
+ self.ip_blacklist = IPSet(["5.0.0.0/8"])
+
+ def test_reactor(self):
+ """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
+ agent = Agent(
+ BlacklistingReactorWrapper(
+ self.reactor,
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ ),
+ )
+
+ # The unsafe domains and IPs should be rejected.
+ for domain in (self.unsafe_domain, self.unsafe_ip):
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + domain), DNSLookupError
+ )
+
+ # The safe domains IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
+
+ def test_agent(self):
+ """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
+ agent = BlacklistingAgentWrapper(
+ Agent(self.reactor),
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ )
+
+ # The unsafe IPs should be rejected.
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
+ )
+
+ # The safe and unsafe domains and safe IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.unsafe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 20940c8107..67b7913666 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
-
-import attr
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return request_factory.request
+ return channel.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- "localhost",
+ b"localhost",
6379,
self.connect_any_redis_attempts,
)
@@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
+ self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
@@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))
-@attr.s()
-class OneShotRequestFactory:
- """A simple request factory that generates a single `SynapseRequest` and
- stores it for future use. Can only be used once.
- """
-
- request = attr.ib(default=None)
-
- def __call__(self, *args, **kwargs):
- assert self.request is None
-
- self.request = SynapseRequest(*args, **kwargs)
- return self.request
-
-
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
@@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
"""
def __init__(
- self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+ self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
@@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
+ def requestDone(self, request):
+ # Store the request for inspection.
+ self.request = request
+ super().requestDone(request)
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
@@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
+ transport = None # type: Optional[FakeTransport]
+
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
@@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
def send(self, msg):
"""Send a message back to the client."""
+ assert self.transport is not None
+
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index f235f1bd83..0d9e3bb11d 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -17,7 +17,7 @@ import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from tests.unittest import HomeserverTestCase
@@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
"""
rch = self.hs.get_tcp_replication()
- # wire up the ReplicationCommandHandler to a mock connection
- mock_connection = mock.Mock(spec=AbstractConnection)
+ # wire up the ReplicationCommandHandler to a mock connection, which needs
+ # to implement IReplicationConnection. (Note that Mock doesn't understand
+ # interfaces, but casing an interface to a list gives the attributes.)
+ mock_connection = mock.Mock(spec=list(IReplicationConnection))
rch.new_connection(mock_connection)
# tell it it received an RDATA row
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 20af3285bd..988821b16f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
- expected_flows = [
- {"type": "m.login.cas"},
- {"type": "m.login.sso"},
- {"type": "m.login.token"},
- {"type": "m.login.password"},
- ] + ADDITIONAL_LOGIN_FLOWS
+ expected_flow_types = [
+ "m.login.cas",
+ "m.login.sso",
+ "m.login.token",
+ "m.login.password",
+ ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
- self.assertCountEqual(channel.json_body["flows"], expected_flows)
+ self.assertCountEqual(
+ [f["type"] for f in channel.json_body["flows"]], expected_flow_types
+ )
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_get_msc2858_login_flows(self):
@@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
- def test_client_idp_redirect_msc2858_disabled(self):
- """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
- channel = self._make_sso_redirect_request(True, "oidc")
- self.assertEqual(channel.code, 400, channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
-
- @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404"""
- channel = self._make_sso_redirect_request(True, "xxx")
+ channel = self._make_sso_redirect_request(False, "xxx")
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
- @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it"""
+ channel = self._make_sso_redirect_request(False, "oidc")
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_msc2858_redirect_to_oidc(self):
+ """Test the unstable API"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
@@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+ def test_client_idp_redirect_msc2858_disabled(self):
+ """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
+ channel = self._make_sso_redirect_request(True, "oidc")
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
):
diff --git a/tests/server.py b/tests/server.py
index 863f6da738..2287d20076 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
+ ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock
+@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 06000f81a6..d597d712d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- @parameterized.expand([(True,), (False,)])
- def test_auth_difference(self, use_chain_cover_index: bool):
+ def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
- # Mark the room as not having a cover index
+ # Mark the room as maybe having a cover index.
def store_room(txn):
self.store.db_pool.simple_insert_txn(
@@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
)
+ return room_id
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
+ # a and b have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["a", "b"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ # d and e have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
+ self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
+ self.assertEqual(auth_chain_ids, ["k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
+ self.assertEqual(auth_chain_ids, ["j"])
+
+ # j and k have no parents.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
+ self.assertEqual(auth_chain_ids, [])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
+ self.assertEqual(auth_chain_ids, [])
+
+ # More complex input sequences.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["h", "i"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["k", "j"])
+
+ # e gets returned even though include_given is false, but it is in the
+ # auth chain of b.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "e"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ # Test include_given.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
+ )
+ self.assertCountEqual(auth_chain_ids, ["i", "j"])
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
|