diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index bcf1a8010e..279c94a03d 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -16,8 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
-
import jsonschema
from twisted.internet import defer
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from tests import unittest
-from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
+from tests.utils import setup_test_homeserver
user_localpart = "test_user"
@@ -42,21 +40,9 @@ def MockEvent(**kwargs):
class FilteringTestCase(unittest.TestCase):
- @defer.inlineCallbacks
def setUp(self):
- self.mock_federation_resource = MockHttpResource()
-
- self.mock_http_client = Mock(spec=[])
- self.mock_http_client.put_json = DeferredMockCallable()
-
- hs = yield setup_test_homeserver(
- self.addCleanup,
- federation_http_client=self.mock_http_client,
- keyring=Mock(),
- )
-
+ hs = setup_test_homeserver(self.addCleanup)
self.filtering = hs.get_filtering()
-
self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..1d99a45436 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
from mock import Mock, patch
-import attr
import pymacaroons
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
+from tests.test_utils import FakeResponse
from tests.unittest import HomeserverTestCase, override_config
-
-@attr.s
-class FakeResponse:
- code = attr.ib()
- body = attr.ib()
- phrase = attr.ib()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
-
-
# These are a few constants that are used as config parameters in the tests.
ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 300a625972..50955ade97 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
- federation_http_client=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index d979aadcd6..0535592b60 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -580,7 +580,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Mock Synapse's http json post method to check for the internal bind call
post_json_get_json = Mock(return_value=make_awaitable(None))
- self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
+ self.hs.get_identity_handler().http_client.post_json_get_json = post_json_get_json
# Retrieve a UIA session ID
channel = self.uia_register(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 36086ca836..f21de958f1 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
import json
+from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.web.resource import Resource
from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
-from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- servlets = [register_federation_servlets]
-
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return hs
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
+
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 870850d1ec..bdfb957a08 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
+from synapse.push import PusherConfigException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import receipts
@@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
+ def default_config(self):
+ config = super().default_config()
+ config["start_pushers"] = True
+ return config
+
def make_homeserver(self, reactor, clock):
self.push_attempts = []
@@ -46,14 +52,48 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
- config = self.default_config()
- config["start_pushers"] = True
+ hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
+
+ return hs
- hs = self.setup_test_homeserver(
- config=config, proxied_blacklisted_http_client=m
+ def test_invalid_configuration(self):
+ """Invalid push configurations should be rejected."""
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
)
+ token_id = user_tuple.token_id
- return hs
+ def test_data(data):
+ self.get_failure(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data=data,
+ ),
+ PusherConfigException,
+ )
+
+ # Data must be provided with a URL.
+ test_data(None)
+ test_data({})
+ test_data({"url": 1})
+ # A bare domain name isn't accepted.
+ test_data({"url": "example.com"})
+ # A URL without a path isn't accepted.
+ test_data({"url": "http://example.com"})
+ # A url with an incorrect path isn't accepted.
+ test_data({"url": "http://example.com/foo"})
def test_sends_http(self):
"""
@@ -84,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -119,7 +159,9 @@ class HTTPPusherTests(HomeserverTestCase):
# One push was attempted to be sent -- it'll be the first message
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(
self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
)
@@ -139,7 +181,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Now it'll try and send the second push message, which will be the second one
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(
self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
)
@@ -196,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -232,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Add yet another person — we want to make this room not a 1:1
@@ -270,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_one_to_one_only(self):
@@ -312,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -328,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority — this is a one-to-one room
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Yet another user joins
@@ -347,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -394,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -410,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time with no mention
@@ -419,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -467,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -487,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time as someone without the power of @room
@@ -498,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
@@ -572,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -591,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
# Check that the unread count for the room is 0
#
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 728de28277..3379189785 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- servlets = [
- streams.register_servlets,
- ]
-
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..fe9e4d5f9a
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,119 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Tuple
+
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+ """Test the authentication of HTTP calls between workers."""
+
+ servlets = [register.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ # This isn't a real configuration option but is used to provide the main
+ # homeserver and worker homeserver different options.
+ main_replication_secret = config.pop("main_replication_secret", None)
+ if main_replication_secret:
+ config["worker_replication_secret"] = main_replication_secret
+ return self.setup_test_homeserver(config=config)
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+
+ return config
+
+ def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]:
+ """Run the actual test:
+
+ 1. Create a worker homeserver.
+ 2. Start registration by providing a user/password.
+ 3. Complete registration by providing dummy auth (this hits the main synapse).
+ 4. Return the final request.
+
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
+
+ request_1, channel_1 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ ) # type: SynapseRequest, FakeChannel
+ self.assertEqual(request_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ return make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
+ )
+
+ def test_no_auth(self):
+ """With no authentication the request should finish.
+ """
+ request, channel = self._test_register()
+ self.assertEqual(request.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ @override_config({"main_replication_secret": "my-secret"})
+ def test_missing_auth(self):
+ """If the main process expects a secret that is not provided, an error results.
+ """
+ request, channel = self._test_register()
+ self.assertEqual(request.code, 500)
+
+ @override_config(
+ {
+ "main_replication_secret": "my-secret",
+ "worker_replication_secret": "wrong-secret",
+ }
+ )
+ def test_unauthorized(self):
+ """If the main process receives the wrong secret, an error results.
+ """
+ request, channel = self._test_register()
+ self.assertEqual(request.code, 500)
+
+ @override_config({"worker_replication_secret": "my-secret"})
+ def test_authorized(self):
+ """The request should finish when the worker provides the authentication header.
+ """
+ request, channel = self._test_register()
+ self.assertEqual(request.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..fdaad3d8ad 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,20 @@
# limitations under the License.
import logging
-from synapse.api.constants import LoginType
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
- """Base class for tests of the replication streams"""
+ """Test using one or more client readers for registration."""
servlets = [register.register_servlets]
- def prepare(self, reactor, clock, hs):
- self.recaptcha_checker = DummyRecaptchaChecker(hs)
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index f894bcd6e7..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "https://push.example.com/push"},
+ data={"url": "https://push.example.com/_matrix/push/v1/notify"},
)
)
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 54d46f4bd3..ba1438cdc7 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -561,7 +561,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- "avatar_url": None,
+ "avatar_url": "mxc://fibble/wibble",
}
)
@@ -578,6 +578,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
request, channel = self.make_request(
@@ -592,6 +593,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(True, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self):
"""
@@ -606,6 +608,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": False,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "avatar_url": "mxc://fibble/wibble",
}
)
@@ -622,6 +625,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
request, channel = self.make_request(
@@ -636,6 +640,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(False, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1256,7 +1261,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "https://example.com/_matrix/push/v1/notify"},
)
)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..5a18af8d34 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,17 +17,23 @@
# limitations under the License.
import json
+import re
import time
+import urllib.parse
from typing import Any, Dict, Optional
+from mock import patch
+
import attr
from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.types import JsonDict
from tests.server import FakeSite, make_request
+from tests.test_utils import FakeResponse
@attr.s
@@ -344,3 +350,111 @@ class RestHelper:
)
return channel.json_body
+
+ def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+ """Log in (as a new user) via OIDC
+
+ Returns the result of the final token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+ client_redirect_url = "https://x"
+
+ # first hit the redirect url (which will issue a cookie and state)
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "/login/sso/redirect?redirectUrl=" + client_redirect_url,
+ )
+ # that will redirect to the OIDC IdP, but we skip that and go straight
+ # back to synapse's OIDC callback resource. However, we do need the "state"
+ # param that synapse passes to the IdP via query params, and the cookie that
+ # synapse passes to the client.
+ assert channel.code == 302
+ oauth_uri = channel.headers.getRawHeaders("Location")[0]
+ params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
+ redirect_uri = "%s?%s" % (
+ urllib.parse.urlparse(params["redirect_uri"][0]).path,
+ urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+ )
+ cookies = {}
+ for h in channel.headers.getRawHeaders("Set-Cookie"):
+ parts = h.split(";")
+ k, v = parts[0].split("=", maxsplit=1)
+ cookies[k] = v
+
+ # before we hit the callback uri, stub out some methods in the http client so
+ # that we don't have to handle full HTTPS requests.
+
+ # (expected url, json response) pairs, in the order we expect them.
+ expected_requests = [
+ # first we get a hit to the token endpoint, which we tell to return
+ # a dummy OIDC access token
+ ("https://issuer.test/token", {"access_token": "TEST"}),
+ # and then one to the user_info endpoint, which returns our remote user id.
+ ("https://issuer.test/userinfo", {"sub": remote_user_id}),
+ ]
+
+ async def mock_req(method: str, uri: str, data=None, headers=None):
+ (expected_uri, resp_obj) = expected_requests.pop(0)
+ assert uri == expected_uri
+ resp = FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+ )
+ return resp
+
+ with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ # now hit the callback URI with the right params and a made-up code
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ redirect_uri,
+ custom_headers=[
+ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+ ],
+ )
+
+ # expect a confirmation page
+ assert channel.code == 200
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.result["body"].decode("utf-8"),
+ )
+ assert m
+ login_token = m.group(1)
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token and device id.
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ assert channel.code == 200
+ return channel.json_body
+
+
+# an 'oidc_config' suitable for login_with_oidc.
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "discover": False,
+ "issuer": "https://issuer.test",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": "https://z",
+ "token_endpoint": "https://issuer.test/token",
+ "userinfo_endpoint": "https://issuer.test/userinfo",
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..ac67a9de29 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import List, Union
from twisted.internet.defer import succeed
@@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.oidc import OIDCResource
+from synapse.types import JsonDict, UserID
from tests import unittest
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
@@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]
+ def default_config(self):
+ config = super().default_config()
+
+ # we enable OIDC as a way of testing SSO flows
+ oidc_config = {}
+ oidc_config.update(TEST_OIDC_CONFIG)
+ oidc_config["allow_existing_users"] = True
+
+ config["oidc_config"] = oidc_config
+ config["public_baseurl"] = "https://synapse.test"
+ return config
+
+ def create_resource_dict(self):
+ resource_dict = super().create_resource_dict()
+ # mount the OIDC resource at /_synapse/oidc
+ resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+ return resource_dict
+
def prepare(self, reactor, clock, hs):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass)
- def get_device_ids(self) -> List[str]:
+ def get_device_ids(self, access_token: str) -> List[str]:
# Get the list of devices so one can be deleted.
- request, channel = self.make_request(
- "GET", "devices", access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
-
- # Get the ID of the device.
- self.assertEqual(request.code, 200)
+ _, channel = self.make_request("GET", "devices", access_token=access_token,)
+ self.assertEqual(channel.code, 200)
return [d["device_id"] for d in channel.json_body["devices"]]
def delete_device(
- self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+ self,
+ access_token: str,
+ device: str,
+ expected_response: int,
+ body: Union[bytes, JsonDict] = b"",
) -> FakeChannel:
"""Delete an individual device."""
request, channel = self.make_request(
- "DELETE", "devices/" + device, body, access_token=self.user_tok
+ "DELETE", "devices/" + device, body, access_token=access_token,
) # type: SynapseRequest, FakeChannel
# Ensure the response is sane.
@@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""
Test user interactive authentication outside of registration.
"""
- device_id = self.get_device_ids()[0]
+ device_id = self.get_device_ids(self.user_tok)[0]
# Attempt to delete this device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_id, 401)
+ channel = self.delete_device(self.user_tok, device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow.
self.delete_device(
+ self.user_tok,
device_id,
200,
{
@@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
UIA - check that still works.
"""
- device_id = self.get_device_ids()[0]
- channel = self.delete_device(device_id, 401)
+ device_id = self.get_device_ids(self.user_tok)[0]
+ channel = self.delete_device(self.user_tok, device_id, 401)
session = channel.json_body["session"]
# Make another request providing the UI auth flow.
self.delete_device(
+ self.user_tok,
device_id,
200,
{
@@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login.
self.login("test", self.user_pass)
- device_ids = self.get_device_ids()
+ device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device.
@@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login.
self.login("test", self.user_pass)
- device_ids = self.get_device_ids()
+ device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_ids[0], 401)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
# Grab the session
session = channel.json_body["session"]
@@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error.
self.delete_device(
+ self.user_tok,
device_ids[1],
403,
{
@@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
},
)
+
+ def test_does_not_offer_password_for_sso_user(self):
+ login_resp = self.helper.login_via_oidc("username")
+ user_tok = login_resp["access_token"]
+ device_id = login_resp["device_id"]
+
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ channel = self.delete_device(user_tok, device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+ def test_does_not_offer_sso_for_password_user(self):
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ device_ids = self.get_device_ids(self.user_tok)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+ def test_offers_both_flows_for_upgraded_user(self):
+ """A user that had a password and then logged in with SSO should get both flows
+ """
+ login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ device_ids = self.get_device_ids(self.user_tok)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+ flows = channel.json_body["flows"]
+ # we have no particular expectations of ordering here
+ self.assertIn({"stages": ["m.login.password"]}, flows)
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+ self.assertEqual(len(flows), 2)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 610b263577..d6642b34eb 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -120,6 +120,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+ self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
self.hs.config.macaroon_secret_key = "test"
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 4c749f1a61..6f0677d335 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -362,3 +362,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"error": "Not found [b'example.com', b'12345']",
},
)
+
+ def test_x_robots_tag_header(self):
+ """
+ Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+ to not index, archive, or follow links in media.
+ """
+ channel = self._req(b"inline; filename=out" + self.test_image.extension)
+
+ headers = channel.headers
+ self.assertEqual(
+ headers.getRawHeaders(b"X-Robots-Tag"),
+ [b"noindex, nofollow, noarchive, noimageindex"],
+ )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index ccdc8c2ecf..529b6bcded 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,41 +18,15 @@ import re
from mock import patch
-import attr
-
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
from tests import unittest
from tests.server import FakeTransport
-@attr.s
-class FakeResponse:
- version = attr.ib()
- code = attr.ib()
- phrase = attr.ib()
- headers = attr.ib()
- body = attr.ib()
- absoluteURI = attr.ib()
-
- @property
- def request(self):
- @attr.s
- class FakeTransport:
- absoluteURI = self.absoluteURI
-
- return FakeTransport()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
-
-
class URLPreviewTests(unittest.HomeserverTestCase):
hijack_auth = True
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..4faf32e335 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -216,8 +216,9 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
+ if path.startswith(b"/"):
+ path = path[1:]
path = b"/_matrix/client/r0/" + path
- path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path
@@ -258,6 +259,7 @@ def make_request(
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
+ req.parseCookies()
req.requestReceived(method, path, b"1.1")
if await_result:
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..09f4f32a02 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+ _get_auth_chain_difference,
+ lexicographical_topological_sort,
+ resolve_events_with_store,
+)
from synapse.types import EventID
from tests import unittest
@@ -587,6 +591,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
self.assert_dict(self.expected_combined_state, state)
+class AuthChainDifferenceTestCase(unittest.TestCase):
+ """We test that `_get_auth_chain_difference` correctly handles unpersisted
+ events.
+ """
+
+ def test_simple(self):
+ # Test getting the auth difference for a simple chain with a single
+ # unpersisted event:
+ #
+ # Unpersisted | Persisted
+ # |
+ # C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c}
+
+ state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {c.event_id})
+
+ def test_multiple_unpersisted_chain(self):
+ # Test getting the auth difference for a simple chain with multiple
+ # unpersisted events:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D -> C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, c.event_id})
+
+ def test_unpersisted_events_different_sets(self):
+ # Test getting the auth difference for with multiple unpersisted events
+ # in different branches:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D --> C -|-> B -> A
+ # E ----^ -|---^
+ # |
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ e = FakeEvent(
+ id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id, b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, e.event_id})
+
+
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
@@ -647,7 +779,7 @@ class TestStateResolutionStore:
return list(result)
- def get_auth_chain_difference(self, auth_sets):
+ def get_auth_chain_difference(self, room_id, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..482506d731 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -202,34 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Now actually test that various combinations give the right result:
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
)
self.assertSetEqual(difference, {"a", "b"})
- difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}])
+ )
self.assertSetEqual(difference, set())
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index fd0add5db3..a6303bf0ee 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
-
from canonicaljson import json
from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["redaction_retention_period"] = "30d"
- return self.setup_test_homeserver(
- resource_for_federation=Mock(), federation_http_client=None, config=config
- )
+ return config
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5ba1db2332..d2aed66f6d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
-
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(
- resource_for_federation=Mock(), federation_http_client=None
- )
- return hs
-
def prepare(self, reactor, clock, hs: TestHomeServer):
# We can't test the RoomMemberStore on its own without the other event
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 7f67ee9e1f..a883d707df 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -56,7 +56,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -69,7 +69,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø lies in Northern Norway. The municipality has a population of"
" (2015) 72,066, but with an annual influx of students it has over 75,000"
@@ -96,7 +96,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -122,7 +122,7 @@ class PreviewTestCase(unittest.TestCase):
]
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -149,7 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment(self):
html = """
@@ -164,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment2(self):
html = """
@@ -182,7 +182,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(
+ self.assertEqual(
og,
{
"og:title": "Foo",
@@ -203,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_missing_title(self):
html = """
@@ -216,7 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_h1_as_title(self):
html = """
@@ -230,7 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_missing_title_and_broken_h1(self):
html = """
@@ -244,4 +244,9 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+
+ def test_empty(self):
+ html = ""
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {})
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..6873d45eb6 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,11 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
TV = TypeVar("TV")
@@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
sys.unraisablehook = unraisablehook # type: ignore
return cleanup
+
+
+@attr.s
+class FakeResponse:
+ """A fake twisted.web.IResponse object
+
+ there is a similar class at treq.test.test_response, but it lacks a `phrase`
+ attribute, and didn't support deliverBody until recently.
+ """
+
+ # HTTP response code
+ code = attr.ib(type=int)
+
+ # HTTP response phrase (eg b'OK' for a 200)
+ phrase = attr.ib(type=bytes)
+
+ # body of the response
+ body = attr.ib(type=bytes)
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..102b0a1f34 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch
@@ -46,6 +46,7 @@ from synapse.logging.context import (
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
"""
Create a the root resource for the test server.
- The default implementation creates a JsonResource and calls each function in
- `servlets` to register servletes against it
+ The default calls `self.create_resource_dict` and builds the resultant dict
+ into a tree.
"""
- resource = JsonResource(self.hs)
+ root_resource = Resource()
+ create_resource_tree(self.create_resource_dict(), root_resource)
+ return root_resource
- for servlet in self.servlets:
- servlet(self.hs, resource)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ """Create a resource tree for the test server
- return resource
+ A resource tree is a mapping from path to twisted.web.resource.
+
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servlets against it.
+ """
+ servlet_resource = JsonResource(self.hs)
+ for servlet in self.servlets:
+ servlet(self.hs, servlet_resource)
+ return {
+ "/_matrix/client": servlet_resource,
+ "/_synapse/admin": servlet_resource,
+ }
def default_config(self):
"""
@@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
- def prepare(self, reactor, clock, homeserver):
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+ return d
+
+
+class TestTransportLayerServer(JsonResource):
+ """A test implementation of TransportLayerServer
+
+ authenticates incoming requests as `other.example.com`.
+ """
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
+ authenticator = Authenticator()
+
ratelimiter = FederationRateLimiter(
- clock,
+ hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
- federation_server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
- return super().prepare(reactor, clock, homeserver)
+ federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):
diff --git a/tests/utils.py b/tests/utils.py
index 1584eacb12..a3b01575dd 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,13 +20,12 @@ import os
import time
import uuid
import warnings
-from inspect import getcallargs
from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
@@ -34,7 +33,6 @@ from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
@@ -42,7 +40,6 @@ from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
-from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
#
@@ -344,32 +341,9 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
- fed = kwargs.get("resource_for_federation", None)
- if fed:
- register_federation_servlets(hs, fed)
-
return hs
-def register_federation_servlets(hs, resource):
- federation_server.register_servlets(
- hs,
- resource=resource,
- authenticator=federation_server.Authenticator(hs),
- ratelimiter=FederationRateLimiter(
- hs.get_clock(), config=hs.config.rc_federation
- ),
- )
-
-
-def get_mock_call_args(pattern_func, mock_func):
- """ Return the arguments the mock function was called with interpreted
- by the pattern functions argument list.
- """
- invoked_args, invoked_kargs = mock_func.call_args
- return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
-
-
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
@@ -555,86 +529,6 @@ class MockClock:
return d
-def _format_call(args, kwargs):
- return ", ".join(
- ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
- )
-
-
-class DeferredMockCallable:
- """A callable instance that stores a set of pending call expectations and
- return values for them. It allows a unit test to assert that the given set
- of function calls are eventually made, by awaiting on them to be called.
- """
-
- def __init__(self):
- self.expectations = []
- self.calls = []
-
- def __call__(self, *args, **kwargs):
- self.calls.append((args, kwargs))
-
- if not self.expectations:
- raise ValueError(
- "%r has no pending calls to handle call(%s)"
- % (self, _format_call(args, kwargs))
- )
-
- for (call, result, d) in self.expectations:
- if args == call[1] and kwargs == call[2]:
- d.callback(None)
- return result
-
- failure = AssertionError(
- "Was not expecting call(%s)" % (_format_call(args, kwargs))
- )
-
- for _, _, d in self.expectations:
- try:
- d.errback(failure)
- except Exception:
- pass
-
- raise failure
-
- def expect_call_and_return(self, call, result):
- self.expectations.append((call, result, defer.Deferred()))
-
- @defer.inlineCallbacks
- def await_calls(self, timeout=1000):
- deferred = defer.DeferredList(
- [d for _, _, d in self.expectations], fireOnOneErrback=True
- )
-
- timer = reactor.callLater(
- timeout / 1000,
- deferred.errback,
- AssertionError(
- "%d pending calls left: %s"
- % (
- len([e for e in self.expectations if not e[2].called]),
- [e for e in self.expectations if not e[2].called],
- )
- ),
- )
-
- yield deferred
-
- timer.cancel()
-
- self.calls = []
-
- def assert_had_no_calls(self):
- if self.calls:
- calls = self.calls
- self.calls = []
-
- raise AssertionError(
- "Expected not to received any calls, got:\n"
- + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
- )
-
-
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
"""
|