diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index c98ae75974..279c94a03d 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -16,8 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
-
import jsonschema
from twisted.internet import defer
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from tests import unittest
-from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
+from tests.utils import setup_test_homeserver
user_localpart = "test_user"
@@ -42,19 +40,9 @@ def MockEvent(**kwargs):
class FilteringTestCase(unittest.TestCase):
- @defer.inlineCallbacks
def setUp(self):
- self.mock_federation_resource = MockHttpResource()
-
- self.mock_http_client = Mock(spec=[])
- self.mock_http_client.put_json = DeferredMockCallable()
-
- hs = yield setup_test_homeserver(
- self.addCleanup, http_client=self.mock_http_client, keyring=Mock(),
- )
-
+ hs = setup_test_homeserver(self.addCleanup)
self.filtering = hs.get_filtering()
-
self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 40abe9d72d..e0ca288829 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -23,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserver_to_use=GenericWorkerServer
+ federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
@@ -57,7 +57,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+ channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
@@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+ channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index ea3be95cf1..467033e201 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserver_to_use=GenericWorkerServer
+ federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
@@ -73,7 +73,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
return
raise
- _, channel = make_request(
+ channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
@@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserver_to_use=SynapseHomeServer
+ federation_http_client=None, homeserver_to_use=SynapseHomeServer
)
return hs
@@ -121,7 +121,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
return
raise
- _, channel = make_request(
+ channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
diff --git a/tests/config/test_util.py b/tests/config/test_util.py
new file mode 100644
index 0000000000..10363e3765
--- /dev/null
+++ b/tests/config/test_util.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config import ConfigError
+from synapse.config._util import validate_config
+
+from tests.unittest import TestCase
+
+
+class ValidateConfigTestCase(TestCase):
+ """Test cases for synapse.config._util.validate_config"""
+
+ def test_bad_object_in_array(self):
+ """malformed objects within an array should be validated correctly"""
+
+ # consider a structure:
+ #
+ # array_of_objs:
+ # - r: 1
+ # foo: 2
+ #
+ # - r: 2
+ # bar: 3
+ #
+ # ... where each entry must contain an "r": check that the path
+ # to the required item is correclty reported.
+
+ schema = {
+ "type": "object",
+ "properties": {
+ "array_of_objs": {
+ "type": "array",
+ "items": {"type": "object", "required": ["r"]},
+ },
+ },
+ }
+
+ with self.assertRaises(ConfigError) as c:
+ validate_config(schema, {"array_of_objs": [{}]}, ("base",))
+
+ self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"])
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 697916a019..1d65ea2f9c 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock()
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
@@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
"""
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
@@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
}
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
@@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
}
- mock_fetcher1 = keyring.KeyFetcher()
+ mock_fetcher1 = Mock()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
- mock_fetcher2 = keyring.KeyFetcher()
+ mock_fetcher2 = Mock()
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
@@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
- hs = self.setup_test_homeserver(http_client=self.http_client)
+ hs = self.setup_test_homeserver(federation_http_client=self.http_client)
return hs
def test_get_keys_from_server(self):
@@ -395,7 +395,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
}
]
- return self.setup_test_homeserver(http_client=self.http_client, config=config)
+ return self.setup_test_homeserver(
+ federation_http_client=self.http_client, config=config
+ )
def build_perspectives_response(
self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index c1274c14af..8ba36c6074 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -34,11 +34,17 @@ def MockEvent(**kwargs):
class PruneEventTestCase(unittest.TestCase):
- """ Asserts that a new event constructed with `evdict` will look like
- `matchdict` when it is redacted. """
-
def run_test(self, evdict, matchdict, **kwargs):
- self.assertEquals(
+ """
+ Asserts that a new event constructed with `evdict` will look like
+ `matchdict` when it is redacted.
+
+ Args:
+ evdict: The dictionary to build the event from.
+ matchdict: The expected resulting dictionary.
+ kwargs: Additional keyword arguments used to create the event.
+ """
+ self.assertEqual(
prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
)
@@ -55,54 +61,80 @@ class PruneEventTestCase(unittest.TestCase):
)
def test_basic_keys(self):
+ """Ensure that the keys that should be untouched are kept."""
+ # Note that some of the values below don't really make sense, but the
+ # pruning of events doesn't worry about the values of any fields (with
+ # the exception of the content field).
self.run_test(
{
+ "event_id": "$3:domain",
"type": "A",
"room_id": "!1:domain",
"sender": "@2:domain",
- "event_id": "$3:domain",
+ "state_key": "B",
+ "content": {"other_key": "foo"},
+ "hashes": "hashes",
+ "signatures": {"domain": {"algo:1": "sigs"}},
+ "depth": 4,
+ "prev_events": "prev_events",
+ "prev_state": "prev_state",
+ "auth_events": "auth_events",
"origin": "domain",
+ "origin_server_ts": 1234,
+ "membership": "join",
+ # Also include a key that should be removed.
+ "other_key": "foo",
},
{
+ "event_id": "$3:domain",
"type": "A",
"room_id": "!1:domain",
"sender": "@2:domain",
- "event_id": "$3:domain",
+ "state_key": "B",
+ "hashes": "hashes",
+ "depth": 4,
+ "prev_events": "prev_events",
+ "prev_state": "prev_state",
+ "auth_events": "auth_events",
"origin": "domain",
+ "origin_server_ts": 1234,
+ "membership": "join",
"content": {},
- "signatures": {},
+ "signatures": {"domain": {"algo:1": "sigs"}},
"unsigned": {},
},
)
- def test_unsigned_age_ts(self):
+ # As of MSC2176 we now redact the membership and prev_states keys.
self.run_test(
- {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}},
- {
- "type": "B",
- "event_id": "$test:domain",
- "content": {},
- "signatures": {},
- "unsigned": {"age_ts": 20},
- },
+ {"type": "A", "prev_state": "prev_state", "membership": "join"},
+ {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
+ room_version=RoomVersions.MSC2176,
)
+ def test_unsigned(self):
+ """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
self.run_test(
{
"type": "B",
"event_id": "$test:domain",
- "unsigned": {"other_key": "here"},
+ "unsigned": {
+ "age_ts": 20,
+ "replaces_state": "$test2:domain",
+ "other_key": "foo",
+ },
},
{
"type": "B",
"event_id": "$test:domain",
"content": {},
"signatures": {},
- "unsigned": {},
+ "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"},
},
)
def test_content(self):
+ """The content dictionary should be stripped in most cases."""
self.run_test(
{"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
{
@@ -114,11 +146,35 @@ class PruneEventTestCase(unittest.TestCase):
},
)
+ # Some events keep a single content key/value.
+ EVENT_KEEP_CONTENT_KEYS = [
+ ("member", "membership", "join"),
+ ("join_rules", "join_rule", "invite"),
+ ("history_visibility", "history_visibility", "shared"),
+ ]
+ for event_type, key, value in EVENT_KEEP_CONTENT_KEYS:
+ self.run_test(
+ {
+ "type": "m.room." + event_type,
+ "event_id": "$test:domain",
+ "content": {key: value, "other_key": "foo"},
+ },
+ {
+ "type": "m.room." + event_type,
+ "event_id": "$test:domain",
+ "content": {key: value},
+ "signatures": {},
+ "unsigned": {},
+ },
+ )
+
+ def test_create(self):
+ """Create events are partially redacted until MSC2176."""
self.run_test(
{
"type": "m.room.create",
"event_id": "$test:domain",
- "content": {"creator": "@2:domain", "other_field": "here"},
+ "content": {"creator": "@2:domain", "other_key": "foo"},
},
{
"type": "m.room.create",
@@ -129,6 +185,68 @@ class PruneEventTestCase(unittest.TestCase):
},
)
+ # After MSC2176, create events get nothing redacted.
+ self.run_test(
+ {"type": "m.room.create", "content": {"not_a_real_key": True}},
+ {
+ "type": "m.room.create",
+ "content": {"not_a_real_key": True},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.MSC2176,
+ )
+
+ def test_power_levels(self):
+ """Power level events keep a variety of content keys."""
+ self.run_test(
+ {
+ "type": "m.room.power_levels",
+ "event_id": "$test:domain",
+ "content": {
+ "ban": 1,
+ "events": {"m.room.name": 100},
+ "events_default": 2,
+ "invite": 3,
+ "kick": 4,
+ "redact": 5,
+ "state_default": 6,
+ "users": {"@admin:domain": 100},
+ "users_default": 7,
+ "other_key": 8,
+ },
+ },
+ {
+ "type": "m.room.power_levels",
+ "event_id": "$test:domain",
+ "content": {
+ "ban": 1,
+ "events": {"m.room.name": 100},
+ "events_default": 2,
+ # Note that invite is not here.
+ "kick": 4,
+ "redact": 5,
+ "state_default": 6,
+ "users": {"@admin:domain": 100},
+ "users_default": 7,
+ },
+ "signatures": {},
+ "unsigned": {},
+ },
+ )
+
+ # After MSC2176, power levels events keep the invite key.
+ self.run_test(
+ {"type": "m.room.power_levels", "content": {"invite": 75}},
+ {
+ "type": "m.room.power_levels",
+ "content": {"invite": 75},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.MSC2176,
+ )
+
def test_alias_event(self):
"""Alias events have special behavior up through room version 6."""
self.run_test(
@@ -146,8 +264,7 @@ class PruneEventTestCase(unittest.TestCase):
},
)
- def test_msc2432_alias_event(self):
- """After MSC2432, alias events have no special behavior."""
+ # After MSC2432, alias events have no special behavior.
self.run_test(
{"type": "m.room.aliases", "content": {"aliases": ["test"]}},
{
@@ -159,6 +276,32 @@ class PruneEventTestCase(unittest.TestCase):
room_version=RoomVersions.V6,
)
+ def test_redacts(self):
+ """Redaction events have no special behaviour until MSC2174/MSC2176."""
+
+ self.run_test(
+ {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+ {
+ "type": "m.room.redaction",
+ "content": {},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.V6,
+ )
+
+ # After MSC2174, redaction events keep the redacts content key.
+ self.run_test(
+ {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+ {
+ "type": "m.room.redaction",
+ "content": {"redacts": "$test2:domain"},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.MSC2176,
+ )
+
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0187f56e21..9ccd2d76b8 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -48,7 +48,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
)
# Get the room complexity
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEquals(200, channel.code)
@@ -60,7 +60,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEquals(200, channel.code)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 3009fbb6c4..cfeccc0577 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -46,7 +46,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/get_missing_events/(?P<room_id>[^/]*)/?"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
@@ -95,7 +95,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.inject_room_member(room_1, "@user:other.example.com", "join")
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.assertEquals(200, channel.code, channel.result)
@@ -127,7 +127,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/federation/transport/__init__.py
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index f9e3c7a51f..85500e169c 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,38 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
from tests import unittest
from tests.unittest import override_config
-class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
- class Authenticator:
- def authenticate_request(self, request, content):
- return defer.succeed("otherserver.nottld")
-
- ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig())
- server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
-
+class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
@override_config({"allow_public_rooms_over_federation": False})
def test_blocked_public_room_list_over_federation(self):
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/publicRooms"
+ """Test that unauthenticated requests to the public rooms directory 403 when
+ allow_public_rooms_over_federation is False.
+ """
+ channel = self.make_request(
+ "GET",
+ "/_matrix/federation/v1/publicRooms",
+ federation_auth_origin=b"example.com",
)
self.assertEquals(403, channel.code)
@override_config({"allow_public_rooms_over_federation": True})
def test_open_public_room_list_over_federation(self):
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/publicRooms"
+ """Test that unauthenticated requests to the public rooms directory 200 when
+ allow_public_rooms_over_federation is True.
+ """
+ channel = self.make_request(
+ "GET",
+ "/_matrix/federation/v1/publicRooms",
+ federation_auth_origin=b"example.com",
)
self.assertEquals(200, channel.code)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
new file mode 100644
index 0000000000..7baf224f7e
--- /dev/null
+++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.handlers.cas_handler import CasResponse
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+SERVER_URL = "https://issuer/"
+
+
+class CasHandlerTestCase(HomeserverTestCase):
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ cas_config = {
+ "enabled": True,
+ "server_url": SERVER_URL,
+ "service_url": BASE_URL,
+ }
+ config["cas_config"] = cas_config
+
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ self.handler = hs.get_cas_handler()
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler = hs.get_sso_handler()
+ sso_handler._MAP_USERNAME_RETRIES = 3
+
+ return hs
+
+ def test_map_cas_user_to_user(self):
+ """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=True
+ )
+
+ def test_map_cas_user_to_existing_user(self):
+ """Existing users can log in with CAS account."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # Map a user via SSO.
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=False
+ )
+
+ # Subsequent calls should map to the same mxid.
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=False
+ )
+
+ def test_map_cas_user_to_invalid_localpart(self):
+ """CAS automaps invalid characters to base-64 encoding."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ cas_response = CasResponse("föö", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+ )
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 875aaec2c6..5dfeccfeb6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -27,7 +27,7 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
return hs
@@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ee6ef5e6fa..a39f898608 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,8 +42,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.mock_registry.register_query_handler = register_query_handler
hs = self.setup_test_homeserver(
- http_client=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
@@ -407,7 +405,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
def test_denied(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
b"directory/room/%23test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -417,7 +415,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
def test_allowed(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
b"directory/room/%23unofficial_test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -433,7 +431,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.assertEquals(200, channel.code, channel.result)
@@ -448,7 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = True
# Room list is enabled so we should get some results
- request, channel = self.make_request("GET", b"publicRooms")
+ channel = self.make_request("GET", b"publicRooms")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) > 0)
@@ -456,13 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = False
# Room list disabled so we should get no results
- request, channel = self.make_request("GET", b"publicRooms")
+ channel = self.make_request("GET", b"publicRooms")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) == 0)
# Room list disabled so we shouldn't be allowed to publish rooms
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf866dacf3..983e368592 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
from unittest import TestCase
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
@@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(http_client=None)
+ hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastore()
return hs
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -191,6 +191,50 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invite_by_user_ratelimit(self):
+ """Tests that invites from federation to a particular user are
+ actually rate-limited.
+ """
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ def create_invite():
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+ return event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": other_user,
+ "state_key": "@user:test",
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ for i in range(3):
+ event = create_invite()
+ self.get_success(
+ self.handler.on_invite_request(other_server, event, event.room_version,)
+ )
+
+ event = create_invite()
+ self.get_failure(
+ self.handler.on_invite_request(other_server, event, event.room_version,),
+ exc=LimitExceededError,
+ )
+
def _build_and_send_join_event(self, other_server, other_user, room_id):
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
@@ -198,7 +242,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
- with LoggingContext(request="send_join"):
+ with LoggingContext("send_join"):
d = run_in_background(
self.handler.on_send_join_request, other_server, join_event
)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index af42775815..f955dfa490 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -206,7 +206,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
# Redaction of event should fail.
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", path, content={}, access_token=self.access_token
)
self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..ad20400b1d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,32 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from typing import Optional
from urllib.parse import parse_qs, urlparse
-from mock import Mock, patch
+from mock import ANY, Mock, patch
-import attr
import pymacaroons
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
from synapse.types import UserID
+from tests.test_utils import FakeResponse, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
+try:
+ import authlib # noqa: F401
-@attr.s
-class FakeResponse:
- code = attr.ib()
- body = attr.ib()
- phrase = attr.ib()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
+ HAS_OIDC = True
+except ImportError:
+ HAS_OIDC = False
# These are a few constants that are used as config parameters in the tests.
@@ -46,7 +40,7 @@ ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
CLIENT_SECRET = "test-client-secret"
BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -64,17 +58,14 @@ COMMON_CONFIG = {
}
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
@staticmethod
def parse_config(config):
return
+ def __init__(self, config):
+ pass
+
def get_remote_user_id(self, userinfo):
return userinfo["sub"]
@@ -97,16 +88,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-def simple_async_mock(return_value=None, raises=None):
- # AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
- if raises:
- raise raises
- return return_value
-
- return Mock(side_effect=cb)
-
-
async def get_json(url):
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
@@ -127,6 +108,9 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
+ if not HAS_OIDC:
+ skip = "requires OIDC"
+
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
@@ -155,6 +139,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
+ self.provider = self.handler._providers["oidc"]
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
@@ -166,27 +151,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def metadata_edit(self, values):
- return patch.dict(self.handler._provider_metadata, values)
+ return patch.dict(self.provider._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
+ self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
self.render_error.reset_mock()
+ return args
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
- self.assertEqual(self.handler._callback_url, CALLBACK_URL)
- self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
- self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+ self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+ self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
- metadata = self.get_success(self.handler.load_metadata())
+ metadata = self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -198,47 +185,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
- self.get_success(self.handler.load_metadata())
+ self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
- self.get_success(self.handler.load_metadata())
+ self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
- jwks = self.get_success(self.handler.load_jwks())
+ jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
- self.get_success(self.handler.load_jwks())
+ self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
- self.get_success(self.handler.load_jwks(force=True))
+ self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
- self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+ self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
- self.handler._scopes = [] # not asking the openid scope
+ self.provider._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
- jwks = self.get_success(self.handler.load_jwks(force=True))
+ jwks = self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": COMMON_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
- h = self.handler
+ h = self.provider
# Default test config does not throw
h._validate_metadata()
@@ -317,13 +304,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
- self.handler._validate_metadata()
+ self.provider._validate_metadata()
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
url = self.get_success(
- self.handler.handle_redirect_request(req, b"http://client/redirect")
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -347,14 +334,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
# For some reason, call.args does not work with python3.5
args = calls[0][0]
kwargs = calls[0][1]
- self.assertEqual(args[0], COOKIE_NAME)
- self.assertEqual(kwargs["path"], COOKIE_PATH)
+
+ # The cookie name and path don't really matter, just that it has to be coherent
+ # between the callback & redirect handlers.
+ self.assertEqual(args[0], b"oidc_session")
+ self.assertEqual(kwargs["path"], "/_synapse/client/oidc")
cookie = args[1]
macaroon = pymacaroons.Macaroon.deserialize(cookie)
- state = self.handler._get_value_from_macaroon(macaroon, "state")
- nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
- redirect = self.handler._get_value_from_macaroon(
+ state = self.handler._token_generator._get_value_from_macaroon(
+ macaroon, "state"
+ )
+ nonce = self.handler._token_generator._get_value_from_macaroon(
+ macaroon, "nonce"
+ )
+ redirect = self.handler._token_generator._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
@@ -384,31 +378,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
- when the userinfo fetching fails
- when the code exchange fails
"""
+
+ # ensure that we are correctly testing the fallback when "get_extra_attributes"
+ # is not implemented.
+ mapping_provider = self.provider._user_mapping_provider
+ with self.assertRaises(AttributeError):
+ _ = mapping_provider.get_extra_attributes
+
token = {
"type": "bearer",
"id_token": "id_token",
"access_token": "access_token",
}
+ username = "bar"
userinfo = {
"sub": "foo",
- "preferred_username": "bar",
+ "username": username,
}
- user_id = "@foo:domain.org"
- self.handler._exchange_code = simple_async_mock(return_value=token)
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ expected_user_id = "@%s:%s" % (username, self.hs.hostname)
+ self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
code = "code"
state = "state"
@@ -416,74 +408,61 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
- state=state,
- nonce=nonce,
- client_redirect_url=client_redirect_url,
- ui_auth_session_id=None,
+ session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+ request = _build_callback_request(
+ code, state, session, user_agent=user_agent, ip_address=ip_address
)
- request.args = {}
- request.args[b"code"] = [code.encode("utf-8")]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = ip_address
- request.get_user_agent.return_value = user_agent
-
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, None, new_user=True
)
- self.handler._exchange_code.assert_called_once_with(code)
- self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
- )
- self.handler._fetch_userinfo.assert_not_called()
+ self.provider._exchange_code.assert_called_once_with(code)
+ self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.provider._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
- self.handler._map_userinfo_to_user = simple_async_mock(
- raises=MappingException()
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("mapping_error")
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ with patch.object(
+ self.provider,
+ "_remote_id_from_userinfo",
+ new=Mock(side_effect=MappingException()),
+ ):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- self.handler._auth_handler.complete_sso_login.reset_mock()
- self.handler._exchange_code.reset_mock()
- self.handler._parse_id_token.reset_mock()
- self.handler._map_userinfo_to_user.reset_mock()
- self.handler._fetch_userinfo.reset_mock()
+ auth_handler.complete_sso_login.reset_mock()
+ self.provider._exchange_code.reset_mock()
+ self.provider._parse_id_token.reset_mock()
+ self.provider._fetch_userinfo.reset_mock()
# With userinfo fetching
- self.handler._scopes = [] # do not ask the "openid" scope
+ self.provider._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
- )
- self.handler._exchange_code.assert_called_once_with(code)
- self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, None, new_user=False
)
- self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.provider._exchange_code.assert_called_once_with(code)
+ self.provider._parse_id_token.assert_not_called()
+ self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
- self.handler._exchange_code = simple_async_mock(
+ from synapse.handlers.oidc_handler import OidcError
+
+ self.provider._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
@@ -513,11 +492,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_session")
# Mismatching session
- session = self.handler._generate_oidc_session_token(
- state="state",
- nonce="nonce",
- client_redirect_url="http://client/redirect",
- ui_auth_session_id=None,
+ session = self._generate_oidc_session_token(
+ state="state", nonce="nonce", client_redirect_url="http://client/redirect",
)
request.args = {}
request.args[b"state"] = [b"mismatching state"]
@@ -541,7 +517,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
- ret = self.get_success(self.handler._exchange_code(code))
+ ret = self.get_success(self.provider._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -563,7 +539,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ from synapse.handlers.oidc_handler import OidcError
+
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar")
@@ -573,7 +551,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
@@ -585,14 +563,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
@@ -601,7 +579,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@override_config(
@@ -624,72 +602,56 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
userinfo = {
"sub": "foo",
+ "username": "foo",
"phone": "1234567",
}
- user_id = "@foo:domain.org"
- self.handler._exchange_code = simple_async_mock(return_value=token)
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
state = "state"
client_redirect_url = "http://client/redirect"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
- state=state,
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- ui_auth_session_id=None,
+ session = self._generate_oidc_session_token(
+ state=state, nonce="nonce", client_redirect_url=client_redirect_url,
)
-
- request.args = {}
- request.args[b"code"] = [b"code"]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = "10.0.0.1"
- request.get_user_agent.return_value = "Browser"
+ request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {"phone": "1234567"},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@foo:test",
+ request,
+ client_redirect_url,
+ {"phone": "1234567"},
+ new_user=True,
)
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
userinfo = {
"sub": "test_user",
"username": "test_user",
}
- # The token doesn't matter with the default user mapping provider.
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", ANY, ANY, None, new_user=True
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user_2:test", ANY, ANY, None, new_user=True
)
- self.assertEqual(mxid, "@test_user_2:test")
+ auth_handler.complete_sso_login.reset_mock()
# Test if the mxid is already taken
store = self.hs.get_datastore()
@@ -698,14 +660,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
- self.assertEqual(
- str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error",
+ "Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {"allow_existing_users": True}})
@@ -717,26 +676,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -747,13 +706,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -770,14 +727,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ args = self.assertRenderedError("mapping_error")
self.assertTrue(
- str(e.value).startswith(
+ args[2].startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
)
)
@@ -788,28 +742,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@TEST_USER_2:test", ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@TEST_USER_2:test")
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
- userinfo = {
- "sub": "test2",
- "username": "föö",
- }
- token = {}
-
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(
+ _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
)
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
{
@@ -822,6 +765,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -830,14 +776,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", ANY, ANY, None, new_user=True
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -853,12 +798,128 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+
+ def test_empty_localpart(self):
+ """Attempts to map onto an empty localpart should be rejected."""
+ userinfo = {
+ "sub": "tester",
+ "username": "",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "config": {"localpart_template": "{{ user.username }}"}
+ }
+ }
+ }
+ )
+ def test_null_localpart(self):
+ """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+ userinfo = {
+ "sub": "tester",
+ "username": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+ def _generate_oidc_session_token(
+ self,
+ state: str,
+ nonce: str,
+ client_redirect_url: str,
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
+ from synapse.handlers.oidc_handler import OidcSessionData
+
+ return self.handler._token_generator.generate_oidc_session_token(
+ state=state,
+ session_data=OidcSessionData(
+ idp_id="oidc",
+ nonce=nonce,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=ui_auth_session_id,
+ ),
)
+
+
+async def _make_callback_with_userinfo(
+ hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+ """Mock up an OIDC callback with the given userinfo dict
+
+ We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+ and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
+
+ Args:
+ hs: the HomeServer impl to send the callback to.
+ userinfo: the OIDC userinfo dict
+ client_redirect_url: the URL to redirect to on success.
+ """
+ from synapse.handlers.oidc_handler import OidcSessionData
+
+ handler = hs.get_oidc_handler()
+ provider = handler._providers["oidc"]
+ provider._exchange_code = simple_async_mock(return_value={})
+ provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+
+ state = "state"
+ session = handler._token_generator.generate_oidc_session_token(
+ state=state,
+ session_data=OidcSessionData(
+ idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+ ),
+ )
+ request = _build_callback_request("code", state, session)
+
+ await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+ code: str,
+ state: str,
+ session: str,
+ user_agent: str = "Browser",
+ ip_address: str = "10.0.0.1",
+):
+ """Builds a fake SynapseRequest to mock the browser callback
+
+ Returns a Mock object which looks like the SynapseRequest we get from a browser
+ after SSO (before we return to the client)
+
+ Args:
+ code: the authorization code which would have been returned by the OIDC
+ provider
+ state: the "state" param which would have been passed around in the
+ query param. Should be the same as was embedded in the session in
+ _build_oidc_session.
+ session: the "session" which would have been passed around in the cookie.
+ user_agent: the user-agent to present
+ ip_address: the IP address to pretend the request came from
+ """
+ request = Mock(
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "getHeader",
+ ]
+ )
+
+ request.getCookie.return_value = session
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+ request.getClientIP.return_value = ip_address
+ return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..f816594ee4 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -432,6 +432,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **providers_config(CustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled(self):
+ """Check the localdb_enabled == enabled == False
+
+ Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+ that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+ cause an exception.
+ """
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
@@ -528,7 +551,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
def _get_login_flows(self) -> JsonDict:
- _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["flows"]
@@ -537,7 +560,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
- _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+ channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
def _start_delete_device_session(self, access_token, device_id) -> str:
@@ -574,7 +597,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
) -> FakeChannel:
"""Delete an individual device."""
- _, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=access_token
)
return channel
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 8ed67640f8..0794b32c9c 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "server", http_client=None, federation_sender=Mock()
+ "server", federation_http_client=None, federation_sender=Mock()
)
return hs
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..022943a10a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
- http_client=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
@@ -107,6 +105,21 @@ class ProfileTestCase(unittest.TestCase):
"Frank",
)
+ # Set displayname to an empty string
+ yield defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), ""
+ )
+ )
+
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ )
+ )
+
@defer.inlineCallbacks
def test_set_my_name_if_disabled(self):
self.hs.config.enable_set_displayname = False
@@ -225,6 +238,21 @@ class ProfileTestCase(unittest.TestCase):
"http://my.server/me.png",
)
+ # Set avatar to an empty string
+ yield defer.ensureDeferred(
+ self.handler.set_avatar_url(
+ self.frank, synapse.types.create_requester(self.frank), "",
+ )
+ )
+
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
+ )
+
@defer.inlineCallbacks
def test_set_my_avatar_if_disabled(self):
self.hs.config.enable_set_avatar_url = False
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 45dc17aba5..a8d6c0f617 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,13 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
+from mock import Mock
+
import attr
from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
+from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
+# Check if we have the dependencies to run the tests.
+try:
+ import saml2.config
+ from saml2.sigver import SigverError
+
+ has_saml2 = True
+
+ # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
+ config = saml2.config.SPConfig()
+ try:
+ config.load({"metadata": {}})
+ has_xmlsec1 = True
+ except SigverError:
+ has_xmlsec1 = False
+except ImportError:
+ has_saml2 = False
+ has_xmlsec1 = False
+
# These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/"
@@ -26,6 +48,8 @@ BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
+ assertions = attr.ib(type=list, factory=list)
+ in_response_to = attr.ib(type=Optional[str], default=None)
class TestMappingProvider:
@@ -86,17 +110,29 @@ class SamlHandlerTestCase(HomeserverTestCase):
return hs
+ if not has_saml2:
+ skip = "Requires pysaml2"
+ elif not has_xmlsec1:
+ skip = "Requires xmlsec1"
+
def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- # The redirect_url doesn't matter with the default user mapping provider.
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=True
)
- self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
@@ -106,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
store.register_user(user_id="@test_user:test", password_hash=None)
)
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "", None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "", None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # mock out the error renderer too
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
- redirect_url = ""
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ )
+ sso_handler.render_error.assert_called_once_with(
+ request, "mapping_error", "localpart is invalid: föö"
)
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+ # stub out the auth handler and error renderer
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
+ # register a user to occupy the first-choice MXID
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
+
+ # send the fake SAML response
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", request, "", None, new_user=True
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular SAML username.
self.get_success(
@@ -165,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ sso_handler.render_error.assert_called_once_with(
+ request,
+ "mapping_error",
+ "Unable to generate a Matrix ID from the SSO response",
)
+ auth_handler.complete_sso_login.assert_not_called()
@override_config(
{
@@ -185,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
)
def test_map_saml_response_redirect(self):
+ """Test a mapping provider that raises a RedirectException"""
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
+ request = _mock_request()
e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
+ self.handler._handle_authn_response(request, saml_response, ""),
RedirectException,
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..96e5bdac4a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
import json
+from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.web.resource import Resource
from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
-from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- servlets = [register_federation_servlets]
-
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -70,13 +70,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(
notifier=Mock(),
- http_client=mock_federation_client,
+ federation_http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
return hs
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
+
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -192,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- put_json = self.hs.get_http_client().put_json
+ put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
@@ -215,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- (request, channel) = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
@@ -270,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
- put_json = self.hs.get_http_client().put_json
+ put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 98e5af2072..9c886d671a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
)
+ regular_user_id = "@regular:test"
+ self.get_success(
+ self.store.register_user(user_id=regular_user_id, password_hash=None)
+ )
self.get_success(
self.handler.handle_local_profile_change(support_user_id, None)
@@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
display_name = "display_name"
profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
- regular_user_id = "@regular:test"
self.get_success(
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
self.assertTrue(profile["display_name"] == display_name)
+ def test_handle_local_profile_change_with_deactivated_user(self):
+ # create user
+ r_user_id = "@regular:test"
+ self.get_success(
+ self.store.register_user(user_id=r_user_id, password_hash=None)
+ )
+
+ # update profile
+ display_name = "Regular User"
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+ self.get_success(
+ self.handler.handle_local_profile_change(r_user_id, profile_info)
+ )
+
+ # profile is in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile["display_name"] == display_name)
+
+ # deactivate user
+ self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
+ self.get_success(self.handler.handle_user_deactivated(r_user_id))
+
+ # profile is not in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile is None)
+
+ # update profile after deactivation
+ self.get_success(
+ self.handler.handle_local_profile_change(r_user_id, profile_info)
+ )
+
+ # profile is furthermore not in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile is None)
+
def test_handle_user_deactivated_support_user(self):
s_user_id = "@support:test"
self.get_success(
@@ -270,7 +308,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
spam_checker = self.hs.get_spam_checker()
class AllowAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -283,7 +321,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that filters all users.
class BlockAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
@@ -534,7 +572,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
self.helper.join(room, user=u2)
# Assert user directory is not empty
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.assertEquals(200, channel.code, channel.result)
@@ -542,7 +580,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
# Disable user directory and check search returns nothing
self.config.user_directory_search_enabled = False
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.assertEquals(200, channel.code, channel.result)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..686012dd25 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,6 +17,7 @@ import logging
from mock import Mock
import treq
+from netaddr import IPSet
from service_identity import VerificationError
from zope.interface import implementer
@@ -35,6 +36,7 @@ from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server
from synapse.http.federation.well_known_resolver import (
+ WELL_KNOWN_MAX_SIZE,
WellKnownResolver,
_cache_period_from_headers,
)
@@ -103,6 +105,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=self.tls_factory,
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
+ ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
@@ -736,6 +739,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=tls_factory,
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
+ ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
self.reactor,
@@ -1091,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# Expire both caches and repeat the request
self.reactor.pump((10000.0,))
- # Repated the request, this time it should fail if the lookup fails.
+ # Repeat the request, this time it should fail if the lookup fails.
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
@@ -1104,6 +1108,32 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, None)
+ def test_well_known_too_large(self):
+ """A well-known query that returns a result which is too large should be rejected."""
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ response_headers={b"Cache-Control": b"max-age=1000"},
+ content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
+ )
+
+ # The result is successful, but disabled delegation.
+ r = self.successResultOf(fetch_d)
+ self.assertIsNone(r.delegated_server)
+
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.
"""
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 05e9c449be..453391a5a5 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -46,16 +46,16 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
- request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+ channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
- request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+ channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
new file mode 100644
index 0000000000..f17c122e93
--- /dev/null
+++ b/tests/http/test_client.py
@@ -0,0 +1,101 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from io import BytesIO
+
+from mock import Mock
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
+from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
+
+from tests.unittest import TestCase
+
+
+class ReadBodyWithMaxSizeTests(TestCase):
+ def setUp(self):
+ """Start reading the body, returns the response, result and proto"""
+ self.response = Mock()
+ self.result = BytesIO()
+ self.deferred = read_body_with_max_size(self.response, self.result, 6)
+
+ # Fish the protocol out of the response.
+ self.protocol = self.response.deliverBody.call_args[0][0]
+ self.protocol.transport = Mock()
+
+ def _cleanup_error(self):
+ """Ensure that the error in the Deferred is handled gracefully."""
+ called = [False]
+
+ def errback(f):
+ called[0] = True
+
+ self.deferred.addErrback(errback)
+ self.assertTrue(called[0])
+
+ def test_no_error(self):
+ """A response that is NOT too large."""
+
+ # Start sending data.
+ self.protocol.dataReceived(b"12345")
+ # Close the connection.
+ self.protocol.connectionLost(Failure(ResponseDone()))
+
+ self.assertEqual(self.result.getvalue(), b"12345")
+ self.assertEqual(self.deferred.result, 5)
+
+ def test_too_large(self):
+ """A response which is too large raises an exception."""
+
+ # Start sending data.
+ self.protocol.dataReceived(b"1234567890")
+ # Close the connection.
+ self.protocol.connectionLost(Failure(ResponseDone()))
+
+ self.assertEqual(self.result.getvalue(), b"1234567890")
+ self.assertIsInstance(self.deferred.result, Failure)
+ self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+ self._cleanup_error()
+
+ def test_multiple_packets(self):
+ """Data should be accummulated through mutliple packets."""
+
+ # Start sending data.
+ self.protocol.dataReceived(b"12")
+ self.protocol.dataReceived(b"34")
+ # Close the connection.
+ self.protocol.connectionLost(Failure(ResponseDone()))
+
+ self.assertEqual(self.result.getvalue(), b"1234")
+ self.assertEqual(self.deferred.result, 4)
+
+ def test_additional_data(self):
+ """A connection can receive data after being closed."""
+
+ # Start sending data.
+ self.protocol.dataReceived(b"1234567890")
+ self.assertIsInstance(self.deferred.result, Failure)
+ self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+ self.protocol.transport.loseConnection.assert_called_once()
+
+ # More data might have come in.
+ self.protocol.dataReceived(b"1234567890")
+ # Close the connection.
+ self.protocol.connectionLost(Failure(ResponseDone()))
+
+ self.assertEqual(self.result.getvalue(), b"1234567890")
+ self.assertIsInstance(self.deferred.result, Failure)
+ self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+ self._cleanup_error()
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index b2e9533b07..d06ea518ce 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name
+from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name
from tests import unittest
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 212484a7fe..9c52c8fdca 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -560,4 +560,4 @@ class FederationClientTests(HomeserverTestCase):
self.pump()
f = self.failureResultOf(test_d)
- self.assertIsInstance(f.value, ValueError)
+ self.assertIsInstance(f.value, RequestSendFailed)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22abf76515..9a56e1c14a 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -15,12 +15,14 @@
import logging
import treq
+from netaddr import IPSet
from twisted.internet import interfaces # noqa: F401
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.http import HTTPChannel
+from synapse.http.client import BlacklistingReactorWrapper
from synapse.http.proxyagent import ProxyAgent
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
@@ -292,6 +294,134 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
+ def test_http_request_via_proxy_with_blacklist(self):
+ # The blacklist includes the configured proxy IP.
+ agent = ProxyAgent(
+ BlacklistingReactorWrapper(
+ self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+ ),
+ self.reactor,
+ http_proxy=b"proxy.com:8888",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy_with_blacklist(self):
+ # The blacklist includes the configured proxy IP.
+ agent = ProxyAgent(
+ BlacklistingReactorWrapper(
+ self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+ ),
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
def _wrap_server_factory_for_tls(factory, sanlist=None):
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 73f469b802..48a74e2eee 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -18,30 +18,35 @@ import logging
from io import StringIO
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
+from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin
from tests.unittest import TestCase
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
+ def setUp(self):
+ self.output = StringIO()
+
+ def get_log_line(self):
+ # One log message, with a single trailing newline.
+ data = self.output.getvalue()
+ logs = data.splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(data.count("\n"), 1)
+ return json.loads(logs[0])
+
def test_terse_json_output(self):
"""
The Terse JSON formatter converts log messages to JSON.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -57,9 +62,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
Additional information can be included in the structured logging.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
@@ -67,12 +70,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
)
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -96,26 +94,44 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
The Terse JSON formatter converts log messages to JSON.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
+
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+
+ def test_with_context(self):
+ """
+ The logging context should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ handler.addFilter(LoggingContextFilter())
+ logger = self.get_logger(handler)
+
+ with LoggingContext(request="test"):
+ logger.info("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
"log",
"level",
"namespace",
+ "request",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertEqual(log["request"], "test")
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index bcdcafa5a9..c4e1e7ed85 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -187,6 +187,36 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_multiple_rooms(self):
+ # We want to test multiple notifications from multiple rooms, so we pause
+ # processing of push while we send messages.
+ self.pusher._pause_processing()
+
+ # Create a simple room with multiple other users
+ rooms = [
+ self.helper.create_room_as(self.user_id, tok=self.access_token),
+ self.helper.create_room_as(self.user_id, tok=self.access_token),
+ ]
+
+ for r, other in zip(rooms, self.others):
+ self.helper.invite(
+ room=r, src=self.user_id, tok=self.access_token, targ=other.id
+ )
+ self.helper.join(room=r, user=other.id, tok=other.token)
+
+ # The other users send some messages
+ self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token)
+ self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+ self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+
+ # Nothing should have happened yet, as we're paused.
+ assert not self.email_attempts
+
+ self.pusher._resume_processing()
+
+ # We should get emailed about those messages
+ self._check_for_mail()
+
def test_encrypted_message(self):
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
@@ -209,7 +239,7 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ last_stream_ordering = pushers[0].last_stream_ordering
# Advance time a bit, so the pusher will register something has happened
self.pump(10)
@@ -220,7 +250,7 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+ self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
# One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1)
@@ -238,4 +268,4 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index f118430309..60f0820cff 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
+from synapse.push import PusherConfigException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import receipts
@@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
+ def default_config(self):
+ config = super().default_config()
+ config["start_pushers"] = True
+ return config
+
def make_homeserver(self, reactor, clock):
self.push_attempts = []
@@ -46,13 +52,49 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
- config = self.default_config()
- config["start_pushers"] = True
-
- hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
+ hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
return hs
+ def test_invalid_configuration(self):
+ """Invalid push configurations should be rejected."""
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+
+ def test_data(data):
+ self.get_failure(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data=data,
+ ),
+ PusherConfigException,
+ )
+
+ # Data must be provided with a URL.
+ test_data(None)
+ test_data({})
+ test_data({"url": 1})
+ # A bare domain name isn't accepted.
+ test_data({"url": "example.com"})
+ # A URL without a path isn't accepted.
+ test_data({"url": "http://example.com"})
+ # A url with an incorrect path isn't accepted.
+ test_data({"url": "http://example.com/foo"})
+
def test_sends_http(self):
"""
The HTTP pusher will send pushes for each message to a HTTP endpoint
@@ -82,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -102,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ last_stream_ordering = pushers[0].last_stream_ordering
# Advance time a bit, so the pusher will register something has happened
self.pump()
@@ -113,11 +155,13 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+ self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
# One push was attempted to be sent -- it'll be the first message
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(
self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
)
@@ -132,12 +176,14 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+ last_stream_ordering = pushers[0].last_stream_ordering
# Now it'll try and send the second push message, which will be the second one
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(
self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
)
@@ -152,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
def test_sends_high_priority_for_encrypted(self):
"""
@@ -194,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -230,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Add yet another person — we want to make this room not a 1:1
@@ -268,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_one_to_one_only(self):
@@ -310,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -326,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority — this is a one-to-one room
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Yet another user joins
@@ -345,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -392,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -408,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time with no mention
@@ -417,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -465,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -485,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time as someone without the power of @room
@@ -496,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -570,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -589,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
# Check that the unread count for the room is 0
#
@@ -603,7 +667,7 @@ class HTTPPusherTests(HomeserverTestCase):
# This will actually trigger a new notification to be sent out so that
# even if the user does not receive another message, their unread
# count goes down
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
{},
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
new file mode 100644
index 0000000000..aff563919d
--- /dev/null
+++ b/tests/push/test_presentable_names.py
@@ -0,0 +1,229 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.presentable_names import calculate_room_name
+from synapse.types import StateKey, StateMap
+
+from tests import unittest
+
+
+class MockDataStore:
+ """
+ A fake data store which stores a mapping of state key to event content.
+ (I.e. the state key is used as the event ID.)
+ """
+
+ def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
+ """
+ Args:
+ events: A state map to event contents.
+ """
+ self._events = {}
+
+ for i, (event_id, content) in enumerate(events):
+ self._events[event_id] = FrozenEvent(
+ {
+ "event_id": "$event_id",
+ "type": event_id[0],
+ "sender": "@user:test",
+ "state_key": event_id[1],
+ "room_id": "#room:test",
+ "content": content,
+ "origin_server_ts": i,
+ },
+ RoomVersions.V1,
+ )
+
+ async def get_event(
+ self, event_id: StateKey, allow_none: bool = False
+ ) -> Optional[FrozenEvent]:
+ assert allow_none, "Mock not configured for allow_none = False"
+
+ return self._events.get(event_id)
+
+ async def get_events(self, event_ids: Iterable[StateKey]):
+ # This is cheating since it just returns all events.
+ return self._events
+
+
+class PresentableNamesTestCase(unittest.HomeserverTestCase):
+ USER_ID = "@test:test"
+ OTHER_USER_ID = "@user:test"
+
+ def _calculate_room_name(
+ self,
+ events: StateMap[dict],
+ user_id: str = "",
+ fallback_to_members: bool = True,
+ fallback_to_single_member: bool = True,
+ ):
+ # This isn't 100% accurate, but works with MockDataStore.
+ room_state_ids = {k[0]: k[0] for k in events}
+
+ return self.get_success(
+ calculate_room_name(
+ MockDataStore(events),
+ room_state_ids,
+ user_id or self.USER_ID,
+ fallback_to_members,
+ fallback_to_single_member,
+ )
+ )
+
+ def test_name(self):
+ """A room name event should be used."""
+ events = [
+ ((EventTypes.Name, ""), {"name": "test-name"}),
+ ]
+ self.assertEqual("test-name", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.Name, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.Name, ""), {"name": 1})]
+ self.assertEqual(1, self._calculate_room_name(events))
+
+ def test_canonical_alias(self):
+ """An canonical alias should be used."""
+ events = [
+ ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
+ ]
+ self.assertEqual("#test-name:test", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_invite(self):
+ """An invite has special behaviour."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
+ ]
+ self.assertEqual("Invite from Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+ # Ensure this logic is skipped if we don't fallback to members.
+ self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
+
+ # Check if the event content has garbage.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ]
+ self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
+
+ # No member event for sender.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ]
+ self.assertEqual("Room Invite", self._calculate_room_name(events))
+
+ def test_no_members(self):
+ """Behaviour of an empty room."""
+ events = []
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ # Note that events with invalid (or missing) membership are ignored.
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
+ ]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_no_other_members(self):
+ """Behaviour of a room with no other members in it."""
+ events = [
+ (
+ (EventTypes.Member, self.USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Me"},
+ ),
+ ]
+ self.assertEqual("Me", self._calculate_room_name(events))
+
+ # Check if the event content has no displayname.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("@test:test", self._calculate_room_name(events))
+
+ # 3pid invite, use the other user (who is set as the sender).
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual(
+ "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
+ )
+
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
+ ]
+ self.assertEqual(
+ "Inviting email address",
+ self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
+ )
+
+ def test_one_other_member(self):
+ """Behaviour of a room with a single other member."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ]
+ self.assertEqual("Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+
+ # Check if the event content has no displayname and is an invite.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.INVITE},
+ ),
+ ]
+ self.assertEqual("@user:test", self._calculate_room_name(events))
+
+ def test_other_members(self):
+ """Behaviour of a room with multiple other members."""
+ # Two other members.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
+
+ # Three or more other members.
+ events.append(
+ ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
+ )
+ self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1f4b5ca2ac..4a841f5bb8 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"type": "m.room.history_visibility",
"sender": "@user:test",
"state_key": "",
- "room_id": "@room:test",
+ "room_id": "#room:test",
"content": content,
},
RoomVersions.V1,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..d5dce1f83f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- servlets = [
- streams.register_servlets,
- ]
-
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
- http_client=None,
+ federation_http_client=None,
homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
@@ -210,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
+ # We may have an attempt to connect to redis for the external cache already.
+ self.connect_any_redis_attempts()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
@@ -264,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
- useful to e.g. pass some mocks for things like `http_client`
+ useful to e.g. pass some mocks for things like `federation_http_client`
Returns:
The new worker HomeServer instance.
@@ -399,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one.
"""
clients = self.reactor.tcpClients
- self.assertEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
- self.assertEqual(port, 6379)
+ while clients:
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
- client_to_server_transport = FakeTransport(
- server_protocol, self.reactor, client_protocol
- )
- client_protocol.makeConnection(client_to_server_transport)
-
- server_to_client_transport = FakeTransport(
- client_protocol, self.reactor, server_protocol
- )
- server_protocol.makeConnection(server_to_client_transport)
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
- return client_to_server_transport, server_to_client_transport
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -622,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
+
+ # Since we use SET/GET to cache things we can safely no-op them.
+ elif command == b"SET":
+ self.send("OK")
+ elif command == b"GET":
+ self.send(None)
else:
raise Exception("Unknown command")
@@ -643,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
+ if obj is None:
+ return "$-1\r\n"
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..f35a5235e1
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+ """Test the authentication of HTTP calls between workers."""
+
+ servlets = [register.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ # This isn't a real configuration option but is used to provide the main
+ # homeserver and worker homeserver different options.
+ main_replication_secret = config.pop("main_replication_secret", None)
+ if main_replication_secret:
+ config["worker_replication_secret"] = main_replication_secret
+ return self.setup_test_homeserver(config=config)
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+
+ return config
+
+ def _test_register(self) -> FakeChannel:
+ """Run the actual test:
+
+ 1. Create a worker homeserver.
+ 2. Start registration by providing a user/password.
+ 3. Complete registration by providing dummy auth (this hits the main synapse).
+ 4. Return the final request.
+
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
+
+ channel_1 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.assertEqual(channel_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ return make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
+ )
+
+ def test_no_auth(self):
+ """With no authentication the request should finish.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ @override_config({"main_replication_secret": "my-secret"})
+ def test_missing_auth(self):
+ """If the main process expects a secret that is not provided, an error results.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 500)
+
+ @override_config(
+ {
+ "main_replication_secret": "my-secret",
+ "worker_replication_secret": "wrong-secret",
+ }
+ )
+ def test_unauthorized(self):
+ """If the main process receives the wrong secret, an error results.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 500)
+
+ @override_config({"worker_replication_secret": "my-secret"})
+ def test_authorized(self):
+ """The request should finish when the worker provides the authentication header.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..4608b65a0c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,19 @@
# limitations under the License.
import logging
-from synapse.api.constants import LoginType
-from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel, make_request
+from tests.server import make_request
logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
- """Base class for tests of the replication streams"""
+ """Test using one or more client readers for registration."""
servlets = [register.register_servlets]
- def prepare(self, reactor, clock, hs):
- self.recaptcha_checker = DummyRecaptchaChecker(hs)
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
@@ -48,27 +40,27 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
- request_1, channel_1 = make_request(
+ channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_1.code, 401)
+ )
+ self.assertEqual(channel_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = make_request(
+ channel_2 = make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_2.code, 200)
+ )
+ self.assertEqual(channel_2.code, 200)
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
@@ -80,28 +72,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
site_1 = self._hs_to_site[worker_hs_1]
- request_1, channel_1 = make_request(
+ channel_1 = make_request(
self.reactor,
site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_1.code, 401)
+ )
+ self.assertEqual(channel_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
site_2 = self._hs_to_site[worker_hs_2]
- request_2, channel_2 = make_request(
+ channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_2.code, 200)
+ )
+ self.assertEqual(channel_2.code, 200)
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 779745ae9d..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.federation_sender",
{"send_federation": True},
- http_client=mock_client,
+ federation_http_client=mock_client,
)
user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user3", "pass")
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 48b574ccbe..d1feca961f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
- self.reactor.lookups["example.com"] = "127.0.0.2"
+ self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self):
conf = super().default_config()
@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
the media which the caller should respond to.
"""
resource = hs.get_media_repository_resource().children[b"download"]
- _, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 67c27a089f..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "https://push.example.com/push"},
+ data={"url": "https://push.example.com/_matrix/push/v1/notify"},
)
)
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.pusher",
{"start_pushers": True},
- proxied_http_client=http_client_mock,
+ proxied_blacklisted_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock1,
+ proxied_blacklisted_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock2,
+ proxied_blacklisted_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 77fc3856d5..8d494ebc03 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -180,7 +180,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Do an initial sync so that we're up to date.
- request, channel = make_request(
+ channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
next_batch = channel.json_body["next_batch"]
@@ -206,7 +206,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -236,7 +236,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -261,7 +261,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -279,7 +279,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -292,7 +292,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -307,7 +307,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Paginating forwards should give the same results
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -318,7 +318,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertListEqual([], channel.json_body["chunk"])
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 4f76f8f768..9d22c04073 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -42,7 +42,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
return resource
def test_version_string(self):
- request, channel = self.make_request("GET", self.url, shorthand=False)
+ channel = self.make_request("GET", self.url, shorthand=False)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -68,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
def test_delete_group(self):
# Create a new group
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/create_group".encode("ascii"),
access_token=self.admin_user_tok,
@@ -84,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
# Invite/join another user
url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
url = "/groups/%s/self/accept_invite" % (group_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -101,7 +99,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
# Now delete the group
url = "/_synapse/admin/v1/delete_group/" + group_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
access_token=self.admin_user_tok,
@@ -123,7 +121,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"""
url = "/groups/%s/profile" % (group_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
@@ -134,7 +132,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
def _get_groups_user_is_in(self, access_token):
"""Returns the list of groups the user is in (given their access token)
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/joined_groups".encode("ascii"), access_token=access_token
)
@@ -155,9 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
- self.hs = hs
-
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
@@ -210,13 +205,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
}
config["media_storage_providers"] = [provider_config]
- hs = self.setup_test_homeserver(config=config, http_client=client)
+ hs = self.setup_test_homeserver(config=config, federation_http_client=client)
return hs
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
@@ -241,7 +236,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt quarantine media APIs as non-admin
url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
@@ -254,7 +249,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# And the roomID/userID endpoint
url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
@@ -282,7 +277,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
@@ -299,7 +294,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
urllib.parse.quote(server_name),
urllib.parse.quote(media_id),
)
- request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ channel = self.make_request("POST", url, access_token=admin_user_tok,)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
@@ -351,7 +346,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
room_id
)
- request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ channel = self.make_request("POST", url, access_token=admin_user_tok,)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
self.assertEqual(
@@ -395,7 +390,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
self.pump(1.0)
@@ -431,13 +426,17 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Mark the second item as safe from quarantine.
_, media_id_2 = server_and_media_id_2.split("/")
- self.get_success(self.store.mark_local_media_as_safe(media_id_2))
+ # Quarantine the media
+ url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
+ channel = self.make_request("POST", url, access_token=admin_user_tok)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
self.pump(1.0)
@@ -453,7 +452,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index cf3a007598..248c4442c3 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -50,17 +50,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
Try to get a device of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- request, channel = self.make_request("PUT", self.url, b"{}")
+ channel = self.make_request("PUT", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- request, channel = self.make_request("DELETE", self.url, b"{}")
+ channel = self.make_request("DELETE", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -69,21 +69,21 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url, access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", self.url, access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", self.url, access_token=self.other_user_token,
)
@@ -99,23 +99,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
% self.other_user_device_id
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -129,23 +123,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
% self.other_user_device_id
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -158,22 +146,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
# Delete unknown device returns status 200
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -197,7 +179,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
}
body = json.dumps(update)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
@@ -208,9 +190,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
@@ -227,16 +207,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
- "PUT", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
@@ -247,7 +223,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
# Set new display_name
body = json.dumps({"display_name": "new displayname"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
@@ -257,9 +233,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
@@ -268,9 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
Tests that a normal lookup for a device is successfully
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -291,7 +263,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(1, number_devices)
# Delete device
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok,
)
@@ -323,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
Try to list devices of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -334,9 +306,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
- )
+ channel = self.make_request("GET", self.url, access_token=other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -346,9 +316,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -359,9 +327,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -373,9 +339,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
# Get devices
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -391,9 +355,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.login("user", "pass")
# Get devices
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"])
@@ -431,7 +393,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
Try to delete devices of an user without authentication.
"""
- request, channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -442,9 +404,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "POST", self.url, access_token=other_user_token,
- )
+ channel = self.make_request("POST", self.url, access_token=other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -454,9 +414,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
- request, channel = self.make_request(
- "POST", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -467,9 +425,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
- request, channel = self.make_request(
- "POST", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -479,7 +435,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a remove of a device that does not exist returns 200.
"""
body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
@@ -510,7 +466,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
# Delete devices
body = json.dumps({"devices": device_ids})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 11b72c10f7..d0090faa4f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -74,7 +72,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
Try to get an event report without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -84,9 +82,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.other_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -96,9 +92,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -111,7 +105,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with limit
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
@@ -126,7 +120,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a defined starting point (from)
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
@@ -141,7 +135,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a defined starting point and limit
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
@@ -156,7 +150,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of room
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?room_id=%s" % self.room_id1,
access_token=self.admin_user_tok,
@@ -176,7 +170,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of user
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?user_id=%s" % self.other_user,
access_token=self.admin_user_tok,
@@ -196,7 +190,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of user and room
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
access_token=self.admin_user_tok,
@@ -218,7 +212,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
# fetch the most recent first, largest timestamp
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?dir=b", access_token=self.admin_user_tok,
)
@@ -234,7 +228,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
report += 1
# fetch the oldest first, smallest timestamp
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?dir=f", access_token=self.admin_user_tok,
)
@@ -254,7 +248,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing that a invalid search order returns a 400
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
)
@@ -267,7 +261,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing that a negative limit parameter returns a 400
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
@@ -279,7 +273,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing that a negative from parameter returns a 400
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
@@ -293,7 +287,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
@@ -304,7 +298,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
@@ -315,7 +309,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
@@ -327,7 +321,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# Check
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
@@ -342,7 +336,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
json.dumps({"score": -100, "reason": "this makes me sad"}),
@@ -375,8 +369,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -399,7 +391,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
Try to get event report without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -409,9 +401,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.other_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -421,9 +411,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
Testing get a reported event
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self._check_fields(channel.json_body)
@@ -434,7 +422,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
# `report_id` is negative
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v1/event_reports/-123",
access_token=self.admin_user_tok,
@@ -448,7 +436,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
)
# `report_id` is a non-numerical string
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v1/event_reports/abcdef",
access_token=self.admin_user_tok,
@@ -462,7 +450,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
)
# `report_id` is undefined
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v1/event_reports/",
access_token=self.admin_user_tok,
@@ -480,7 +468,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
Testing that a not existing `report_id` returns a 404.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v1/event_reports/123",
access_token=self.admin_user_tok,
@@ -496,7 +484,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
json.dumps({"score": -100, "reason": "this makes me sad"}),
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index dadf9db660..51a7731693 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.handler = hs.get_device_handler()
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
@@ -50,7 +49,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
- request, channel = self.make_request("DELETE", url, b"{}")
+ channel = self.make_request("DELETE", url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -64,9 +63,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
- request, channel = self.make_request(
- "DELETE", url, access_token=self.other_user_token,
- )
+ channel = self.make_request("DELETE", url, access_token=self.other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -77,9 +74,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -90,9 +85,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
@@ -121,7 +114,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name)
# Attempt to access media
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
@@ -146,9 +139,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
# Delete media
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
@@ -157,7 +148,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
)
# Attempt to access media
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
@@ -189,7 +180,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.handler = hs.get_device_handler()
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
@@ -204,7 +194,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
Try to delete media without authentication.
"""
- request, channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -216,7 +206,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.other_user = self.register_user("user", "pass")
self.other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", self.url, access_token=self.other_user_token,
)
@@ -229,7 +219,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
)
@@ -240,9 +230,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
If the parameter `before_ts` is missing, an error is returned.
"""
- request, channel = self.make_request(
- "POST", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
@@ -254,7 +242,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
If parameters are invalid, an error is returned.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
)
@@ -265,7 +253,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=1234&size_gt=-1234",
access_token=self.admin_user_tok,
@@ -278,7 +266,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=1234&keep_profiles=not_bool",
access_token=self.admin_user_tok,
@@ -308,7 +296,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# timestamp after upload/create
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
@@ -332,7 +320,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
@@ -344,7 +332,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# timestamp after upload
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
@@ -367,7 +355,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok,
@@ -378,7 +366,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok,
@@ -401,7 +389,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
# set media as avatar
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.admin_user,),
content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
@@ -410,7 +398,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
@@ -421,7 +409,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
@@ -445,7 +433,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# set media as room avatar
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/state/m.room.avatar" % (room_id,),
content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
@@ -454,7 +442,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
@@ -465,7 +453,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id)
now_ms = self.clock.time_msec()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
@@ -512,7 +500,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id)
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46933a0493..7c47aa7e0a 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -20,6 +20,7 @@ from typing import List, Optional
from mock import Mock
import synapse.rest.admin
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
from synapse.rest.client.v1 import directory, events, login, room
@@ -79,7 +80,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/shutdown_room/" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
@@ -103,7 +104,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
# Enable world readable
url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
json.dumps({"history_visibility": "world_readable"}),
@@ -113,7 +114,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/shutdown_room/" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
@@ -130,7 +131,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
"""
url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.assertEqual(
@@ -138,7 +139,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
)
url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.assertEqual(
@@ -184,7 +185,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", self.url, json.dumps({}), access_token=self.other_user_tok,
)
@@ -197,7 +198,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, json.dumps({}), access_token=self.admin_user_tok,
)
@@ -210,7 +211,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/rooms/invalidroom/delete"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, json.dumps({}), access_token=self.admin_user_tok,
)
@@ -225,7 +226,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"new_room_user_id": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -244,7 +245,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"new_room_user_id": "@not:exist.bla"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -262,7 +263,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"block": "NotBool"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -278,7 +279,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"purge": "NotBool"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -304,7 +305,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": True, "purge": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
@@ -337,7 +338,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": False, "purge": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
@@ -371,7 +372,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": False, "purge": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
@@ -418,7 +419,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
@@ -448,7 +449,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Enable world readable
url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
json.dumps({"history_visibility": "world_readable"}),
@@ -465,7 +466,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
@@ -530,7 +531,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.assertEqual(
@@ -538,7 +539,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.assertEqual(
@@ -569,7 +570,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
{"room_id": room_id},
@@ -604,8 +605,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -623,7 +622,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
@@ -704,7 +703,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
limit,
"name",
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(
@@ -744,7 +743,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_ids, returned_room_ids)
url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -764,7 +763,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Create a new alias to this room
url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
{"room_id": room_id},
@@ -794,7 +793,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -835,7 +834,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/directory/room/%s" % (
urllib.parse.quote(test_alias),
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
{"room_id": room_id},
@@ -875,7 +874,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
if reverse:
url += "&dir=b"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1011,7 +1010,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
expected_http_code: The expected http code for the request
"""
url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
@@ -1050,6 +1049,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(room_id_2, "else")
_search_test(room_id_2, "se")
+ # Test case insensitive
+ _search_test(room_id_1, "SOMETHING")
+ _search_test(room_id_1, "THING")
+
+ _search_test(room_id_2, "ELSE")
+ _search_test(room_id_2, "SE")
+
_search_test(None, "foo")
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
@@ -1072,7 +1078,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1084,6 +1090,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("joined_local_devices", channel.json_body)
self.assertIn("version", channel.json_body)
self.assertIn("creator", channel.json_body)
self.assertIn("encryption", channel.json_body)
@@ -1096,6 +1103,39 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ def test_single_room_devices(self):
+ """Test that `joined_local_devices` can be requested correctly"""
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["joined_local_devices"])
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(2, channel.json_body["joined_local_devices"])
+
+ # leave room
+ self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok)
+ self.helper.leave(room_id_1, user_1, tok=user_tok_1)
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["joined_local_devices"])
+
def test_room_members(self):
"""Test that room members can be requested correctly"""
# Create two test rooms
@@ -1119,7 +1159,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id_2, user_3, tok=user_tok_3)
url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1130,7 +1170,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 3)
url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1140,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.json_body["total"], 3)
+ def test_room_state(self):
+ """Test that room state can be requested correctly"""
+ # Create two test rooms
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("state", channel.json_body)
+ # testing that the state events match is painful and not done here. We assume that
+ # the create_room already does the right thing, so no need to verify that we got
+ # the state events it created.
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
@@ -1170,7 +1225,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1186,7 +1241,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"unknown_parameter": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1202,7 +1257,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1218,7 +1273,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": "@not:exist.bla"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1238,7 +1293,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/!unknown:test"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1255,7 +1310,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/invalidroom"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1274,7 +1329,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
@@ -1286,7 +1341,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1303,7 +1358,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1333,7 +1388,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if server admin is a member of the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1344,7 +1399,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1355,7 +1410,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1372,7 +1427,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
@@ -1384,13 +1439,150 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format(
+ self.public_room_id
+ )
+
+ def test_public_room(self):
+ """Test that getting admin in a public room works.
+ """
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room and ban a user.
+ self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+ self.helper.change_membership(
+ room_id,
+ self.admin_user,
+ "@test:test",
+ Membership.BAN,
+ tok=self.admin_user_tok,
+ )
+
+ def test_private_room(self):
+ """Test that getting admin in a private room works and we get invited.
+ """
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False,
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room (we should have received an
+ # invite) and can ban a user.
+ self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+ self.helper.change_membership(
+ room_id,
+ self.admin_user,
+ "@test:test",
+ Membership.BAN,
+ tok=self.admin_user_tok,
+ )
+
+ def test_other_user(self):
+ """Test that giving admin in a public room works to a non-admin user works.
+ """
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={"user_id": self.second_user_id},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room and ban a user.
+ self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
+ self.helper.change_membership(
+ room_id,
+ self.second_user_id,
+ "@test:test",
+ Membership.BAN,
+ tok=self.second_tok,
+ )
+
+ def test_not_enough_power(self):
+ """Test that we get a sensible error if there are no local room admins.
+ """
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ # The creator drops admin rights in the room.
+ pl = self.helper.get_state(
+ room_id, EventTypes.PowerLevels, tok=self.creator_tok
+ )
+ pl["users"][self.creator] = 0
+ self.helper.send_state(
+ room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ # We expect this to fail with a 400 as there are no room admins.
+ #
+ # (Note we assert the error message to ensure that it's not denied for
+ # some other reason)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ channel.json_body["error"],
+ "No local admin user in room with power to update power levels.",
+ )
+
+
PURGE_TABLES = [
"current_state_events",
"event_backward_extremities",
@@ -1419,7 +1611,6 @@ PURGE_TABLES = [
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
- "local_invites",
"room_account_data",
"room_tags",
# "state_groups", # Current impl leaves orphaned state groups around.
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 907b49f889..f48be3d65a 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -46,7 +45,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
Try to list users without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -55,7 +54,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url, json.dumps({}), access_token=self.other_user_tok,
)
@@ -67,7 +66,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
If parameters are invalid, an error is returned.
"""
# unkown order_by
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
)
@@ -75,7 +74,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
@@ -83,7 +82,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
@@ -91,7 +90,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from_ts
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
)
@@ -99,7 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative until_ts
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
)
@@ -107,7 +106,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# until_ts smaller from_ts
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?from_ts=10&until_ts=5",
access_token=self.admin_user_tok,
@@ -117,7 +116,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# empty search term
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?search_term=", access_token=self.admin_user_tok,
)
@@ -125,7 +124,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
)
@@ -138,7 +137,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
self._create_users_with_media(10, 2)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
@@ -154,7 +153,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
self._create_users_with_media(20, 2)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
@@ -170,7 +169,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
self._create_users_with_media(20, 2)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
@@ -190,7 +189,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
@@ -201,7 +200,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
@@ -212,7 +211,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
@@ -223,7 +222,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# Set `from` to value of `next_token` for request remaining entries
# Check `next_token` does not appear
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
@@ -238,9 +237,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
if users have no media created
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -316,15 +313,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
ts1 = self.clock.time_msec()
# list all media when filter is not set
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media
# result is 0
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -337,7 +332,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self._create_media(self.other_user_tok, 3)
# filter media between `ts1` and `ts2`
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok,
@@ -346,7 +341,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media until `ts2` and earlier
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -356,14 +351,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self._create_users_with_media(20, 1)
# check without filter get all users
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
# filter user 1 and 10-19 by `user_id`
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok,
@@ -372,7 +365,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 11)
# filter on this user in `displayname`
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok,
@@ -382,7 +375,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 1)
# filter and get empty result
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -447,7 +440,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
url = self.url + "?order_by=%s" % (order_type,)
if dir is not None and dir in ("b", "f"):
url += "&dir=%s" % (dir,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 54d46f4bd3..ee05ee60bc 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,14 +18,17 @@ import hmac
import json
import urllib.parse
from binascii import unhexlify
+from typing import Optional
from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import make_awaitable
@@ -70,7 +73,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
self.hs.config.registration_shared_secret = None
- request, channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -87,7 +90,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.hs.get_secrets = Mock(return_value=secrets)
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
@@ -96,14 +99,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
# 59 seconds
self.reactor.advance(59)
body = json.dumps({"nonce": nonce})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("username must be specified", channel.json_body["error"])
@@ -111,7 +114,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# 61 seconds
self.reactor.advance(2)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -120,7 +123,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -136,7 +139,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -146,7 +149,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -165,7 +168,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -174,7 +177,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
A valid unrecognised nonce.
"""
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -190,13 +193,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -209,7 +212,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
def nonce():
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
return channel.json_body["nonce"]
#
@@ -218,7 +221,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -229,28 +232,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({"nonce": nonce()})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
@@ -261,28 +264,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
@@ -300,7 +303,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"user_type": "invalid",
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid user type", channel.json_body["error"])
@@ -311,7 +314,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
# set no displayname
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -321,17 +324,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps(
{"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob1:test", channel.json_body["user_id"])
- request, channel = self.make_request("GET", "/profile/@bob1:test/displayname")
+ channel = self.make_request("GET", "/profile/@bob1:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -347,17 +350,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob2:test", channel.json_body["user_id"])
- request, channel = self.make_request("GET", "/profile/@bob2:test/displayname")
+ channel = self.make_request("GET", "/profile/@bob2:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -373,16 +376,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob3:test", channel.json_body["user_id"])
- request, channel = self.make_request("GET", "/profile/@bob3:test/displayname")
+ channel = self.make_request("GET", "/profile/@bob3:test/displayname")
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
# set displayname
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -398,12 +401,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob4:test", channel.json_body["user_id"])
- request, channel = self.make_request("GET", "/profile/@bob4:test/displayname")
+ channel = self.make_request("GET", "/profile/@bob4:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@@ -429,7 +432,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
# Register new user with admin API
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -448,7 +451,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -466,23 +469,34 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.register_user("user1", "pass1", admin=False)
- self.register_user("user2", "pass2", admin=False)
-
def test_no_auth(self):
"""
Try to list users without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self._create_users(1)
+ other_user_token = self.login("user1", "pass1")
+
+ channel = self.make_request("GET", self.url, access_token=other_user_token)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self):
"""
List all users, including deactivated users.
"""
- request, channel = self.make_request(
+ self._create_users(2)
+
+ channel = self.make_request(
"GET",
self.url + "?deactivated=true",
b"{}",
@@ -493,6 +507,449 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
+ # Check that all fields are available
+ self._check_fields(channel.json_body["users"])
+
+ def test_search_term(self):
+ """Test that searching for a users works correctly"""
+
+ def _search_test(
+ expected_user_id: Optional[str],
+ search_term: str,
+ search_field: Optional[str] = "name",
+ expected_http_code: Optional[int] = 200,
+ ):
+ """Search for a user and check that the returned user's id is a match
+
+ Args:
+ expected_user_id: The user_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for user names with
+ search_field: Field which is to request: `name` or `user_id`
+ expected_http_code: The expected http code for the request
+ """
+ url = self.url + "?%s=%s" % (search_field, search_term,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that users were returned
+ self.assertTrue("users" in channel.json_body)
+ self._check_fields(channel.json_body["users"])
+ users = channel.json_body["users"]
+
+ # Check that the expected number of users were returned
+ expected_user_count = 1 if expected_user_id else 0
+ self.assertEqual(len(users), expected_user_count)
+ self.assertEqual(channel.json_body["total"], expected_user_count)
+
+ if expected_user_id:
+ # Check that the first returned user id is correct
+ u = users[0]
+ self.assertEqual(expected_user_id, u["name"])
+
+ self._create_users(2)
+
+ user1 = "@user1:test"
+ user2 = "@user2:test"
+
+ # Perform search tests
+ _search_test(user1, "er1")
+ _search_test(user1, "me 1")
+
+ _search_test(user2, "er2")
+ _search_test(user2, "me 2")
+
+ _search_test(user1, "er1", "user_id")
+ _search_test(user2, "er2", "user_id")
+
+ # Test case insensitive
+ _search_test(user1, "ER1")
+ _search_test(user1, "NAME 1")
+
+ _search_test(user2, "ER2")
+ _search_test(user2, "NAME 2")
+
+ _search_test(user1, "ER1", "user_id")
+ _search_test(user2, "ER2", "user_id")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+
+ _search_test(None, "foo", "user_id")
+ _search_test(None, "bar", "user_id")
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+
+ # negative limit
+ channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative from
+ channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid guests
+ channel = self.make_request(
+ "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid deactivated
+ channel = self.make_request(
+ "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ def test_limit(self):
+ """
+ Testing list of users with limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 5)
+ self.assertEqual(channel.json_body["next_token"], "5")
+ self._check_fields(channel.json_body["users"])
+
+ def test_from(self):
+ """
+ Testing list of users with a defined starting point (from)
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["users"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of users with a defined starting point and limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(channel.json_body["next_token"], "15")
+ self.assertEqual(len(channel.json_body["users"]), 10)
+ self._check_fields(channel.json_body["users"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 19)
+ self.assertEqual(channel.json_body["next_token"], "19")
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _check_fields(self, content: JsonDict):
+ """Checks that the expected user attributes are present in content
+ Args:
+ content: List that is checked for content
+ """
+ for u in content:
+ self.assertIn("name", u)
+ self.assertIn("is_guest", u)
+ self.assertIn("admin", u)
+ self.assertIn("user_type", u)
+ self.assertIn("deactivated", u)
+ self.assertIn("displayname", u)
+ self.assertIn("avatar_url", u)
+
+ def _create_users(self, number_users: int):
+ """
+ Create a number of users
+ Args:
+ number_users: Number of users to be created
+ """
+ for i in range(1, number_users + 1):
+ self.register_user(
+ "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
+ )
+
+
+class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass", displayname="User1")
+ self.other_user_token = self.login("user", "pass")
+ self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
+ self.other_user
+ )
+ self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote(
+ self.other_user
+ )
+
+ # set attributes for user
+ self.get_success(
+ self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+ )
+ self.get_success(
+ self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+ )
+
+ def test_no_auth(self):
+ """
+ Try to deactivate users without authentication.
+ """
+ channel = self.make_request("POST", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ url = "/_synapse/admin/v1/deactivate/@bob:test"
+
+ channel = self.make_request("POST", url, access_token=self.other_user_token)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+ channel = self.make_request(
+ "POST", url, access_token=self.other_user_token, content=b"{}",
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that deactivation for a user that does not exist returns a 404
+ """
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/deactivate/@unknown_person:test",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_erase_is_not_bool(self):
+ """
+ If parameter `erase` is not boolean, return an error
+ """
+ body = json.dumps({"erase": "False"})
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that deactivation for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
+
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only deactivate local users", channel.json_body["error"])
+
+ def test_deactivate_user_erase_true(self):
+ """
+ Test deactivating an user and set `erase` to `true`
+ """
+
+ # Get user
+ channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User1", channel.json_body["displayname"])
+
+ # Deactivate user
+ body = json.dumps({"erase": True})
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
+
+ self._is_erased("@user:test", True)
+
+ def test_deactivate_user_erase_false(self):
+ """
+ Test deactivating an user and set `erase` to `false`
+ """
+
+ # Get user
+ channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User1", channel.json_body["displayname"])
+
+ # Deactivate user
+ body = json.dumps({"erase": False})
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User1", channel.json_body["displayname"])
+
+ self._is_erased("@user:test", False)
+
+ def _is_erased(self, user_id: str, expect: bool) -> None:
+ """Assert that the user is erased or not
+ """
+ d = self.store.is_user_erased(user_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertFalse(self.get_success(d))
+
class UserRestTestCase(unittest.HomeserverTestCase):
@@ -508,7 +965,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.other_user = self.register_user("user", "pass")
+ self.other_user = self.register_user("user", "pass", displayname="User")
self.other_user_token = self.login("user", "pass")
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user
@@ -520,14 +977,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@bob:test"
- request, channel = self.make_request(
- "GET", url, access_token=self.other_user_token,
- )
+ channel = self.make_request("GET", url, access_token=self.other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("You are not a server admin", channel.json_body["error"])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, access_token=self.other_user_token, content=b"{}",
)
@@ -539,7 +994,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v2/users/@unknown_person:test",
access_token=self.admin_user_tok,
@@ -561,11 +1016,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- "avatar_url": None,
+ "avatar_url": "mxc://fibble/wibble",
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -578,11 +1033,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -592,6 +1046,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(True, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self):
"""
@@ -606,10 +1061,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": False,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "avatar_url": "mxc://fibble/wibble",
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -622,11 +1078,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -636,6 +1091,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(False, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -651,9 +1107,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Sync to set admin user to active
# before limit of monthly active users is reached
- request, channel = self.make_request(
- "GET", "/sync", access_token=self.admin_user_tok
- )
+ channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
if channel.code != 200:
raise HttpResponseException(
@@ -676,7 +1130,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Create user
body = json.dumps({"password": "abc123", "admin": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -715,7 +1169,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Create user
body = json.dumps({"password": "abc123", "admin": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -752,7 +1206,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -769,7 +1223,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual("@bob:test", pushers[0]["user_name"])
+ self.assertEqual("@bob:test", pushers[0].user_name)
@override_config(
{
@@ -796,7 +1250,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -822,7 +1276,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Change password
body = json.dumps({"password": "hahaha"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -839,7 +1293,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Modify user
body = json.dumps({"displayname": "foobar"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -851,7 +1305,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("foobar", channel.json_body["displayname"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
@@ -869,7 +1323,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
{"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -882,7 +1336,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
@@ -896,10 +1350,30 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Test deactivating another user.
"""
+ # set attributes for user
+ self.get_success(
+ self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+ )
+ self.get_success(
+ self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+ )
+
+ # Get user
+ channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User", channel.json_body["displayname"])
+
# Deactivate user
body = json.dumps({"deactivated": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -909,16 +1383,70 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User", channel.json_body["displayname"])
# the user is deactivated, the threepid will be deleted
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+ self.assertEqual("User", channel.json_body["displayname"])
+
+ @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+ def test_change_name_deactivate_user_user_directory(self):
+ """
+ Test change profile information of a deactivated user and
+ check that it does not appear in user directory
+ """
+
+ # is in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertTrue(profile["display_name"] == "User")
+
+ # Deactivate user
+ body = json.dumps({"deactivated": True})
+
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+
+ # is not in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertTrue(profile is None)
+
+ # Set new displayname user
+ body = json.dumps({"displayname": "Foobar"})
+
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual("Foobar", channel.json_body["displayname"])
+
+ # is not in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertTrue(profile is None)
def test_reactivate_user(self):
"""
@@ -926,7 +1454,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Deactivate the user.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -939,7 +1467,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password).
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -948,7 +1476,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
# Reactivate the user.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -959,7 +1487,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
@@ -976,7 +1504,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set a user as an admin
body = json.dumps({"admin": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -988,7 +1516,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(True, channel.json_body["admin"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
@@ -1006,7 +1534,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Create user
body = json.dumps({"password": "abc123"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -1018,9 +1546,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("bob", channel.json_body["displayname"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1030,7 +1556,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Change password (and use a str for deactivate instead of a bool)
body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
@@ -1040,9 +1566,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
# Check user is not deactivated
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1070,8 +1594,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -1084,7 +1606,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
Try to list rooms of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1095,37 +1617,33 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
- )
+ channel = self.make_request("GET", self.url, access_token=other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns an empty list
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["joined_rooms"]))
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["joined_rooms"]))
def test_no_memberships(self):
"""
@@ -1133,9 +1651,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
if user has no memberships
"""
# Get rooms
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1152,14 +1668,55 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.helper.create_room_as(self.other_user, tok=other_user_tok)
# Get rooms
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
+ def test_get_rooms_with_nonlocal_user(self):
+ """
+ Tests that a normal lookup for rooms is successful with a non-local user
+ """
+
+ other_user_tok = self.login("user", "pass")
+ event_builder_factory = self.hs.get_event_builder_factory()
+ event_creation_handler = self.hs.get_event_creation_handler()
+ storage = self.hs.get_storage()
+
+ # Create two rooms, one with a local user only and one with both a local
+ # and remote user.
+ self.helper.create_room_as(self.other_user, tok=other_user_tok)
+ local_and_remote_room_id = self.helper.create_room_as(
+ self.other_user, tok=other_user_tok
+ )
+
+ # Add a remote user to the room.
+ builder = event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": "m.room.member",
+ "sender": "@joiner:remote_hs",
+ "state_key": "@joiner:remote_hs",
+ "room_id": local_and_remote_room_id,
+ "content": {"membership": "join"},
+ },
+ )
+
+ event, context = self.get_success(
+ event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(storage.persistence.persist_event(event, context))
+
+ # Now get rooms
+ url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
+
class PushersRestTestCase(unittest.HomeserverTestCase):
@@ -1183,7 +1740,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
Try to list pushers of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1194,9 +1751,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
- )
+ channel = self.make_request("GET", self.url, access_token=other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1206,9 +1761,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1219,9 +1772,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1232,9 +1783,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
# Get pushers
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1256,14 +1805,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "https://example.com/_matrix/push/v1/notify"},
)
)
# Get pushers
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
@@ -1287,7 +1834,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -1302,7 +1848,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
Try to list media of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1313,9 +1859,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
- )
+ channel = self.make_request("GET", self.url, access_token=other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1325,9 +1869,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1338,9 +1880,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1354,7 +1894,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
@@ -1373,7 +1913,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
@@ -1392,7 +1932,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
@@ -1407,7 +1947,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
Testing that a negative limit parameter returns a 400
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
@@ -1419,7 +1959,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
Testing that a negative from parameter returns a 400
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
@@ -1437,7 +1977,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
@@ -1448,7 +1988,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
@@ -1459,7 +1999,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
@@ -1471,7 +2011,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Check
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
@@ -1486,9 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
if user has no media created
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1503,9 +2041,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media)
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
@@ -1571,7 +2107,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
)
def _get_token(self) -> str:
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1580,7 +2116,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
def test_no_auth(self):
"""Try to login as a user without authentication.
"""
- request, channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1588,7 +2124,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
def test_not_admin(self):
"""Try to login as a user as a non-admin user.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", self.url, b"{}", access_token=self.other_user_tok
)
@@ -1616,7 +2152,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self._get_token()
# Check that we don't see a new device in our devices list
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1631,25 +2167,19 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
puppet_token = self._get_token()
# Test that we can successfully make a request
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout with the puppet token
- request, channel = self.make_request(
- "POST", "logout", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should no longer work
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens should still work
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1662,25 +2192,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
puppet_token = self._get_token()
# Test that we can successfully make a request
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout all with the real user token
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should still work
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens shouldn't
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
@@ -1693,25 +2219,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
puppet_token = self._get_token()
# Test that we can successfully make a request
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout all with the admin user token
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should no longer work
- request, channel = self.make_request(
- "GET", "devices", b"{}", access_token=puppet_token
- )
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens should still work
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1778,8 +2300,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -1793,11 +2313,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
Try to get information of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url1, b"{}")
+ channel = self.make_request("GET", self.url1, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- request, channel = self.make_request("GET", self.url2, b"{}")
+ channel = self.make_request("GET", self.url2, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1808,15 +2328,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.register_user("user2", "pass")
other_user2_token = self.login("user2", "pass")
- request, channel = self.make_request(
- "GET", self.url1, access_token=other_user2_token,
- )
+ channel = self.make_request("GET", self.url1, access_token=other_user2_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- request, channel = self.make_request(
- "GET", self.url2, access_token=other_user2_token,
- )
+ channel = self.make_request("GET", self.url2, access_token=other_user2_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1827,15 +2343,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
- request, channel = self.make_request(
- "GET", url1, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url1, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
- request, channel = self.make_request(
- "GET", url2, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", url2, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
@@ -1843,16 +2355,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
The lookup should succeed for an admin.
"""
- request, channel = self.make_request(
- "GET", self.url1, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
- request, channel = self.make_request(
- "GET", self.url2, access_token=self.admin_user_tok,
- )
+ channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -1863,16 +2371,76 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url1, access_token=other_user_token,
- )
+ channel = self.make_request("GET", self.url1, access_token=other_user_token,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
- request, channel = self.make_request(
- "GET", self.url2, access_token=other_user_token,
- )
+ channel = self.make_request("GET", self.url2, access_token=other_user_token,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
+
+
+class ShadowBanRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of an user without authentication.
+ """
+ channel = self.make_request("POST", self.url)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request("POST", self.url, access_token=other_user_token)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that shadow-banning for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+
+ def test_success(self):
+ """
+ Shadow-banning should succeed for an admin.
+ """
+ # The user starts off as not shadow-banned.
+ other_user_token = self.login("user", "pass")
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertFalse(result.shadow_banned)
+
+ channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ # Ensure the user is shadow-banned (and the cache was cleared).
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertTrue(result.shadow_banned)
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index e2e6a5e16d..c74693e9b2 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -61,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
- request, channel = make_request(
+ channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
)
self.assertEqual(channel.code, 200)
@@ -82,7 +82,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user"
)
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
@@ -97,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"POST",
@@ -109,7 +109,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Fetch the consent page, to get the consent version -- it should have
# changed
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index a1ccc4ee9a..56937dcd2e 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -93,7 +93,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
def get_event(self, room_id, event_id, expected_code=200):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
- request, channel = self.make_request("GET", url)
+ channel = self.make_request("GET", url)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 259c6a1985..c0a9fc6925 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -43,9 +43,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
- request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
- )
+ channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
@@ -56,7 +54,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
}
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", request_url, request_data, access_token=tok
)
self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index c1f516cc93..f0707646bb 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -69,16 +69,12 @@ class RedactionsTestCase(HomeserverTestCase):
"""
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
- request, channel = self.make_request(
- "POST", path, content={}, access_token=access_token
- )
+ channel = self.make_request("POST", path, content={}, access_token=access_token)
self.assertEqual(int(channel.result["code"]), expect_code)
return channel.json_body
def _sync_room_timeline(self, access_token, room_id):
- request, channel = self.make_request(
- "GET", "sync", access_token=self.mod_access_token
- )
+ channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
self.assertEqual(channel.result["code"], b"200")
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index f56b5d9231..31dc832fd5 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -325,7 +325,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def get_event(self, room_id, event_id, expected_code=200):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
- request, channel = self.make_request("GET", url, access_token=self.token)
+ channel = self.make_request("GET", url, access_token=self.token)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 94dcfb9f7c..0ebdf1415b 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -18,6 +18,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.types import UserID
from tests import unittest
@@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
self.store = self.hs.get_datastore()
self.get_success(
- self.store.db_pool.simple_update(
- table="users",
- keyvalues={"name": self.banned_user_id},
- updatevalues={"shadow_banned": True},
- desc="shadow_ban",
- )
+ self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
)
self.other_user_id = self.register_user("otheruser", "pass")
@@ -89,7 +85,7 @@ class RoomTestCase(_ShadowBannedBase):
)
# Inviting the user completes successfully.
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
{"id_server": "test", "medium": "email", "address": "test@test.test"},
@@ -103,7 +99,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_create_room(self):
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/createRoom",
{"visibility": "public", "invite": [self.other_user_id]},
@@ -158,7 +154,7 @@ class RoomTestCase(_ShadowBannedBase):
self.banned_user_id, tok=self.banned_access_token
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
{"new_version": "6"},
@@ -183,7 +179,7 @@ class RoomTestCase(_ShadowBannedBase):
self.banned_user_id, tok=self.banned_access_token
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
{"typing": True, "timeout": 30000},
@@ -198,7 +194,7 @@ class RoomTestCase(_ShadowBannedBase):
# The other user can join and send typing events.
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (room_id, self.other_user_id),
{"typing": True, "timeout": 30000},
@@ -244,7 +240,7 @@ class ProfileTestCase(_ShadowBannedBase):
)
# The update should succeed.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
{"displayname": new_display_name},
@@ -254,7 +250,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.assertEqual(channel.json_body, {})
# The user's display name should be updated.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/profile/%s/displayname" % (self.banned_user_id,)
)
self.assertEqual(channel.code, 200, channel.result)
@@ -282,7 +278,7 @@ class ProfileTestCase(_ShadowBannedBase):
)
# The update should succeed.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
% (room_id, self.banned_user_id),
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 0e96697f9b..227fffab58 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -86,7 +86,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
callback = Mock(spec=[], side_effect=check)
current_rules_module().check_event_allowed = callback
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
{},
@@ -104,7 +104,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.assertEqual(ev.type, k[0])
self.assertEqual(ev.state_key, k[1])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
{},
@@ -123,7 +123,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
current_rules_module().check_event_allowed = check
# now send the event
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
{"x": "x"},
@@ -142,7 +142,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
current_rules_module().check_event_allowed = check
# now send the event
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
{"x": "x"},
@@ -152,7 +152,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
event_id = channel.json_body["event_id"]
# ... and check that it got modified
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 7a2c653df8..edd1d184f8 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -91,7 +91,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# that we can make sure that the check is done on the whole alias.
data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, 400, channel.result)
@@ -104,7 +104,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# as cautious as possible here.
data = {"room_alias_name": random_string(5)}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, 200, channel.result)
@@ -118,7 +118,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
data = {"aliases": [self.random_alias(alias_length)]}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -128,7 +128,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 12a93f5687..0a5ca317ea 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -63,13 +63,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
# implementation is now part of the r0 implementation, the newer
# behaviour is used instead to be consistent with the r0 spec.
# see issue #2602
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s" % ("invalid" + self.token,)
)
self.assertEquals(channel.code, 401, msg=channel.result)
# valid token, expect content
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
self.assertEquals(channel.code, 200, msg=channel.result)
@@ -87,7 +87,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
)
# valid token, expect content
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
self.assertEquals(channel.code, 200, msg=channel.result)
@@ -149,7 +149,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(self.room_id, tok=self.token)
event_id = resp["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events/" + event_id, access_token=self.token,
)
self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 176ddf7ec9..bfcb786af8 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,23 +1,83 @@
-import json
+# -*- coding: utf-8 -*-
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import time
import urllib.parse
+from typing import Any, Dict, Union
+from urllib.parse import urlencode
from mock import Mock
-import jwt
+import pymacaroons
+
+from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.types import create_requester
from tests import unittest
-from tests.unittest import override_config
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.handlers.test_saml import has_saml2
+from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.test_utils.html_parsers import TestHtmlParser
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+
+try:
+ import jwt
+
+ HAS_JWT = True
+except ImportError:
+ HAS_JWT = False
+
+
+# public_base_url used in some tests
+BASE_URL = "https://synapse/"
+
+# CAS server used in some tests
+CAS_SERVER = "https://fake.test"
+
+# just enough to tell pysaml2 where to redirect to
+SAML_SERVER = "https://test.saml.server/idp/sso"
+TEST_SAML_METADATA = """
+<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
+ <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
+ <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
+ </md:IDPSSODescriptor>
+</md:EntityDescriptor>
+""" % {
+ "SAML_SERVER": SAML_SERVER,
+}
LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami"
+# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
+TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
+
+# the query params in TEST_CLIENT_REDIRECT_URL
+EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -63,7 +123,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -82,7 +142,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -108,7 +168,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -127,7 +187,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -153,7 +213,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -172,7 +232,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"403", channel.result)
@@ -181,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.register_user("kermit", "monkey")
# we shouldn't be able to make requests without an access token
- request, channel = self.make_request(b"GET", TEST_URL)
+ channel = self.make_request(b"GET", TEST_URL)
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
@@ -191,25 +251,21 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -223,9 +279,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -233,16 +287,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# ... but if we delete that device, it will be a proper logout
self._delete_device(access_token_2, "kermit", "monkey", device_id)
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], False)
def _delete_device(self, access_token, user_id, password, device_id):
"""Perform the UI-Auth to delete a device"""
- request, channel = self.make_request(
+ channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
self.assertEquals(channel.code, 401, channel.result)
@@ -262,7 +314,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"session": channel.json_body["session"],
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"DELETE",
"devices/" + device_id,
access_token=access_token,
@@ -278,26 +330,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token = self.login("kermit", "monkey")
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# Now try to hard logout this session
- request, channel = self.make_request(
- b"POST", "/logout", access_token=access_token
- )
+ channel = self.make_request(b"POST", "/logout", access_token=access_token)
self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"session_lifetime": "24h"})
@@ -308,29 +354,313 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token = self.login("kermit", "monkey")
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions
- request, channel = self.make_request(
- b"POST", "/logout/all", access_token=access_token
- )
+ channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
self.assertEquals(channel.result["code"], b"200", channel.result)
+@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
+class MultiSSOTestCase(unittest.HomeserverTestCase):
+ """Tests for homeservers with multiple SSO providers enabled"""
+
+ servlets = [
+ login.register_servlets,
+ ]
+
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+
+ config["public_baseurl"] = BASE_URL
+
+ config["cas_config"] = {
+ "enabled": True,
+ "server_url": CAS_SERVER,
+ "service_url": "https://matrix.goodserver.com:8448",
+ }
+
+ config["saml2_config"] = {
+ "sp_config": {
+ "metadata": {"inline": [TEST_SAML_METADATA]},
+ # use the XMLSecurity backend to avoid relying on xmlsec1
+ "crypto_backend": "XMLSecurity",
+ },
+ }
+
+ # default OIDC provider
+ config["oidc_config"] = TEST_OIDC_CONFIG
+
+ # additional OIDC providers
+ config["oidc_providers"] = [
+ {
+ "idp_id": "idp1",
+ "idp_name": "IDP1",
+ "discover": False,
+ "issuer": "https://issuer1",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": "https://issuer1/auth",
+ "token_endpoint": "https://issuer1/token",
+ "userinfo_endpoint": "https://issuer1/userinfo",
+ "user_mapping_provider": {
+ "config": {"localpart_template": "{{ user.sub }}"}
+ },
+ }
+ ]
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d.update(build_synapse_client_resource_tree(self.hs))
+ return d
+
+ def test_get_login_flows(self):
+ """GET /login should return password and SSO flows"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.sso"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(channel.json_body["flows"], expected_flows)
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_get_msc2858_login_flows(self):
+ """The SSO flow should include IdP info if MSC2858 is enabled"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # stick the flows results in a dict by type
+ flow_results = {} # type: Dict[str, Any]
+ for f in channel.json_body["flows"]:
+ flow_type = f["type"]
+ self.assertNotIn(
+ flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+ )
+ flow_results[flow_type] = f
+
+ self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+ sso_flow = flow_results.pop("m.login.sso")
+ # we should have a set of IdPs
+ self.assertCountEqual(
+ sso_flow["org.matrix.msc2858.identity_providers"],
+ [
+ {"id": "cas", "name": "CAS"},
+ {"id": "saml", "name": "SAML"},
+ {"id": "oidc-idp1", "name": "IDP1"},
+ {"id": "oidc", "name": "OIDC"},
+ ],
+ )
+
+ # the rest of the flows are simple
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(flow_results.values(), expected_flows)
+
+ def test_multi_sso_redirect(self):
+ """/login/sso/redirect should redirect to an identity picker"""
+ # first hit the redirect url, which should redirect to our idp picker
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/login/sso/redirect?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+ self.assertEqual(channel.code, 302, channel.result)
+ uri = channel.headers.getRawHeaders("Location")[0]
+
+ # hitting that picker should give us some HTML
+ channel = self.make_request("GET", uri)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # parse the form to check it has fields assumed elsewhere in this class
+ p = TestHtmlParser()
+ p.feed(channel.result["body"].decode("utf-8"))
+ p.close()
+
+ self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+
+ self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+
+ def test_multi_sso_redirect_to_cas(self):
+ """If CAS is chosen, should redirect to the CAS server"""
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=cas",
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 302, channel.result)
+ cas_uri = channel.headers.getRawHeaders("Location")[0]
+ cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
+
+ # it should redirect us to the login page of the cas server
+ self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
+
+ # check that the redirectUrl is correctly encoded in the service param - ie, the
+ # place that CAS will redirect to
+ cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
+ service_uri = cas_uri_params["service"][0]
+ _, service_uri_query = service_uri.split("?", 1)
+ service_uri_params = urllib.parse.parse_qs(service_uri_query)
+ self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
+
+ def test_multi_sso_redirect_to_saml(self):
+ """If SAML is chosen, should redirect to the SAML server"""
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=saml",
+ )
+ self.assertEqual(channel.code, 302, channel.result)
+ saml_uri = channel.headers.getRawHeaders("Location")[0]
+ saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
+
+ # it should redirect us to the login page of the SAML server
+ self.assertEqual(saml_uri_path, SAML_SERVER)
+
+ # the RelayState is used to carry the client redirect url
+ saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
+ relay_state_param = saml_uri_params["RelayState"][0]
+ self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
+
+ def test_login_via_oidc(self):
+ """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
+
+ # pick the default OIDC provider
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=oidc",
+ )
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+ # ... and should have set a cookie including the redirect url
+ cookies = dict(
+ h.split(";")[0].split("=", maxsplit=1)
+ for h in channel.headers.getRawHeaders("Set-Cookie")
+ )
+
+ oidc_session_cookie = cookies["oidc_session"]
+ macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
+ self.assertEqual(
+ self._get_value_from_macaroon(macaroon, "client_redirect_url"),
+ TEST_CLIENT_REDIRECT_URL,
+ )
+
+ channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+
+ # that should serve a confirmation page
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertTrue(
+ channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
+ )
+ p = TestHtmlParser()
+ p.feed(channel.text_body)
+ p.close()
+
+ # ... which should contain our redirect link
+ self.assertEqual(len(p.links), 1)
+ path, query = p.links[0].split("?", 1)
+ self.assertEqual(path, "https://x")
+
+ # it will have url-encoded the params properly, so we'll have to parse them
+ params = urllib.parse.parse_qsl(
+ query, keep_blank_values=True, strict_parsing=True, errors="strict"
+ )
+ self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+ self.assertEqual(params[2][0], "loginToken")
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ login_token = params[2][1]
+ chan = self.make_request(
+ "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], "@user1:test")
+
+ def test_multi_sso_redirect_to_unknown(self):
+ """An unknown IdP should cause a 400"""
+ channel = self.make_request(
+ "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def test_client_idp_redirect_msc2858_disabled(self):
+ """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_idp_redirect_to_unknown(self):
+ """If the client tries to pick an unknown IdP, return a 404"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+ self.assertEqual(channel.code, 404, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_idp_redirect_to_oidc(self):
+ """If the client pick a known IdP, redirect to it"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+ @staticmethod
+ def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+ prefix = key + " = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(prefix):
+ return caveat.caveat_id[len(prefix) :]
+ raise ValueError("No %s caveat in macaroon" % (key,))
+
+
class CASTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -342,10 +672,12 @@ class CASTestCase(unittest.HomeserverTestCase):
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
config = self.default_config()
+ config["public_baseurl"] = (
+ config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
+ )
config["cas_config"] = {
"enabled": True,
- "server_url": "https://fake.test",
- "service_url": "https://matrix.goodserver.com:8448",
+ "server_url": CAS_SERVER,
}
cas_user_id = "username"
@@ -402,10 +734,10 @@ class CASTestCase(unittest.HomeserverTestCase):
cas_ticket_url = urllib.parse.urlunparse(url_parts)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
+ channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
@@ -430,8 +762,7 @@ class CASTestCase(unittest.HomeserverTestCase):
}
)
def test_cas_redirect_whitelisted(self):
- """Tests that the SSO login flow serves a redirect to a whitelisted url
- """
+ """Tests that the SSO login flow serves a redirect to a whitelisted url"""
self._test_redirect("https://legit-site.com/")
@override_config({"public_baseurl": "https://example.com"})
@@ -446,7 +777,7 @@ class CASTestCase(unittest.HomeserverTestCase):
)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
+ channel = self.make_request("GET", cas_ticket_url)
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
@@ -462,7 +793,9 @@ class CASTestCase(unittest.HomeserverTestCase):
# Deactivate the account.
self.get_success(
- self.deactivate_account_handler.deactivate_account(self.user_id, False)
+ self.deactivate_account_handler.deactivate_account(
+ self.user_id, False, create_requester(self.user_id)
+ )
)
# Request the CAS ticket.
@@ -472,13 +805,14 @@ class CASTestCase(unittest.HomeserverTestCase):
)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
+ channel = self.make_request("GET", cas_ticket_url)
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
self.assertIn(b"SSO account deactivated", channel.result["body"])
+@skip_unless(HAS_JWT, "requires jwt")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -495,14 +829,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
- def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
+ def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
+ # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+ result = jwt.encode(
+ payload, secret, self.jwt_algorithm
+ ) # type: Union[str, bytes]
+ if isinstance(result, bytes):
+ return result.decode("ascii")
+ return result
def jwt_login(self, *args):
- params = json.dumps(
- {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
- )
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ channel = self.make_request(b"POST", LOGIN_URL, params)
return channel
def test_login_jwt_valid_registered(self):
@@ -633,8 +971,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
)
def test_login_no_token(self):
- params = json.dumps({"type": "org.matrix.login.jwt"})
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ params = {"type": "org.matrix.login.jwt"}
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -643,6 +981,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
+@skip_unless(HAS_JWT, "requires jwt")
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
@@ -700,14 +1039,16 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = "RS256"
return self.hs
- def jwt_encode(self, token, secret=jwt_privatekey):
- return jwt.encode(token, secret, "RS256").decode("ascii")
+ def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
+ # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+ result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
+ if isinstance(result, bytes):
+ return result.decode("ascii")
+ return result
def jwt_login(self, *args):
- params = json.dumps(
- {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
- )
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ channel = self.make_request(b"POST", LOGIN_URL, params)
return channel
def test_login_jwt_valid(self):
@@ -735,7 +1076,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
]
def register_as_user(self, username):
- request, channel = self.make_request(
+ self.make_request(
b"POST",
"/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
{"username": username},
@@ -776,60 +1117,56 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs
def test_login_appservice_user(self):
- """Test that an appservice user can use /login
- """
+ """Test that an appservice user can use /login"""
self.register_as_user(AS_USER)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_user_bot(self):
- """Test that the appservice bot can use /login
- """
+ """Test that the appservice bot can use /login"""
self.register_as_user(AS_USER)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": self.service.sender},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_wrong_user(self):
- """Test that non-as users cannot login with the as token
- """
+ """Test that non-as users cannot login with the as token"""
self.register_as_user(AS_USER)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": "fibble_wibble"},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_wrong_as(self):
- """Test that as users cannot login with wrong as token
- """
+ """Test that as users cannot login with wrong as token"""
self.register_as_user(AS_USER)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
@@ -837,7 +1174,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_no_token(self):
"""Test that users must provide a token when using the appservice
- login method
+ login method
"""
self.register_as_user(AS_USER)
@@ -845,6 +1182,116 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"401", channel.result)
+
+
+@skip_unless(HAS_OIDC, "requires OIDC")
+class UsernamePickerTestCase(HomeserverTestCase):
+ """Tests for the username picker flow of SSO login"""
+
+ servlets = [login.register_servlets]
+
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+
+ config["oidc_config"] = {}
+ config["oidc_config"].update(TEST_OIDC_CONFIG)
+ config["oidc_config"]["user_mapping_provider"] = {
+ "config": {"display_name_template": "{{ user.displayname }}"}
+ }
+
+ # whitelist this client URI so we redirect straight to it rather than
+ # serving a confirmation page
+ config["sso"] = {"client_whitelist": ["https://x"]}
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d.update(build_synapse_client_resource_tree(self.hs))
+ return d
+
+ def test_username_picker(self):
+ """Test the happy path of a username picker flow."""
+
+ # do the start of the login flow
+ channel = self.helper.auth_via_oidc(
+ {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+ )
+
+ # that should redirect to the username picker
+ self.assertEqual(channel.code, 302, channel.result)
+ picker_url = channel.headers.getRawHeaders("Location")[0]
+ self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
+
+ # ... with a username_mapping_session cookie
+ cookies = {} # type: Dict[str,str]
+ channel.extract_cookies(cookies)
+ self.assertIn("username_mapping_session", cookies)
+ session_id = cookies["username_mapping_session"]
+
+ # introspect the sso handler a bit to check that the username mapping session
+ # looks ok.
+ username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+ self.assertIn(
+ session_id, username_mapping_sessions, "session id not found in map",
+ )
+ session = username_mapping_sessions[session_id]
+ self.assertEqual(session.remote_user_id, "tester")
+ self.assertEqual(session.display_name, "Jonny")
+ self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
+
+ # the expiry time should be about 15 minutes away
+ expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+ self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+ # Now, submit a username to the username picker, which should serve a redirect
+ # to the completion page
+ content = urlencode({b"username": b"bobby"}).encode("utf8")
+ chan = self.make_request(
+ "POST",
+ path=picker_url,
+ content=content,
+ content_is_form=True,
+ custom_headers=[
+ ("Cookie", "username_mapping_session=" + session_id),
+ # old versions of twisted don't do form-parsing without a valid
+ # content-length header.
+ ("Content-Length", str(len(content))),
+ ],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+
+ # send a request to the completion page, which should 302 to the client redirectUrl
+ chan = self.make_request(
+ "GET",
+ path=location_headers[0],
+ custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+
+ # ensure that the returned location matches the requested redirect URL
+ path, query = location_headers[0].split("?", 1)
+ self.assertEqual(path, "https://x")
+
+ # it will have url-encoded the params properly, so we'll have to parse them
+ params = urllib.parse.parse_qsl(
+ query, keep_blank_values=True, strict_parsing=True, errors="strict"
+ )
+ self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+ self.assertEqual(params[2][0], "loginToken")
+
+ # fish the login token out of the returned redirect uri
+ login_token = params[2][1]
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ chan = self.make_request(
+ "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 5d5c24d01c..94a5154834 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -38,7 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(
"red",
- http_client=None,
+ federation_http_client=None,
federation_client=Mock(),
presence_handler=presence_handler,
)
@@ -53,7 +53,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.hs.config.use_presence = True
body = {"presence": "here", "status_msg": "beep boop"}
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
@@ -68,7 +68,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.hs.config.use_presence = False
body = {"presence": "here", "status_msg": "beep boop"}
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 383a9eafac..e59fa70baa 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
- http_client=None,
+ federation_http_client=None,
resource_for_client=self.mock_resource,
federation=Mock(),
federation_client=Mock(),
@@ -189,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.owner_tok = self.login("owner", "pass")
def test_set_displayname(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test"}),
@@ -202,7 +202,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def test_set_displayname_too_long(self):
"""Attempts to set a stupid displayname should get a 400"""
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test" * 100}),
@@ -214,9 +214,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, "owner")
def get_displayname(self):
- request, channel = self.make_request(
- "GET", "/profile/%s/displayname" % (self.owner,)
- )
+ channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
@@ -278,7 +276,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
)
def request_profile(self, expected_code, url_suffix="", access_token=None):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.profile_url + url_suffix, access_token=access_token
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -320,19 +318,19 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
"""Tests that a user can lookup their own profile without having to be in a room
if 'require_auth_for_profile_requests' is set to true in the server's config.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/profile/" + self.requester, access_token=self.requester_tok
)
self.assertEqual(channel.code, 200, channel.result)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/profile/" + self.requester + "/displayname",
access_token=self.requester_tok,
)
self.assertEqual(channel.code, 200, channel.result)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/profile/" + self.requester + "/avatar_url",
access_token=self.requester_tok,
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
index 7add5523c8..2bc512d75e 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -45,13 +45,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -74,13 +74,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# disable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": False},
@@ -89,26 +89,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# check rule disabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], False)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
self.assertEqual(channel.code, 200)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -130,13 +130,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# disable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": False},
@@ -145,14 +145,14 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# check rule disabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], False)
# re-enable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": True},
@@ -161,7 +161,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# check rule enabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -182,32 +182,32 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 200)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
self.assertEqual(channel.code, 200)
# check 404 for deleted rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
@@ -221,7 +221,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
@@ -235,7 +235,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": True},
@@ -252,7 +252,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/.m.muahahah/enabled",
{"enabled": True},
@@ -276,13 +276,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# GET actions for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/actions", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -305,13 +305,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# change the rule actions
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/actions",
{"actions": ["dont_notify"]},
@@ -320,7 +320,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
# GET actions for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/actions", access_token=token
)
self.assertEqual(channel.code, 200)
@@ -341,26 +341,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
self.assertEqual(channel.code, 200)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
self.assertEqual(channel.code, 200)
# check 404 for deleted rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
self.assertEqual(channel.code, 404)
@@ -374,7 +374,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
)
self.assertEqual(channel.code, 404)
@@ -388,7 +388,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/actions",
{"actions": ["dont_notify"]},
@@ -405,7 +405,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/.m.muahahah/actions",
{"actions": ["dont_notify"]},
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 49f1073c88..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -26,9 +26,10 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
+from synapse.rest import admin
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias, UserID
+from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -45,7 +46,7 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock(),
+ "red", federation_http_client=None, federation_client=Mock(),
)
self.hs.get_federation_handler = Mock()
@@ -83,13 +84,13 @@ class RoomPermissionsTestCase(RoomBase):
self.created_rmid_msg_path = (
"rooms/%s/send/m.room.message/a1" % (self.created_rmid)
).encode("ascii")
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
self.assertEquals(200, channel.code, channel.result)
# set topic for public room
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
@@ -111,7 +112,7 @@ class RoomPermissionsTestCase(RoomBase):
)
# send message in uncreated room, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
@@ -119,24 +120,24 @@ class RoomPermissionsTestCase(RoomBase):
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self):
@@ -144,30 +145,30 @@ class RoomPermissionsTestCase(RoomBase):
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
- request, channel = self.make_request("PUT", topic_path, topic_content)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", topic_path)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- request, channel = self.make_request("PUT", topic_path, topic_content)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
- request, channel = self.make_request("GET", topic_path)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
@@ -175,29 +176,29 @@ class RoomPermissionsTestCase(RoomBase):
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
- request, channel = self.make_request("PUT", topic_path, topic_content)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
- request, channel = self.make_request("GET", topic_path)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", topic_path, topic_content)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", topic_path)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
@@ -207,7 +208,7 @@ class RoomPermissionsTestCase(RoomBase):
def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
- request, channel = self.make_request("GET", path)
+ channel = self.make_request("GET", path)
self.assertEquals(expect_code, channel.code)
def test_membership_basic_room_perms(self):
@@ -379,16 +380,16 @@ class RoomsMemberListTestCase(RoomBase):
def test_get_member_list(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+ channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEquals(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self):
- request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
+ channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self):
room_id = self.helper.create_room_as("@some_other_guy:red")
- request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+ channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self):
@@ -397,17 +398,17 @@ class RoomsMemberListTestCase(RoomBase):
room_path = "/rooms/%s/members" % room_id
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
- request, channel = self.make_request("GET", room_path)
+ channel = self.make_request("GET", room_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
- request, channel = self.make_request("GET", room_path)
+ channel = self.make_request("GET", room_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
- request, channel = self.make_request("GET", room_path)
+ channel = self.make_request("GET", room_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -418,30 +419,26 @@ class RoomsCreateTestCase(RoomBase):
def test_post_room_no_keys(self):
# POST with no config keys, expect new room id
- request, channel = self.make_request("POST", "/createRoom", "{}")
+ channel = self.make_request("POST", "/createRoom", "{}")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id
- request, channel = self.make_request(
- "POST", "/createRoom", b'{"visibility":"private"}'
- )
+ channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id
- request, channel = self.make_request(
- "POST", "/createRoom", b'{"custom":"stuff"}'
- )
+ channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
self.assertEquals(200, channel.code)
@@ -449,16 +446,16 @@ class RoomsCreateTestCase(RoomBase):
def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400
- request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
+ channel = self.make_request("POST", "/createRoom", b'{"visibili')
self.assertEquals(400, channel.code)
- request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
+ channel = self.make_request("POST", "/createRoom", b'["hello"]')
self.assertEquals(400, channel.code)
def test_post_room_invitees_invalid_mxid(self):
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
# Note the trailing space in the MXID here!
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
)
self.assertEquals(400, channel.code)
@@ -476,54 +473,54 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self):
# missing keys or invalid json
- request, channel = self.make_request("PUT", self.path, "{}")
+ channel = self.make_request("PUT", self.path, "{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
+ channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, '{"nao')
+ channel = self.make_request("PUT", self.path, '{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, "text only")
+ channel = self.make_request("PUT", self.path, "text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, "")
+ channel = self.make_request("PUT", self.path, "")
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid key, wrong type
content = '{"topic":["Topic name"]}'
- request, channel = self.make_request("PUT", self.path, content)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self):
# nothing should be there
- request, channel = self.make_request("GET", self.path)
+ channel = self.make_request("GET", self.path)
self.assertEquals(404, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
- request, channel = self.make_request("PUT", self.path, content)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
- request, channel = self.make_request("GET", self.path)
+ channel = self.make_request("GET", self.path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
- request, channel = self.make_request("PUT", self.path, content)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
- request, channel = self.make_request("GET", self.path)
+ channel = self.make_request("GET", self.path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -539,24 +536,22 @@ class RoomMemberStateTestCase(RoomBase):
def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
- request, channel = self.make_request("PUT", path, "{}")
+ channel = self.make_request("PUT", path, "{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
+ channel = self.make_request("PUT", path, '{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, '{"nao')
+ channel = self.make_request("PUT", path, '{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
- "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
- )
+ channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, "text only")
+ channel = self.make_request("PUT", path, "text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, "")
+ channel = self.make_request("PUT", path, "")
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid keys, wrong types
@@ -565,7 +560,7 @@ class RoomMemberStateTestCase(RoomBase):
Membership.JOIN,
Membership.LEAVE,
)
- request, channel = self.make_request("PUT", path, content.encode("ascii"))
+ channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self):
@@ -576,10 +571,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
- request, channel = self.make_request("PUT", path, content.encode("ascii"))
+ channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
@@ -594,10 +589,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
@@ -613,18 +608,54 @@ class RoomMemberStateTestCase(RoomBase):
Membership.INVITE,
"Join us!",
)
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
+class RoomInviteRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_rooms_ratelimit(self):
+ """Tests that invites in a room are actually rate-limited."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ for i in range(3):
+ self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+ self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_users_ratelimit(self):
+ """Tests that invites to a specific user are actually rate-limited."""
+
+ for i in range(3):
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
servlets = [
+ admin.register_servlets,
profile.register_servlets,
room.register_servlets,
]
@@ -666,7 +697,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
- request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+ channel = self.make_request("PUT", path, {"displayname": "John Doe"})
self.assertEquals(channel.code, 200, channel.json_body)
# Check that all the rooms have been sent a profile update into.
@@ -676,7 +707,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
self.user_id,
)
- request, channel = self.make_request("GET", path)
+ channel = self.make_request("GET", path)
self.assertEquals(channel.code, 200)
self.assertIn("displayname", channel.json_body)
@@ -700,9 +731,23 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
for i in range(4):
- request, channel = self.make_request("POST", path % room_id, {})
+ channel = self.make_request("POST", path % room_id, {})
self.assertEquals(channel.code, 200)
+ @unittest.override_config(
+ {
+ "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}},
+ "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"],
+ "autocreate_auto_join_rooms": True,
+ },
+ )
+ def test_autojoin_rooms(self):
+ user_id = self.register_user("testuser", "password")
+
+ # Check that the new user successfully joined the four rooms
+ rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 4)
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -715,42 +760,40 @@ class RoomMessagesTestCase(RoomBase):
def test_invalid_puts(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
- request, channel = self.make_request("PUT", path, b"{}")
+ channel = self.make_request("PUT", path, b"{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
+ channel = self.make_request("PUT", path, b'{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b'{"nao')
+ channel = self.make_request("PUT", path, b'{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
- "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
- )
+ channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b"text only")
+ channel = self.make_request("PUT", path, b"text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b"")
+ channel = self.make_request("PUT", path, b"")
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# custom message types
content = b'{"body":"test","msgtype":"test.custom.text"}'
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = b'{"body":"test2","msgtype":"m.text"}'
- request, channel = self.make_request("PUT", path, content)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -764,9 +807,7 @@ class RoomInitialSyncTestCase(RoomBase):
self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self):
- request, channel = self.make_request(
- "GET", "/rooms/%s/initialSync" % self.room_id
- )
+ channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
self.assertEquals(200, channel.code)
self.assertEquals(self.room_id, channel.json_body["room_id"])
@@ -807,7 +848,7 @@ class RoomMessageListTestCase(RoomBase):
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
self.assertEquals(200, channel.code)
@@ -818,7 +859,7 @@ class RoomMessageListTestCase(RoomBase):
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
self.assertEquals(200, channel.code)
@@ -851,7 +892,7 @@ class RoomMessageListTestCase(RoomBase):
self.helper.send(self.room_id, "message 3")
# Check that we get the first and second message when querying /messages.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -879,7 +920,7 @@ class RoomMessageListTestCase(RoomBase):
# Check that we only get the second message through /message now that the first
# has been purged.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -896,7 +937,7 @@ class RoomMessageListTestCase(RoomBase):
# Check that we get no event, but also no error, when querying /messages with
# the token that was pointing at the first event, because we don't have it
# anymore.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -955,7 +996,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
self.helper.send(self.room, body="There!", tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/search?access_token=%s" % (self.access_token,),
{
@@ -984,7 +1025,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
self.helper.send(self.room, body="There!", tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/search?access_token=%s" % (self.access_token,),
{
@@ -1032,14 +1073,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
return self.hs
def test_restricted_no_auth(self):
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401, channel.result)
def test_restricted_auth(self):
self.register_user("user", "pass")
tok = self.login("user", "pass")
- request, channel = self.make_request("GET", self.url, access_token=tok)
+ channel = self.make_request("GET", self.url, access_token=tok)
self.assertEqual(channel.code, 200, channel.result)
@@ -1067,7 +1108,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.displayname = "test user"
data = {"displayname": self.displayname}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
request_data,
@@ -1080,7 +1121,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
def test_per_room_profile_forbidden(self):
data = {"membership": "join", "displayname": "other test user"}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
% (self.room_id, self.user_id),
@@ -1090,7 +1131,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
@@ -1123,7 +1164,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
def test_join_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/join".format(self.room_id),
content={"reason": reason},
@@ -1137,7 +1178,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
content={"reason": reason},
@@ -1151,7 +1192,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
@@ -1165,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
@@ -1177,7 +1218,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
def test_unban_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
@@ -1189,7 +1230,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
def test_invite_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
@@ -1208,7 +1249,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
content={"reason": reason},
@@ -1219,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def _check_for_reason(self, reason):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
self.room_id, self.second_user_id
@@ -1268,7 +1309,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
@@ -1298,7 +1339,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by the absence of a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
@@ -1333,7 +1374,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
@@ -1361,7 +1402,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
@@ -1378,7 +1419,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
@@ -1401,7 +1442,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (
@@ -1432,7 +1473,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
@@ -1467,7 +1508,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
@@ -1514,7 +1555,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
@@ -1635,7 +1676,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
# Check that we can still see the messages before the erasure request.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
'/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
% (self.room_id, event_id),
@@ -1681,7 +1722,9 @@ class ContextTestCase(unittest.HomeserverTestCase):
deactivate_account_handler = self.hs.get_deactivate_account_handler()
self.get_success(
- deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
+ deactivate_account_handler.deactivate_account(
+ self.user_id, True, create_requester(self.user_id)
+ )
)
# Invite another user in the room. This is needed because messages will be
@@ -1699,7 +1742,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
# Check that a user that joined the room after the erasure request can't see
# the messages anymore.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
'/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
% (self.room_id, event_id),
@@ -1789,7 +1832,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases"
% (self.room_id,),
@@ -1810,7 +1853,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -1840,14 +1883,14 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
access_token=self.room_owner_tok,
@@ -1859,7 +1902,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
json.dumps(content),
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index bbd30f594b..38c51525a3 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -39,7 +39,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock(),
+ "red", federation_http_client=None, federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
@@ -94,7 +94,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user="@jim:red")
def test_set_typing(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
@@ -117,7 +117,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
)
def test_set_not_typing(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": false}',
@@ -125,7 +125,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
def test_typing_timeout(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
@@ -138,7 +138,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 2)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..b1333df82d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,8 +17,12 @@
# limitations under the License.
import json
+import re
import time
-from typing import Any, Dict, Optional
+import urllib.parse
+from typing import Any, Dict, Mapping, MutableMapping, Optional
+
+from mock import patch
import attr
@@ -26,8 +30,11 @@ from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.types import JsonDict
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
+from tests.test_utils import FakeResponse
+from tests.test_utils.html_parsers import TestHtmlParser
@attr.s
@@ -75,7 +82,7 @@ class RestHelper:
if tok:
path = path + "?access_token=%s" % tok
- _, channel = make_request(
+ channel = make_request(
self.hs.get_reactor(),
self.site,
"POST",
@@ -151,7 +158,7 @@ class RestHelper:
data = {"membership": membership}
data.update(extra_data)
- _, channel = make_request(
+ channel = make_request(
self.hs.get_reactor(),
self.site,
"PUT",
@@ -186,7 +193,7 @@ class RestHelper:
if tok:
path = path + "?access_token=%s" % tok
- _, channel = make_request(
+ channel = make_request(
self.hs.get_reactor(),
self.site,
"PUT",
@@ -242,9 +249,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- _, channel = make_request(
- self.hs.get_reactor(), self.site, method, path, content
- )
+ channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
@@ -327,7 +332,7 @@ class RestHelper:
"""
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
- _, channel = make_request(
+ channel = make_request(
self.hs.get_reactor(),
FakeSite(resource),
"POST",
@@ -344,3 +349,246 @@ class RestHelper:
)
return channel.json_body
+
+ def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+ """Log in (as a new user) via OIDC
+
+ Returns the result of the final token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+ client_redirect_url = "https://x"
+ channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+
+ # expect a confirmation page
+ assert channel.code == 200, channel.result
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.text_body,
+ )
+ assert m, channel.text_body
+ login_token = m.group(1)
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token and device id.
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ assert channel.code == 200
+ return channel.json_body
+
+ def auth_via_oidc(
+ self,
+ user_info_dict: JsonDict,
+ client_redirect_url: Optional[str] = None,
+ ui_auth_session_id: Optional[str] = None,
+ ) -> FakeChannel:
+ """Perform an OIDC authentication flow via a mock OIDC provider.
+
+ This can be used for either login or user-interactive auth.
+
+ Starts by making a request to the relevant synapse redirect endpoint, which is
+ expected to serve a 302 to the OIDC provider. We then make a request to the
+ OIDC callback endpoint, intercepting the HTTP requests that will get sent back
+ to the OIDC provider.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+
+ Args:
+ user_info_dict: the remote userinfo that the OIDC provider should present.
+ Typically this should be '{"sub": "<remote user id>"}'.
+ client_redirect_url: for a login flow, the client redirect URL to pass to
+ the login redirect endpoint
+ ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
+ of the UI auth.
+
+ Returns:
+ A FakeChannel containing the result of calling the OIDC callback endpoint.
+ Note that the response code may be a 200, 302 or 400 depending on how things
+ went.
+ """
+
+ cookies = {}
+
+ # if we're doing a ui auth, hit the ui auth redirect endpoint
+ if ui_auth_session_id:
+ # can't set the client redirect url for UI Auth
+ assert client_redirect_url is None
+ oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+ else:
+ # otherwise, hit the login redirect endpoint
+ oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+
+ # we now have a URI for the OIDC IdP, but we skip that and go straight
+ # back to synapse's OIDC callback resource. However, we do need the "state"
+ # param that synapse passes to the IdP via query params, as well as the cookie
+ # that synapse passes to the client.
+
+ oauth_uri_path, _ = oauth_uri.split("?", 1)
+ assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+ "unexpected SSO URI " + oauth_uri_path
+ )
+ return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+
+ def complete_oidc_auth(
+ self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+ ) -> FakeChannel:
+ """Mock out an OIDC authentication flow
+
+ Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
+ initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
+ Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
+ sent back to the OIDC provider.
+
+ Requires the OIDC callback resource to be mounted at the normal place.
+
+ Args:
+ oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
+ from initiate_sso_login or initiate_sso_ui_auth).
+ cookies: the cookies set by synapse's redirect endpoint, which will be
+ sent back to the callback endpoint.
+ user_info_dict: the remote userinfo that the OIDC provider should present.
+ Typically this should be '{"sub": "<remote user id>"}'.
+
+ Returns:
+ A FakeChannel containing the result of calling the OIDC callback endpoint.
+ """
+ _, oauth_uri_qs = oauth_uri.split("?", 1)
+ params = urllib.parse.parse_qs(oauth_uri_qs)
+ callback_uri = "%s?%s" % (
+ urllib.parse.urlparse(params["redirect_uri"][0]).path,
+ urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+ )
+
+ # before we hit the callback uri, stub out some methods in the http client so
+ # that we don't have to handle full HTTPS requests.
+ # (expected url, json response) pairs, in the order we expect them.
+ expected_requests = [
+ # first we get a hit to the token endpoint, which we tell to return
+ # a dummy OIDC access token
+ (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
+ # and then one to the user_info endpoint, which returns our remote user id.
+ (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
+ ]
+
+ async def mock_req(method: str, uri: str, data=None, headers=None):
+ (expected_uri, resp_obj) = expected_requests.pop(0)
+ assert uri == expected_uri
+ resp = FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+ )
+ return resp
+
+ with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ # now hit the callback URI with the right params and a made-up code
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ callback_uri,
+ custom_headers=[
+ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+ ],
+ )
+ return channel
+
+ def initiate_sso_login(
+ self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+ ) -> str:
+ """Make a request to the login-via-sso redirect endpoint, and return the target
+
+ Assumes that exactly one SSO provider has been configured. Requires the login
+ servlet to be mounted.
+
+ Args:
+ client_redirect_url: the client redirect URL to pass to the login redirect
+ endpoint
+ cookies: any cookies returned will be added to this dict
+
+ Returns:
+ the URI that the client gets redirected to (ie, the SSO server)
+ """
+ params = {}
+ if client_redirect_url:
+ params["redirectUrl"] = client_redirect_url
+
+ # hit the redirect url (which will issue a cookie and state)
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+ )
+
+ assert channel.code == 302
+ channel.extract_cookies(cookies)
+ return channel.headers.getRawHeaders("Location")[0]
+
+ def initiate_sso_ui_auth(
+ self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
+ ) -> str:
+ """Make a request to the ui-auth-via-sso endpoint, and return the target
+
+ Assumes that exactly one SSO provider has been configured. Requires the
+ AuthRestServlet to be mounted.
+
+ Args:
+ ui_auth_session_id: the session id of the UI auth
+ cookies: any cookies returned will be added to this dict
+
+ Returns:
+ the URI that the client gets linked to (ie, the SSO server)
+ """
+ sso_redirect_endpoint = (
+ "/_matrix/client/r0/auth/m.login.sso/fallback/web?"
+ + urllib.parse.urlencode({"session": ui_auth_session_id})
+ )
+ # hit the redirect url (which will issue a cookie and state)
+ channel = make_request(
+ self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
+ )
+ # that should serve a confirmation page
+ assert channel.code == 200, channel.text_body
+ channel.extract_cookies(cookies)
+
+ # parse the confirmation page to fish out the link.
+ p = TestHtmlParser()
+ p.feed(channel.text_body)
+ p.close()
+ assert len(p.links) == 1, "not exactly one link in confirmation page"
+ oauth_uri = p.links[0]
+ return oauth_uri
+
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
+TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
+TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "discover": False,
+ "issuer": "https://issuer.test",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
+ "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
+ "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 2ac1ecb7d3..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -19,13 +19,12 @@ import os
import re
from email.parser import Parser
from typing import Optional
-from urllib.parse import urlencode
import pkg_resources
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -113,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_email(self):
+ """Test that we ratelimit /requestToken for the same email.
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test1@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ def reset(ip):
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret, ip)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ self.email_attempts.clear()
+
+ # We expect to be able to make three requests before getting rate
+ # limited.
+ #
+ # We change IPs to ensure that we're not being ratelimited due to the
+ # same IP
+ reset("127.0.0.1")
+ reset("127.0.0.2")
+ reset("127.0.0.3")
+
+ with self.assertRaises(HttpResponseException) as cm:
+ reset("127.0.0.4")
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow
Request password reset with different spelling
@@ -240,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
- def _request_token(self, email, client_secret):
- request, channel = self.make_request(
+ def _request_token(self, email, client_secret, ip="127.0.0.1"):
+ channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
+ client_ip=ip,
)
- self.assertEquals(200, channel.code, channel.result)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body["sid"]
@@ -255,7 +309,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
# Load the password reset confirmation page
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"GET",
@@ -268,20 +322,13 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
- # Send arguments as url-encoded form data, matching the template's behaviour
- form_args = []
- for key, value_list in request.args.items():
- for value in value_list:
- arg = (key, value)
- form_args.append(arg)
-
# Confirm the password reset
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"POST",
path,
- content=urlencode(form_args).encode("utf8"),
+ content=b"",
shorthand=False,
content_is_form=True,
)
@@ -310,7 +357,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
def _reset_password(
self, new_password, session_id, client_secret, expected_code=200
):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/password",
{
@@ -352,8 +399,8 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
# Check that this access token has been invalidated.
- request, channel = self.make_request("GET", "account/whoami")
- self.assertEqual(request.code, 401)
+ channel = self.make_request("GET", "account/whoami")
+ self.assertEqual(channel.code, 401)
def test_pending_invites(self):
"""Tests that deactivating a user rejects every pending invite for them."""
@@ -407,10 +454,10 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
"erase": False,
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@@ -517,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_ip(self):
+ """Tests that adding emails is ratelimited by IP
+ """
+
+ # We expect to be able to set three emails before getting ratelimited.
+ self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+ self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+ self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+ with self.assertRaises(HttpResponseException) as cm:
+ self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
"""
@@ -530,7 +592,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self._validate_token(link)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -548,7 +610,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
@@ -569,7 +631,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/delete",
{"medium": "email", "address": self.email},
@@ -578,7 +640,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
@@ -601,7 +663,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/delete",
{"medium": "email", "address": self.email},
@@ -612,7 +674,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
@@ -629,7 +691,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEquals(len(self.email_attempts), 1)
# Attempt to add email without clicking the link
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -647,7 +709,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
@@ -662,7 +724,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
session_id = "weasle"
# Attempt to add email without even requesting an email
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -680,7 +742,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
@@ -784,17 +846,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
if next_link:
body["next_link"] = next_link
- request, channel = self.make_request(
- "POST", b"account/3pid/email/requestToken", body,
- )
- self.assertEquals(expect_code, channel.code, channel.result)
+ channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
+
+ if channel.code != expect_code:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body.get("sid")
def _request_token_invalid_email(
self, email, expected_errcode, expected_error, client_secret="foobar",
):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
@@ -807,7 +871,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Remove the host
path = link.replace("https://example.com", "")
- request, channel = self.make_request("GET", path, shorthand=False)
+ channel = self.make_request("GET", path, shorthand=False)
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -833,15 +897,17 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _add_email(self, request_email, expected_email):
"""Test adding an email to profile
"""
+ previous_email_attempts = len(self.email_attempts)
+
client_secret = "foobar"
session_id = self._request_token(request_email, client_secret)
- self.assertEquals(len(self.email_attempts), 1)
+ self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email()
self._validate_token(link)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -859,10 +925,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+ threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+ self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..3f50c56745 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,20 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Union
+from typing import Union
from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.types import JsonDict, UserID
from tests import unittest
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
+from tests.unittest import override_config, skip_unless
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
@@ -64,11 +68,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
def register(self, expected_response: int, body: JsonDict) -> FakeChannel:
"""Make a register request."""
- request, channel = self.make_request(
- "POST", "register", body
- ) # type: SynapseRequest, FakeChannel
+ channel = self.make_request("POST", "register", body)
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
def recaptcha(
@@ -78,18 +80,18 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
if post_session is None:
post_session = session
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request.code, 200)
+ )
+ self.assertEqual(channel.code, 200)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"auth/m.login.recaptcha/fallback/web?session="
+ post_session
+ "&g-recaptcha-response=a",
)
- self.assertEqual(request.code, expected_post_response)
+ self.assertEqual(channel.code, expected_post_response)
# The recaptcha handler is called with the response given
attempts = self.recaptcha_checker.recaptcha_attempts
@@ -156,31 +158,44 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = "https://synapse.test"
+
+ if HAS_OIDC:
+ # we enable OIDC as a way of testing SSO flows
+ oidc_config = {}
+ oidc_config.update(TEST_OIDC_CONFIG)
+ oidc_config["allow_existing_users"] = True
+ config["oidc_config"] = oidc_config
+
+ return config
+
+ def create_resource_dict(self):
+ resource_dict = super().create_resource_dict()
+ resource_dict.update(build_synapse_client_resource_tree(self.hs))
+ return resource_dict
+
def prepare(self, reactor, clock, hs):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
- self.user_tok = self.login("test", self.user_pass)
-
- def get_device_ids(self) -> List[str]:
- # Get the list of devices so one can be deleted.
- request, channel = self.make_request(
- "GET", "devices", access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
-
- # Get the ID of the device.
- self.assertEqual(request.code, 200)
- return [d["device_id"] for d in channel.json_body["devices"]]
+ self.device_id = "dev1"
+ self.user_tok = self.login("test", self.user_pass, self.device_id)
def delete_device(
- self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+ self,
+ access_token: str,
+ device: str,
+ expected_response: int,
+ body: Union[bytes, JsonDict] = b"",
) -> FakeChannel:
"""Delete an individual device."""
- request, channel = self.make_request(
- "DELETE", "devices/" + device, body, access_token=self.user_tok
- ) # type: SynapseRequest, FakeChannel
+ channel = self.make_request(
+ "DELETE", "devices/" + device, body, access_token=access_token,
+ )
# Ensure the response is sane.
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
@@ -188,12 +203,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""Delete 1 or more devices."""
# Note that this uses the delete_devices endpoint so that we can modify
# the payload half-way through some tests.
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "delete_devices", body, access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
+ )
# Ensure the response is sane.
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
@@ -201,11 +216,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""
Test user interactive authentication outside of registration.
"""
- device_id = self.get_device_ids()[0]
-
# Attempt to delete this device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_id, 401)
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -214,7 +227,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow.
self.delete_device(
- device_id,
+ self.user_tok,
+ self.device_id,
200,
{
"auth": {
@@ -233,13 +247,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
UIA - check that still works.
"""
- device_id = self.get_device_ids()[0]
- channel = self.delete_device(device_id, 401)
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
session = channel.json_body["session"]
# Make another request providing the UI auth flow.
self.delete_device(
- device_id,
+ self.user_tok,
+ self.device_id,
200,
{
"auth": {
@@ -262,14 +276,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
session ID should be rejected.
"""
# Create a second login.
- self.login("test", self.user_pass)
-
- device_ids = self.get_device_ids()
- self.assertEqual(len(device_ids), 2)
+ self.login("test", self.user_pass, "dev2")
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_devices(401, {"devices": [device_ids[0]]})
+ channel = self.delete_devices(401, {"devices": [self.device_id]})
# Grab the session
session = channel.json_body["session"]
@@ -281,7 +292,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
self.delete_devices(
200,
{
- "devices": [device_ids[1]],
+ "devices": ["dev2"],
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": self.user},
@@ -296,14 +307,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
The initial requested URI cannot be modified during the user interactive authentication session.
"""
# Create a second login.
- self.login("test", self.user_pass)
-
- device_ids = self.get_device_ids()
- self.assertEqual(len(device_ids), 2)
+ self.login("test", self.user_pass, "dev2")
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_ids[0], 401)
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -312,8 +320,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error.
+ #
+ # This makes use of the fact that the device ID is embedded into the URL.
self.delete_device(
- device_ids[1],
+ self.user_tok,
+ "dev2",
403,
{
"auth": {
@@ -324,3 +335,152 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
},
)
+
+ @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}})
+ def test_can_reuse_session(self):
+ """
+ The session can be reused if configured.
+
+ Compare to test_cannot_change_uri.
+ """
+ # Create a second and third login.
+ self.login("test", self.user_pass, "dev2")
+ self.login("test", self.user_pass, "dev3")
+
+ # Attempt to delete a device. This works since the user just logged in.
+ self.delete_device(self.user_tok, "dev2", 200)
+
+ # Move the clock forward past the validation timeout.
+ self.reactor.advance(6)
+
+ # Deleting another devices throws the user into UI auth.
+ channel = self.delete_device(self.user_tok, "dev3", 401)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ self.user_tok,
+ "dev3",
+ 200,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
+ # Make another request, but try to delete the first device. This works
+ # due to re-using the previous session.
+ #
+ # Note that *no auth* information is provided, not even a session iD!
+ self.delete_device(self.user_tok, self.device_id, 200)
+
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config({"oidc_config": TEST_OIDC_CONFIG})
+ def test_ui_auth_via_sso(self):
+ """Test a successful UI Auth flow via SSO
+
+ This includes:
+ * hitting the UIA SSO redirect endpoint
+ * checking it serves a confirmation page which links to the OIDC provider
+ * calling back to the synapse oidc callback
+ * checking that the original operation succeeds
+ """
+
+ # log the user in
+ remote_user_id = UserID.from_string(self.user).localpart
+ login_resp = self.helper.login_via_oidc(remote_user_id)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ # initiate a UI Auth process by attempting to delete the device
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ # check that SSO is offered
+ flows = channel.json_body["flows"]
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+
+ # run the UIA-via-SSO flow
+ session_id = channel.json_body["session"]
+ channel = self.helper.auth_via_oidc(
+ {"sub": remote_user_id}, ui_auth_session_id=session_id
+ )
+
+ # that should serve a confirmation page
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # and now the delete request should succeed.
+ self.delete_device(
+ self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
+ )
+
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config({"oidc_config": TEST_OIDC_CONFIG})
+ def test_does_not_offer_password_for_sso_user(self):
+ login_resp = self.helper.login_via_oidc("username")
+ user_tok = login_resp["access_token"]
+ device_id = login_resp["device_id"]
+
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ channel = self.delete_device(user_tok, device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+ def test_does_not_offer_sso_for_password_user(self):
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config({"oidc_config": TEST_OIDC_CONFIG})
+ def test_offers_both_flows_for_upgraded_user(self):
+ """A user that had a password and then logged in with SSO should get both flows
+ """
+ login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ flows = channel.json_body["flows"]
+ # we have no particular expectations of ordering here
+ self.assertIn({"stages": ["m.login.password"]}, flows)
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+ self.assertEqual(len(flows), 2)
+
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config({"oidc_config": TEST_OIDC_CONFIG})
+ def test_ui_auth_fails_for_incorrect_sso_user(self):
+ """If the user tries to authenticate with the wrong SSO user, they get an error
+ """
+ # log the user in
+ login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ # start a UI Auth flow by attempting to delete a device
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+ session_id = channel.json_body["session"]
+
+ # do the OIDC auth, but auth as the wrong user
+ channel = self.helper.auth_via_oidc(
+ {"sub": "wrong_user"}, ui_auth_session_id=session_id
+ )
+
+ # that should return a failure message
+ self.assertSubstring("We were unable to validate", channel.text_body)
+
+ # ... and the delete op should now fail with a 403
+ self.delete_device(
+ self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
+ )
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index 767e126875..e808339fb3 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -36,7 +36,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
return hs
def test_check_auth_required(self):
- request, channel = self.make_request("GET", self.url)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401)
@@ -44,7 +44,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.register_user("user", "pass")
access_token = self.login("user", "pass")
- request, channel = self.make_request("GET", self.url, access_token=access_token)
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
@@ -62,7 +62,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
user = self.register_user(localpart, password)
access_token = self.login(user, password)
- request, channel = self.make_request("GET", self.url, access_token=access_token)
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
@@ -70,7 +70,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
# Test case where password is handled outside of Synapse
self.assertTrue(capabilities["m.change_password"]["enabled"])
self.get_success(self.store.user_set_password_hash(user, None))
- request, channel = self.make_request("GET", self.url, access_token=access_token)
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 231d5aefea..f761c44936 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -36,7 +36,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastore()
def test_add_filter(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
@@ -49,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
def test_add_filter_for_other_user(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
@@ -61,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
@@ -79,7 +79,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.reactor.advance(1)
filter_id = filter_id.result
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
@@ -87,7 +87,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
@@ -97,7 +97,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
# Currently invalid params do not have an appropriate errcode
# in errors.py
def test_get_filter_invalid_id(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
@@ -105,7 +105,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
# No ID also returns an invalid_id error
def test_get_filter_no_id(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index ee86b94917..fba34def30 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -70,9 +70,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_get_policy(self):
"""Tests if the /password_policy endpoint returns the configured policy."""
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/password_policy"
- )
+ channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(
@@ -89,7 +87,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_too_short(self):
request_data = json.dumps({"username": "kermit", "password": "shorty"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -98,7 +96,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_digit(self):
request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -107,7 +105,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_symbol(self):
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -116,7 +114,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_uppercase(self):
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -125,7 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_lowercase(self):
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -134,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_compliant(self):
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
+ channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has
# responded with a list of registration flows.
@@ -160,7 +158,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
},
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/account/password",
request_data,
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 8f0c2430e8..27db4f551e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -61,7 +61,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.get_datastore().services_cache.append(appservice)
request_data = json.dumps({"username": "as_user_kermit"})
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
@@ -72,7 +72,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists
request_data = json.dumps({"username": "kermit"})
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
@@ -80,14 +80,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username")
@@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
det_data = {
"user_id": user_id,
@@ -117,16 +117,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+ self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -135,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
@@ -144,7 +145,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_guest(self):
for i in range(0, 6):
url = self.url + b"?kind=guest"
- request, channel = self.make_request(b"POST", url, b"{}")
+ channel = self.make_request(b"POST", url, b"{}")
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -154,7 +155,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -168,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -178,12 +179,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_advertised_flows(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -206,7 +207,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -238,7 +239,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
)
def test_advertised_flows_no_msisdn_email_required(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -278,7 +279,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"register/email/requestToken",
{"client_secret": "foobar", "email": email, "send_attempt": 1},
@@ -317,13 +318,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
@@ -345,14 +346,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/account_validity/validity"
params = {"user_id": user_id}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_manual_expire(self):
@@ -369,14 +368,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
@@ -396,20 +393,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Try to log the user out
- request, channel = self.make_request(b"POST", "/logout", access_token=tok)
+ channel = self.make_request(b"POST", "/logout", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions
- request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+ channel = self.make_request(b"POST", "/logout/all", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -483,7 +478,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# retrieve the token from the DB.
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
- request, channel = self.make_request(b"GET", url)
+ channel = self.make_request(b"GET", url)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
@@ -503,14 +498,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# our access token should be denied from now, otherwise they should
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_renewal_invalid_token(self):
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
- request, channel = self.make_request(b"GET", url)
+ channel = self.make_request(b"GET", url)
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
@@ -531,7 +526,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.email_attempts = []
(user_id, tok) = self.create_user()
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST",
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
@@ -555,10 +550,10 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"erase": False,
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
@@ -606,7 +601,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.email_attempts = []
# Test that we're still able to manually trigger a mail to be sent.
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST",
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 6cd4eb6624..bd574077e7 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -60,7 +60,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, event_id),
access_token=self.user_token,
@@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
% (self.room, self.parent_id),
@@ -152,7 +152,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
@@ -210,7 +210,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
@@ -279,7 +279,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s"
"/aggregations/%s/%s/m.reaction/%s?limit=1%s"
@@ -325,7 +325,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
@@ -357,7 +357,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Now lets redact one of the 'a' reactions
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
access_token=self.user_token,
@@ -365,7 +365,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
@@ -382,7 +382,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"""Test that aggregations must be annotations.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
% (self.room, self.parent_id, RelationTypes.REPLACE),
@@ -414,7 +414,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
reply_2 = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
@@ -450,7 +450,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
edit_event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
@@ -507,7 +507,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
@@ -549,7 +549,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Check the relation is returned
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
% (self.room, original_event_id),
@@ -561,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(len(channel.json_body["chunk"]), 1)
# Redact the original event
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/redact/%s/%s"
% (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
@@ -571,7 +571,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Try to check for remaining m.replace relations
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
% (self.room, original_event_id),
@@ -598,7 +598,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Redact the original
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/redact/%s/%s"
% (
@@ -612,7 +612,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Check that aggregations returns zero
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
% (self.room, original_event_id),
@@ -656,7 +656,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
original_id = parent_id if parent_id else self.parent_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 562a9c1ba4..116ace1812 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -17,6 +17,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import shared_rooms
from tests import unittest
+from tests.server import FakeChannel
class UserSharedRoomsTest(unittest.HomeserverTestCase):
@@ -40,14 +41,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
- def _get_shared_rooms(self, token, other_user):
- request, channel = self.make_request(
+ def _get_shared_rooms(self, token, other_user) -> FakeChannel:
+ return self.make_request(
"GET",
"/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
% other_user,
access_token=token,
)
- return request, channel
def test_shared_room_list_public(self):
"""
@@ -63,7 +63,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
self.helper.join(room, user=u2, tok=u2_token)
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
@@ -82,7 +82,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
self.helper.join(room, user=u2, tok=u2_token)
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
@@ -104,7 +104,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room_public, user=u2, tok=u2_token)
self.helper.join(room_private, user=u1, tok=u1_token)
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 2)
self.assertTrue(room_public in channel.json_body["joined"])
@@ -125,13 +125,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
self.helper.leave(room, user=u1, tok=u1_token)
- request, channel = self._get_shared_rooms(u2_token, u1)
+ channel = self._get_shared_rooms(u2_token, u1)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 31ac0fccb8..512e36c236 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -35,7 +35,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
]
def test_sync_argless(self):
- request, channel = self.make_request("GET", "/sync")
+ channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200)
self.assertTrue(
@@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"""
self.hs.config.use_presence = False
- request, channel = self.make_request("GET", "/sync")
+ channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200)
self.assertTrue(
@@ -194,7 +194,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
tok=tok,
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/sync?filter=%s" % sync_filter, access_token=tok
)
self.assertEqual(channel.code, 200, channel.result)
@@ -245,21 +245,19 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.helper.send(room, body="There!", tok=other_access_token)
# Start typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": true, "timeout": 30000}',
)
self.assertEquals(200, channel.code)
- request, channel = self.make_request(
- "GET", "/sync?access_token=%s" % (access_token,)
- )
+ channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
# Stop typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": false}',
@@ -267,7 +265,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
# Start typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": true, "timeout": 30000}',
@@ -275,9 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
# Should return immediately
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -289,9 +285,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
# invalidate the stream token.
self.helper.send(room, body="There!", tok=other_access_token)
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -299,9 +293,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
# ahead, and therefore it's saying the typing (that we've actually
# already seen) is new, since it's got a token above our new, now-reset
# stream token.
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -383,7 +375,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event.
body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/rooms/%s/read_markers" % self.room_id,
body,
@@ -450,7 +442,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
def _check_unread_count(self, expected_count: True):
"""Syncs and compares the unread count with the expected value."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url % self.next_batch, access_token=self.tok,
)
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index fbcf8d5b86..5e90d656f7 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -39,7 +39,7 @@ from tests.utils import default_config
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
- return self.setup_test_homeserver(http_client=self.http_client)
+ return self.setup_test_homeserver(federation_http_client=self.http_client)
def create_test_resource(self):
return create_resource_tree(
@@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
}
]
self.hs2 = self.setup_test_homeserver(
- http_client=self.http_client2, config=config
+ federation_http_client=self.http_client2, config=config
)
# wire up outbound POST /key/v2/query requests from hs2 so that they
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2a3b2a8f27..a6c6985173 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config = self.default_config()
config["media_store_path"] = self.media_store_path
- config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
@@ -214,7 +213,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
}
config["media_storage_providers"] = [provider_config]
- hs = self.setup_test_homeserver(config=config, http_client=client)
+ hs = self.setup_test_homeserver(config=config, federation_http_client=client)
return hs
@@ -228,7 +227,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _req(self, content_disposition):
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
@@ -313,18 +312,42 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
+ """Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self):
+ """Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
+ def test_invalid_type(self):
+ """An invalid thumbnail type is never available."""
+ self._test_thumbnail("invalid", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+ )
+ def test_no_thumbnail_crop(self):
+ """
+ Override the config to generate only scaled thumbnails, but request a cropped one.
+ """
+ self._test_thumbnail("crop", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+ )
+ def test_no_thumbnail_scale(self):
+ """
+ Override the config to generate only cropped thumbnails, but request a scaled one.
+ """
+ self._test_thumbnail("scale", None, False)
+
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
- request, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(self.thumbnail_resource),
"GET",
@@ -362,3 +385,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"error": "Not found [b'example.com', b'12345']",
},
)
+
+ def test_x_robots_tag_header(self):
+ """
+ Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+ to not index, archive, or follow links in media.
+ """
+ channel = self._req(b"inline; filename=out" + self.test_image.extension)
+
+ headers = channel.headers
+ self.assertEqual(
+ headers.getRawHeaders(b"X-Robots-Tag"),
+ [b"noindex, nofollow, noarchive, noimageindex"],
+ )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index ccdc8c2ecf..6968502433 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,42 +18,23 @@ import re
from mock import patch
-import attr
-
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
from tests import unittest
from tests.server import FakeTransport
-
-@attr.s
-class FakeResponse:
- version = attr.ib()
- code = attr.ib()
- phrase = attr.ib()
- headers = attr.ib()
- body = attr.ib()
- absoluteURI = attr.ib()
-
- @property
- def request(self):
- @attr.s
- class FakeTransport:
- absoluteURI = self.absoluteURI
-
- return FakeTransport()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
+try:
+ import lxml
+except ImportError:
+ lxml = None
class URLPreviewTests(unittest.HomeserverTestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
hijack_auth = True
user_id = "@test:user"
@@ -139,7 +120,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def test_cache_returns_correct_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
@@ -164,7 +145,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
# Check the cache returns the correct response
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://matrix.org", shorthand=False
)
@@ -180,7 +161,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertNotIn("http://matrix.org", self.preview_url._cache)
# Check the database cache returns the correct response
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://matrix.org", shorthand=False
)
@@ -201,7 +182,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
@@ -236,7 +217,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
@@ -271,7 +252,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
@@ -304,7 +285,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://example.com",
shorthand=False,
@@ -334,7 +315,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
@@ -355,7 +336,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
@@ -372,7 +353,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Blacklisted IP addresses, accessed directly, are not spidered.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://192.168.1.1", shorthand=False
)
@@ -391,7 +372,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Blacklisted IP ranges, accessed directly, are not spidered.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://1.1.1.2", shorthand=False
)
@@ -411,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://example.com",
shorthand=False,
@@ -448,7 +429,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
(IPv4Address, "10.1.2.3"),
]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
self.assertEqual(channel.code, 502)
@@ -468,7 +449,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
@@ -489,7 +470,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False
)
@@ -506,7 +487,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
OPTIONS returns the OPTIONS.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"OPTIONS", "preview_url?url=http://example.com", shorthand=False
)
self.assertEqual(channel.code, 200)
@@ -519,7 +500,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
# Build and make a request to the server
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://example.com",
shorthand=False,
@@ -593,7 +574,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
@@ -658,7 +639,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}
end_content = json.dumps(result).encode("utf-8")
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 02a46e5fda..32acd93dc1 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -25,7 +25,7 @@ class HealthCheckTests(unittest.HomeserverTestCase):
return HealthResource()
def test_health(self):
- request, channel = self.make_request("GET", "/health", shorthand=False)
+ channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 6a930f4148..14de0921be 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -28,11 +28,11 @@ class WellKnownTests(unittest.HomeserverTestCase):
self.hs.config.public_baseurl = "https://tesths"
self.hs.config.default_identity_server = "https://testis"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{
@@ -44,8 +44,8 @@ class WellKnownTests(unittest.HomeserverTestCase):
def test_well_known_no_public_baseurl(self):
self.hs.config.public_baseurl = None
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(request.code, 404)
+ self.assertEqual(channel.code, 404)
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..6419c445ec 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, Optional, Tuple, Union
+from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
import attr
from typing_extensions import Deque
@@ -47,13 +47,26 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
+ _ip = attr.ib(type=str, default="127.0.0.1")
_producer = None
@property
def json_body(self):
- if not self.result:
- raise Exception("No result yet.")
- return json.loads(self.result["body"].decode("utf8"))
+ return json.loads(self.text_body)
+
+ @property
+ def text_body(self) -> str:
+ """The body of the result, utf-8-decoded.
+
+ Raises an exception if the request has not yet completed.
+ """
+ if not self.is_finished:
+ raise Exception("Request not yet completed")
+ return self.result["body"].decode("utf8")
+
+ def is_finished(self) -> bool:
+ """check if the response has been completely received"""
+ return self.result.get("done", False)
@property
def code(self):
@@ -62,7 +75,7 @@ class FakeChannel:
return int(self.result["code"])
@property
- def headers(self):
+ def headers(self) -> Headers:
if not self.result:
raise Exception("No result yet.")
h = Headers()
@@ -108,7 +121,7 @@ class FakeChannel:
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
- return address.IPv4Address("TCP", "127.0.0.1", 3423)
+ return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self):
return None
@@ -124,7 +137,7 @@ class FakeChannel:
self._reactor.run()
x = 0
- while not self.result.get("done"):
+ while not self.is_finished():
# If there's a producer, tell it to resume producing so we get content
if self._producer:
self._producer.resumeProducing()
@@ -136,6 +149,16 @@ class FakeChannel:
self._reactor.advance(0.1)
+ def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
+ """Process the contents of any Set-Cookie headers in the response
+
+ Any cookines found are added to the given dict
+ """
+ for h in self.headers.getRawHeaders("Set-Cookie"):
+ parts = h.split(";")
+ k, v = parts[0].split("=", maxsplit=1)
+ cookies[k] = v
+
class FakeSite:
"""
@@ -174,11 +197,12 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
-):
+ client_ip: str = "127.0.0.1",
+) -> FakeChannel:
"""
Make a web request using the given method, path and content, and render it
- Returns the Request and the Channel underneath.
+ Returns the fake Channel object which records the response to the request.
Args:
site: The twisted Site to use to render the request
@@ -201,8 +225,11 @@ def make_request(
will pump the reactor until the the renderer tells the channel the request
is finished.
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
- Tuple[synapse.http.site.SynapseRequest, channel]
+ channel
"""
if not isinstance(method, bytes):
method = method.encode("ascii")
@@ -216,8 +243,9 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
+ if path.startswith(b"/"):
+ path = path[1:]
path = b"/_matrix/client/r0/" + path
- path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path
@@ -227,7 +255,7 @@ def make_request(
if isinstance(content, str):
content = content.encode("utf8")
- channel = FakeChannel(site, reactor)
+ channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel)
req.content = BytesIO(content)
@@ -258,12 +286,13 @@ def make_request(
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
+ req.parseCookies()
req.requestReceived(method, path, b"1.1")
if await_result:
channel.await_result()
- return req, channel
+ return channel
@implementer(IReactorPluggableNameResolver)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index e0a9cd93ac..4dd5a36178 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -70,7 +70,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
the notice URL + an authentication code.
"""
# Initial sync, to get the user consent room invite
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=self.access_token
)
self.assertEqual(channel.code, 200)
@@ -79,7 +79,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
room_id = list(channel.json_body["rooms"]["invite"].keys())[0]
# Join the room
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/" + room_id + "/join",
access_token=self.access_token,
@@ -87,7 +87,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
# Sync again, to get the message in the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=self.access_token
)
self.assertEqual(channel.code, 200)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 9c8027a5b2..fea54464af 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -305,7 +305,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.register_user("user", "password")
tok = self.login("user", "password")
- request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
invites = channel.json_body["rooms"]["invite"]
self.assertEqual(len(invites), 0, invites)
@@ -318,7 +318,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
# Sync again to retrieve the events in the room, so we can check whether this
# room has a notice in it.
- request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
# Scan the events in the room to search for a message from the server notices
# user.
@@ -353,9 +353,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
tok = self.login(localpart, "password")
# Sync with the user's token to mark the user as active.
- request, channel = self.make_request(
- "GET", "/sync?timeout=0", access_token=tok,
- )
+ channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,)
# Also retrieves the list of invites for this user. We don't care about that
# one except if we're processing the last user, which should have received an
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..77c72834f2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+ _get_auth_chain_difference,
+ lexicographical_topological_sort,
+ resolve_events_with_store,
+)
from synapse.types import EventID
from tests import unittest
@@ -84,7 +88,7 @@ class FakeEvent:
event_dict = {
"auth_events": [(a, {}) for a in auth_events],
"prev_events": [(p, {}) for p in prev_events],
- "event_id": self.node_id,
+ "event_id": self.event_id,
"sender": self.sender,
"type": self.type,
"content": self.content,
@@ -377,6 +381,61 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
+ def test_mainline_sort(self):
+ """Tests that the mainline ordering works correctly.
+ """
+
+ events = [
+ FakeEvent(
+ id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100, BOB: 50}},
+ ),
+ FakeEvent(
+ id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {ALICE: 100, BOB: 50},
+ "events": {EventTypes.PowerLevels: 100},
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100, BOB: 50}},
+ ),
+ FakeEvent(
+ id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ ]
+
+ edges = [
+ ["END", "T3", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "T4", "PB", "PA1"],
+ ]
+
+ # We expect T3 to be picked as the other topics are pointing at older
+ # power levels. Note that without mainline ordering we'd pick T4 due to
+ # it being sent *after* T3.
+ expected_state_ids = ["T3", "PA2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
def do_check(self, events, edges, expected_state_ids):
"""Take a list of events and edges and calculate the state of the
graph at END, and asserts it matches `expected_state_ids`
@@ -587,6 +646,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
self.assert_dict(self.expected_combined_state, state)
+class AuthChainDifferenceTestCase(unittest.TestCase):
+ """We test that `_get_auth_chain_difference` correctly handles unpersisted
+ events.
+ """
+
+ def test_simple(self):
+ # Test getting the auth difference for a simple chain with a single
+ # unpersisted event:
+ #
+ # Unpersisted | Persisted
+ # |
+ # C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c}
+
+ state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {c.event_id})
+
+ def test_multiple_unpersisted_chain(self):
+ # Test getting the auth difference for a simple chain with multiple
+ # unpersisted events:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D -> C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, c.event_id})
+
+ def test_unpersisted_events_different_sets(self):
+ # Test getting the auth difference for with multiple unpersisted events
+ # in different branches:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D --> C -|-> B -> A
+ # E ----^ -|---^
+ # |
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ e = FakeEvent(
+ id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id, b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, e.event_id})
+
+
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
@@ -647,7 +834,7 @@ class TestStateResolutionStore:
return list(result)
- def get_auth_chain_difference(self, auth_sets):
+ def get_auth_chain_difference(self, room_id, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
new file mode 100644
index 0000000000..673e1fe3e3
--- /dev/null
+++ b/tests/storage/test_account_data.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Set
+
+from synapse.api.constants import AccountDataTypes
+
+from tests import unittest
+
+
+class IgnoredUsersTestCase(unittest.HomeserverTestCase):
+ def prepare(self, hs, reactor, clock):
+ self.store = self.hs.get_datastore()
+ self.user = "@user:test"
+
+ def _update_ignore_list(
+ self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
+ ) -> None:
+ """Update the account data to block the given users."""
+ if ignorer_user_id is None:
+ ignorer_user_id = self.user
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ ignorer_user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {u: {} for u in ignored_user_ids}},
+ )
+ )
+
+ def assert_ignorers(
+ self, ignored_user_id: str, expected_ignorer_user_ids: Set[str]
+ ) -> None:
+ self.assertEqual(
+ self.get_success(self.store.ignored_by(ignored_user_id)),
+ expected_ignorer_user_ids,
+ )
+
+ def test_ignoring_users(self):
+ """Basic adding/removing of users from the ignore list."""
+ self._update_ignore_list("@other:test", "@another:remote")
+
+ # Check a user which no one ignores.
+ self.assert_ignorers("@user:test", set())
+
+ # Check a local user which is ignored.
+ self.assert_ignorers("@other:test", {self.user})
+
+ # Check a remote user which is ignored.
+ self.assert_ignorers("@another:remote", {self.user})
+
+ # Add one user, remove one user, and leave one user.
+ self._update_ignore_list("@foo:test", "@another:remote")
+
+ # Check the removed user.
+ self.assert_ignorers("@other:test", set())
+
+ # Check the added user.
+ self.assert_ignorers("@foo:test", {self.user})
+
+ # Check the removed user.
+ self.assert_ignorers("@another:remote", {self.user})
+
+ def test_caching(self):
+ """Ensure that caching works properly between different users."""
+ # The first user ignores a user.
+ self._update_ignore_list("@other:test")
+ self.assert_ignorers("@other:test", {self.user})
+
+ # The second user ignores them.
+ self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+ self.assert_ignorers("@other:test", {self.user, "@second:test"})
+
+ # The first user un-ignores them.
+ self._update_ignore_list()
+ self.assert_ignorers("@other:test", {"@second:test"})
+
+ def test_invalid_data(self):
+ """Invalid data ends up clearing out the ignored users list."""
+ # Add some data and ensure it is there.
+ self._update_ignore_list("@other:test")
+ self.assert_ignorers("@other:test", {self.user})
+
+ # No ignored_users key.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user, AccountDataTypes.IGNORED_USER_LIST, {},
+ )
+ )
+
+ # No one ignores the user now.
+ self.assert_ignorers("@other:test", set())
+
+ # Add some data and ensure it is there.
+ self._update_ignore_list("@other:test")
+ self.assert_ignorers("@other:test", {self.user})
+
+ # Invalid data.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": "unexpected"},
+ )
+ )
+
+ # No one ignores the user now.
+ self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ecb00f4e02..dabc1c5f09 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -80,6 +80,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_count_devices_by_users(self):
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device1", "display_name 1")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device2", "display_name 2")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id2", "device3", "display_name 3")
+ )
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ self.assertEqual(0, res)
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ self.assertEqual(0, res)
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ self.assertEqual(2, res)
+
+ res = yield defer.ensureDeferred(
+ self.store.count_devices_by_users(["user_id", "user_id2"])
+ )
+ self.assertEqual(3, res)
+
+ @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 35dafbb904..3d7760d5d9 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -26,7 +26,7 @@ room_key = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.store = hs.get_datastore()
return hs
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..0c46ad595b
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,741 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Set, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.storage.databases.main.events import _LinkMap
+from synapse.types import create_requester
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self._next_stream_ordering = 1
+
+ def test_simple(self):
+ """Test that the example in `docs/auth_chain_difference_algorithm.md`
+ works.
+ """
+
+ event_factory = self.hs.get_event_builder_factory()
+ bob = "@creator:test"
+ alice = "@alice:test"
+ room_id = "!room:test"
+
+ # Ensure that we have a rooms entry so that we generate the chain index.
+ self.get_success(
+ self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=True,
+ room_version=RoomVersions.V6,
+ )
+ )
+
+ create = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "create"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[])
+ )
+
+ bob_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+ )
+
+ power = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power"},
+ },
+ ).build(
+ prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+ )
+ )
+
+ alice_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "alice_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+ )
+ )
+
+ power_2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power_2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ bob_join_2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join_2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[
+ create.event_id,
+ alice_join.event_id,
+ power_2.event_id,
+ ],
+ )
+ )
+
+ events = [
+ create,
+ bob_join,
+ power,
+ alice_invite,
+ alice_join,
+ bob_join_2,
+ power_2,
+ alice_join2,
+ ]
+
+ expected_links = [
+ (bob_join, create),
+ (power, create),
+ (power, bob_join),
+ (alice_invite, create),
+ (alice_invite, power),
+ (alice_invite, bob_join),
+ (bob_join_2, power),
+ (alice_join2, power_2),
+ ]
+
+ self.persist(events)
+ chain_map, link_map = self.fetch_chains(events)
+
+ # Check that the expected links and only the expected links have been
+ # added.
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
+
+ # Test that everything can reach the create event, but the create event
+ # can't reach anything.
+ for event in events[1:]:
+ self.assertTrue(
+ link_map.exists_path_from(
+ chain_map[event.event_id], chain_map[create.event_id]
+ ),
+ )
+
+ self.assertFalse(
+ link_map.exists_path_from(
+ chain_map[create.event_id], chain_map[event.event_id],
+ ),
+ )
+
+ def test_out_of_order_events(self):
+ """Test that we handle persisting events that we don't have the full
+ auth chain for yet (which should only happen for out of band memberships).
+ """
+ event_factory = self.hs.get_event_builder_factory()
+ bob = "@creator:test"
+ alice = "@alice:test"
+ room_id = "!room:test"
+
+ # Ensure that we have a rooms entry so that we generate the chain index.
+ self.get_success(
+ self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=True,
+ room_version=RoomVersions.V6,
+ )
+ )
+
+ # First persist the base room.
+ create = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "create"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[])
+ )
+
+ bob_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+ )
+
+ power = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power"},
+ },
+ ).build(
+ prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+ )
+ )
+
+ self.persist([create, bob_join, power])
+
+ # Now persist an invite and a couple of memberships out of order.
+ alice_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "alice_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+ )
+ )
+
+ alice_join2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+ )
+ )
+
+ self.persist([alice_join])
+ self.persist([alice_join2])
+ self.persist([alice_invite])
+
+ # The end result should be sane.
+ events = [create, bob_join, power, alice_invite, alice_join]
+
+ chain_map, link_map = self.fetch_chains(events)
+
+ expected_links = [
+ (bob_join, create),
+ (power, create),
+ (power, bob_join),
+ (alice_invite, create),
+ (alice_invite, power),
+ (alice_invite, bob_join),
+ ]
+
+ # Check that the expected links and only the expected links have been
+ # added.
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
+
+ def persist(
+ self, events: List[EventBase],
+ ):
+ """Persist the given events and check that the links generated match
+ those given.
+ """
+
+ persist_events_store = self.hs.get_datastores().persist_events
+
+ for e in events:
+ e.internal_metadata.stream_ordering = self._next_stream_ordering
+ self._next_stream_ordering += 1
+
+ def _persist(txn):
+ # We need to persist the events to the events and state_events
+ # tables.
+ persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+ # Actually call the function that calculates the auth chain stuff.
+ persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+ self.get_success(
+ persist_events_store.db_pool.runInteraction("_persist", _persist,)
+ )
+
+ def fetch_chains(
+ self, events: List[EventBase]
+ ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+ # Fetch the map from event ID -> (chain ID, sequence number)
+ rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chains",
+ column="event_id",
+ iterable=[e.event_id for e in events],
+ retcols=("event_id", "chain_id", "sequence_number"),
+ keyvalues={},
+ )
+ )
+
+ chain_map = {
+ row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+ }
+
+ # Fetch all the links and pass them to the _LinkMap.
+ rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable=[chain_id for chain_id, _ in chain_map.values()],
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ keyvalues={},
+ )
+ )
+
+ link_map = _LinkMap()
+ for row in rows:
+ added = link_map.add_link(
+ (row["origin_chain_id"], row["origin_sequence_number"]),
+ (row["target_chain_id"], row["target_sequence_number"]),
+ )
+
+ # We shouldn't have persisted any redundant links
+ self.assertTrue(added)
+
+ return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+ def test_simple(self):
+ """Basic tests for the LinkMap.
+ """
+ link_map = _LinkMap()
+
+ link_map.add_link((1, 1), (2, 1), new=False)
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+ self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+ self.assertCountEqual(link_map.get_additions(), [])
+ self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+ self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+ # Attempting to add a redundant link is ignored.
+ self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+ # Adding new non-redundant links works
+ self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+ self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+ self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+ self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+
+
+class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.user_id = self.register_user("foo", "pass")
+ self.token = self.login("foo", "pass")
+ self.requester = create_requester(self.user_id)
+
+ def _generate_room(self) -> Tuple[str, List[Set[str]]]:
+ """Insert a room without a chain cover index.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Mark the room as not having a chain cover index
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": False},
+ desc="test",
+ )
+ )
+
+ # Create a fork in the DAG with different events.
+ event_handler = self.hs.get_event_creation_handler()
+ latest_event_ids = self.get_success(
+ self.store.get_prev_events_for_room(room_id)
+ )
+ event, context = self.get_success(
+ event_handler.create_event(
+ self.requester,
+ {
+ "type": "some_state_type",
+ "state_key": "",
+ "content": {},
+ "room_id": room_id,
+ "sender": self.user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(self.requester, event, context)
+ )
+ state1 = set(self.get_success(context.get_current_state_ids()).values())
+
+ event, context = self.get_success(
+ event_handler.create_event(
+ self.requester,
+ {
+ "type": "some_state_type",
+ "state_key": "",
+ "content": {},
+ "room_id": room_id,
+ "sender": self.user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(self.requester, event, context)
+ )
+ state2 = set(self.get_success(context.get_current_state_ids()).values())
+
+ # Delete the chain cover info.
+
+ def _delete_tables(txn):
+ txn.execute("DELETE FROM event_auth_chains")
+ txn.execute("DELETE FROM event_auth_chain_links")
+
+ self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
+
+ return room_id, [state1, state2]
+
+ def test_background_update_single_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create a room
+ room_id, states = self._generate_room()
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test",
+ self.store._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ states,
+ )
+ )
+
+ def test_background_update_multiple_rooms(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+ # Create a room
+ room_id1, states1 = self._generate_room()
+ room_id2, states2 = self._generate_room()
+ room_id3, states2 = self._generate_room()
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test",
+ self.store._get_auth_chain_difference_using_cover_index_txn,
+ room_id1,
+ states1,
+ )
+ )
+
+ def test_background_update_single_large_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create a room
+ room_id, states = self._generate_room()
+
+ # Add a bunch of state so that it takes multiple iterations of the
+ # background update to process the room.
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ iterations = 0
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ iterations += 1
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Ensure that we did actually take multiple iterations to process the
+ # room.
+ self.assertGreater(iterations, 1)
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test",
+ self.store._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ states,
+ )
+ )
+
+ def test_background_update_multiple_large_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create the rooms
+ room_id1, _ = self._generate_room()
+ room_id2, _ = self._generate_room()
+
+ # Add a bunch of state so that it takes multiple iterations of the
+ # background update to process the room.
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id1, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id2, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ iterations = 0
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ iterations += 1
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Ensure that we did actually take multiple iterations to process the
+ # room.
+ self.assertGreater(iterations, 1)
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
import tests.unittest
import tests.utils
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- def test_auth_difference(self):
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -159,77 +165,279 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
+ # Mark the room as not having a cover index
+
+ def store_room(txn):
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ "has_auth_chain_index": use_chain_cover_index,
+ },
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn, event_id, stream_ordering):
+ def insert_event(txn):
+ stream_ordering = 0
+
+ for event_id in auth_graph:
+ stream_ordering += 1
+ depth = depth_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ txn,
+ [
+ FakeEvent(event_id, room_id, auth_graph[event_id])
+ for event_id in auth_graph
+ ],
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+ # Now actually test that various combinations give the right result:
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}])
+ )
+ self.assertSetEqual(difference, set())
+
+ def test_auth_difference_partial_cover(self):
+ """Test that we correctly handle rooms where not all events have a chain
+ cover calculated. This can happen in some obscure edge cases, including
+ during the background update that calculates the chain cover for old
+ rooms.
+ """
+
+ room_id = "@ROOM:local"
+
+ # The silly auth graph we use to test the auth difference algorithm,
+ # where the top are the most recent events.
+ #
+ # A B
+ # \ /
+ # D E
+ # \ |
+ # ` F C
+ # | /|
+ # G ´ |
+ # | \ |
+ # H I
+ # | |
+ # K J
+
+ auth_graph = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+ }
- depth = depth_map[event_id]
+ depth_map = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+ }
+
+ # We rudely fiddle with the appropriate tables directly, as that's much
+ # easier than constructing events properly.
+ def insert_event(txn):
+ # First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
- table="events",
- values={
- "event_id": event_id,
+ "rooms",
+ {
"room_id": room_id,
- "depth": depth,
- "topological_ordering": depth,
- "type": "m.test",
- "processed": True,
- "outlier": False,
- "stream_ordering": stream_ordering,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ "has_auth_chain_index": True,
},
)
- self.store.db_pool.simple_insert_many_txn(
+ stream_ordering = 0
+
+ for event_id in auth_graph:
+ stream_ordering += 1
+ depth = depth_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ # Insert all events apart from 'B'
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
- table="event_auth",
- values=[
- {"event_id": event_id, "room_id": room_id, "auth_id": a}
- for a in auth_graph[event_id]
+ [
+ FakeEvent(event_id, room_id, auth_graph[event_id])
+ for event_id in auth_graph
+ if event_id != "b"
],
)
- next_stream_ordering = 0
- for event_id in auth_graph:
- next_stream_ordering += 1
- self.get_success(
- self.store.db_pool.runInteraction(
- "insert", insert_event, event_id, next_stream_ordering
- )
+ # Now we insert the event 'B' without a chain cover, by temporarily
+ # pretending the room doesn't have a chain cover.
+
+ self.store.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": False},
+ )
+
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ txn, [FakeEvent("b", room_id, auth_graph["b"])],
+ )
+
+ self.store.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": True},
)
+ self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
)
self.assertSetEqual(difference, {"a", "b"})
- difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}])
+ )
self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+ event_id = attr.ib()
+ room_id = attr.ib()
+ auth_events = attr.ib()
+
+ type = "foo"
+ state_key = "foo"
+
+ internal_metadata = _EventInternalMetadata({})
+
+ def auth_event_ids(self):
+ return self.auth_events
+
+ def is_state(self):
+ return True
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
new file mode 100644
index 0000000000..71210ce606
--- /dev/null
+++ b/tests/storage/test_events.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class ExtremPruneTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.state = self.hs.get_state_handler()
+ self.persistence = self.hs.get_storage().persistence
+ self.store = self.hs.get_datastore()
+
+ self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ "user", room_version=RoomVersions.V6.identifier, tok=self.token
+ )
+
+ body = self.helper.send(self.room_id, body="Test", tok=self.token)
+ local_message_event_id = body["event_id"]
+
+ # Fudge a remote event and persist it. This will be the extremity before
+ # the gap.
+ self.remote_event_1 = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "state_key": "@user:other",
+ "content": {},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 5,
+ "prev_events": [local_message_event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ self.persist_event(self.remote_event_1)
+
+ # Check that the current extremities is the remote event.
+ self.assert_extremities([self.remote_event_1.event_id])
+
+ def persist_event(self, event, state=None):
+ """Persist the event, with optional state
+ """
+ context = self.get_success(
+ self.state.compute_event_context(event, old_state=state)
+ )
+ self.get_success(self.persistence.persist_event(event, context))
+
+ def assert_extremities(self, expected_extremities):
+ """Assert the current extremities for the room
+ """
+ extremities = self.get_success(
+ self.store.get_prev_events_for_room(self.room_id)
+ )
+ self.assertCountEqual(extremities, expected_extremities)
+
+ def test_prune_gap(self):
+ """Test that we drop extremities after a gap when we see an event from
+ the same domain.
+ """
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 50,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_do_not_prune_gap_if_state_different(self):
+ """Test that we don't prune extremities after a gap if the resolved
+ state is different.
+ """
+
+ # Fudge a second event which points to an event we don't have.
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "state_key": "@user:other",
+ "content": {},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 10,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ # Now we persist it with state with a dropped history visibility
+ # setting. The state resolution across the old and new event will then
+ # include it, and so the resolved state won't match the new state.
+ state_before_gap = dict(
+ self.get_success(self.state.get_current_state(self.room_id))
+ )
+ state_before_gap.pop(("m.room.history_visibility", ""))
+
+ context = self.get_success(
+ self.state.compute_event_context(
+ remote_event_2, old_state=state_before_gap.values()
+ )
+ )
+
+ self.get_success(self.persistence.persist_event(remote_event_2, context))
+
+ # Check that we haven't dropped the old extremity.
+ self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+ def test_prune_gap_if_old(self):
+ """Test that we drop extremities after a gap when the previous extremity
+ is "old"
+ """
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_do_not_prune_gap_if_other_server(self):
+ """Test that we do not drop extremities after a gap when we see an event
+ from a different domain.
+ """
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+ def test_prune_gap_if_dummy_remote(self):
+ """Test that we drop extremities after a gap when the previous extremity
+ is a local dummy event and only points to remote events.
+ """
+
+ body = self.helper.send_event(
+ self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+ )
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_prune_gap_if_dummy_local(self):
+ """Test that we don't drop extremities after a gap when the previous
+ extremity is a local dummy event and points to local events.
+ """
+
+ body = self.helper.send(self.room_id, body="Test", tok=self.token)
+
+ body = self.helper.send_event(
+ self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+ )
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id, local_message_event_id])
+
+ def test_do_not_prune_gap_if_not_dummy(self):
+ """Test that we do not drop extremities after a gap when the previous extremity
+ is not a dummy event.
+ """
+
+ body = self.helper.send(self.room_id, body="test", tok=self.token)
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([local_message_event_id, remote_event_2.event_id])
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index cc0612cf65..3e2fd4da01 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
- table="foobar",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
)
@@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
- table="foobar",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
positive=False,
@@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
+
+
+class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db_pool = self.store.db_pool # type: DatabasePool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar1 (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ txn.execute(
+ """
+ CREATE TABLE foobar2 (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ stream_name="test_stream",
+ instance_name=instance_name,
+ tables=[
+ ("foobar1", "instance_name", "stream_id"),
+ ("foobar2", "instance_name", "stream_id"),
+ ],
+ sequence_name="foobar_seq",
+ writers=writers,
+ )
+
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+ def _insert_rows(
+ self,
+ table: str,
+ instance_name: str,
+ number: int,
+ update_stream_table: bool = True,
+ ):
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
+
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
+ (instance_name,),
+ )
+ if update_stream_table:
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+ def test_load_existing_stream(self):
+ """Test creating ID gens with multiple tables that have rows from after
+ the position in `stream_positions` table.
+ """
+ self._insert_rows("foobar1", "first", 3)
+ self._insert_rows("foobar2", "second", 3)
+ self._insert_rows("foobar2", "second", 1, update_stream_table=False)
+
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+ # The first ID gen will notice that it can advance its token to 7 as it
+ # has no in progress writes...
+ self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+ self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
+
+ # ... but the second ID gen doesn't know that.
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..e9e3bca3bf 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -48,3 +48,10 @@ class DataStoreTestCase(unittest.TestCase):
self.assertEquals(1, total)
self.assertEquals(self.displayname, users.pop()["displayname"])
+
+ users, total = yield defer.ensureDeferred(
+ self.store.get_users_paginate(0, 10, name="BC", guests=False)
+ )
+
+ self.assertEquals(1, total)
+ self.assertEquals(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..ea63bd56b4 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase):
),
)
+ # test set to None
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.u_frank.localpart, None)
+ )
+
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.u_frank.localpart)
+ )
+ )
+ )
+
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
@@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase):
)
),
)
+
+ # test set to None
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(self.u_frank.localpart, None)
+ )
+
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.u_frank.localpart)
+ )
+ )
+ )
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index cc1f3c53c5..a06ad2c03e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d4f9e809db..a6303bf0ee 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
-
from canonicaljson import json
from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["redaction_retention_period"] = "30d"
- return self.setup_test_homeserver(
- resource_for_federation=Mock(), http_client=None, config=config
- )
+ return config
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index ff972daeaa..d2aed66f6d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
-
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(
- resource_for_federation=Mock(), http_client=None
- )
- return hs
-
def prepare(self, reactor, clock, hs: TestHomeServer):
# We can't test the RoomMemberStore on its own without the other event
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 738e912468..a6f63f4aaf 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -21,6 +21,8 @@ from tests.utils import setup_test_homeserver
ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
+# The localpart isn't 'Bela' on purpose so we can test looking up display names.
+BELA = "@somenickname:a"
class UserDirectoryStoreTestCase(unittest.TestCase):
@@ -41,6 +43,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
)
yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BELA, "Bela", None)
+ )
+ yield defer.ensureDeferred(
self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
)
@@ -72,3 +77,21 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
)
finally:
self.hs.config.user_directory_search_all_users = False
+
+ @defer.inlineCallbacks
+ def test_search_user_dir_stop_words(self):
+ """Tests that a user can look up another user by searching for the start if its
+ display name even if that name happens to be a common English word that would
+ usually be ignored in full text searches.
+ """
+ self.hs.config.user_directory_search_all_users = True
+ try:
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
+ finally:
+ self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 1ce4ea3a01..fc9aab32d0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -37,7 +37,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
self.addCleanup,
- http_client=self.http_client,
+ federation_http_client=self.http_client,
clock=self.hs_clock,
reactor=self.reactor,
)
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- with LoggingContext(request="lying_event"):
+ with LoggingContext():
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_mau.py b/tests/test_mau.py
index c5ec6396a7..51660b51d5 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -19,6 +19,7 @@ import json
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
@@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ def test_as_ignores_mau(self):
+ """Test that application services can still create users when the MAU
+ limit has been reached. This only works when application service
+ user ip tracking is disabled.
+ """
+
+ # Create and sync so that the MAU counts get updated
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # check we're testing what we think we are: there should be two active users
+ self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
+
+ # We've created and activated two users, we shouldn't be able to
+ # register new users
+ with self.assertRaises(SynapseError) as cm:
+ self.create_user("kermit3")
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ # Cheekily add an application service that we use to register a new user
+ # with.
+ as_token = "foobartoken"
+ self.store.services_cache.append(
+ ApplicationService(
+ token=as_token,
+ hostname=self.hs.hostname,
+ id="SomeASID",
+ sender="@as_sender:test",
+ namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
+ )
+ )
+
+ self.create_user("as_kermit4", token=as_token)
+
def test_allowed_after_a_month_mau(self):
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
@@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count))
- def create_user(self, localpart):
+ def create_user(self, localpart, token=None):
request_data = json.dumps(
{
"username": localpart,
@@ -201,7 +241,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request("POST", "/register", request_data)
+ channel = self.make_request(
+ "POST", "/register", request_data, access_token=token,
+ )
if channel.code != 200:
raise HttpResponseException(
@@ -213,7 +255,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
def do_sync_for_user(self, token):
- request, channel = self.make_request("GET", "/sync", access_token=token)
+ channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200:
raise HttpResponseException(
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 7f67ee9e1f..0c6cbbd921 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -20,8 +20,16 @@ from synapse.rest.media.v1.preview_url_resource import (
from . import unittest
+try:
+ import lxml
+except ImportError:
+ lxml = None
+
class PreviewTestCase(unittest.TestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
+
def test_long_summarize(self):
example_paras = [
"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
@@ -56,7 +64,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -69,7 +77,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø lies in Northern Norway. The municipality has a population of"
" (2015) 72,066, but with an annual influx of students it has over 75,000"
@@ -96,7 +104,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -122,7 +130,7 @@ class PreviewTestCase(unittest.TestCase):
]
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -137,6 +145,9 @@ class PreviewTestCase(unittest.TestCase):
class PreviewUrlTestCase(unittest.TestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
+
def test_simple(self):
html = """
<html>
@@ -149,7 +160,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment(self):
html = """
@@ -164,7 +175,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment2(self):
html = """
@@ -182,7 +193,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(
+ self.assertEqual(
og,
{
"og:title": "Foo",
@@ -203,7 +214,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_missing_title(self):
html = """
@@ -216,7 +227,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_h1_as_title(self):
html = """
@@ -230,7 +241,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_missing_title_and_broken_h1(self):
html = """
@@ -244,4 +255,38 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+
+ def test_empty(self):
+ html = ""
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {})
+
+ def test_invalid_encoding(self):
+ """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
+ html = """
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ og = decode_and_calc_og(
+ html, "http://example.com/test.html", "invalid-encoding"
+ )
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
+
+ def test_invalid_encoding2(self):
+ """A body which doesn't match the sent character encoding."""
+ # Note that this contains an invalid UTF-8 sequence in the title.
+ html = b"""
+ <html>
+ <head><title>\xff\xff Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
diff --git a/tests/test_server.py b/tests/test_server.py
index c387a85f2e..815da18e65 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase):
self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
+ self.addCleanup,
+ federation_http_client=None,
+ clock=self.hs_clock,
+ reactor=self.reactor,
)
def test_handler_for_request(self):
@@ -61,11 +64,10 @@ class JsonResourceTests(unittest.TestCase):
"test_servlet",
)
- request, channel = make_request(
+ make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
)
- self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]})
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
def test_callback_direct_exception(self):
@@ -82,7 +84,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"500")
@@ -106,7 +108,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"500")
@@ -124,7 +126,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -146,9 +148,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- _, channel = make_request(
- self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
- )
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
self.assertEqual(channel.result["code"], b"400")
self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -170,7 +170,7 @@ class JsonResourceTests(unittest.TestCase):
)
# The path was registered as GET, but this is a HEAD request.
- _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
+ channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
@@ -202,7 +202,7 @@ class OptionsResourceTests(unittest.TestCase):
)
# render the request and return the channel
- _, channel = make_request(self.reactor, site, method, path, shorthand=False)
+ channel = make_request(self.reactor, site, method, path, shorthand=False)
return channel
def test_unknown_options_request(self):
@@ -275,7 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"]
@@ -293,7 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"]
@@ -314,7 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"]
@@ -333,7 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
+ channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 71580b454d..a743cdc3a9 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
def test_ui_auth(self):
# Do a UI auth request
request_data = json.dumps({"username": "kermit", "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler.check_username = Mock(return_value=True)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
# We don't bother checking that the response is correct - we'll leave that to
# other tests. We just want to make sure we're on the right path.
@@ -113,7 +113,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
},
}
)
- request, channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, request_data)
# We're interested in getting a response that looks like a successful
# registration, not so much that the details are exactly what we want.
diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..acdeea7a09 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -58,6 +58,10 @@ class RoomAliasTestCase(unittest.HomeserverTestCase):
self.assertEquals(room.to_string(), "#channel:my.domain")
+ def test_validate(self):
+ id_string = "#test:domain,test"
+ self.assertFalse(RoomAlias.is_valid(id_string))
+
class GroupIDTestCase(unittest.TestCase):
def test_parse(self):
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,13 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
+from mock import Mock
+
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
TV = TypeVar("TV")
@@ -80,3 +87,35 @@ def setup_awaitable_errors() -> Callable[[], None]:
sys.unraisablehook = unraisablehook # type: ignore
return cleanup
+
+
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+@attr.s
+class FakeResponse:
+ """A fake twisted.web.IResponse object
+
+ there is a similar class at treq.test.test_response, but it lacks a `phrase`
+ attribute, and didn't support deliverBody until recently.
+ """
+
+ # HTTP response code
+ code = attr.ib(type=int)
+
+ # HTTP response phrase (eg b'OK' for a 200)
+ phrase = attr.ib(type=bytes)
+
+ # body of the response
+ body = attr.ib(type=bytes)
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
new file mode 100644
index 0000000000..ad563eb3f0
--- /dev/null
+++ b/tests/test_utils/html_parsers.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from html.parser import HTMLParser
+from typing import Dict, Iterable, List, Optional, Tuple
+
+
+class TestHtmlParser(HTMLParser):
+ """A generic HTML page parser which extracts useful things from the HTML"""
+
+ def __init__(self):
+ super().__init__()
+
+ # a list of links found in the doc
+ self.links = [] # type: List[str]
+
+ # the values of any hidden <input>s: map from name to value
+ self.hiddens = {} # type: Dict[str, Optional[str]]
+
+ # the values of any radio buttons: map from name to list of values
+ self.radios = {} # type: Dict[str, List[Optional[str]]]
+
+ def handle_starttag(
+ self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+ ) -> None:
+ attr_dict = dict(attrs)
+ if tag == "a":
+ href = attr_dict["href"]
+ if href:
+ self.links.append(href)
+ elif tag == "input":
+ input_name = attr_dict.get("name")
+ if attr_dict["type"] == "radio":
+ assert input_name
+ self.radios.setdefault(input_name, []).append(attr_dict["value"])
+ elif attr_dict["type"] == "hidden":
+ assert input_name
+ self.hiddens[input_name] = attr_dict["value"]
+
+ def error(_, message):
+ raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging():
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
- handler.addFilter(LoggingContextFilter(request=""))
+ handler.addFilter(LoggingContextFilter())
root_logger.addHandler(handler)
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..767d5d6077 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
from mock import Mock, patch
@@ -46,6 +46,7 @@ from synapse.logging.context import (
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
"""
Create a the root resource for the test server.
- The default implementation creates a JsonResource and calls each function in
- `servlets` to register servletes against it
+ The default calls `self.create_resource_dict` and builds the resultant dict
+ into a tree.
"""
- resource = JsonResource(self.hs)
+ root_resource = Resource()
+ create_resource_tree(self.create_resource_dict(), root_resource)
+ return root_resource
- for servlet in self.servlets:
- servlet(self.hs, resource)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ """Create a resource tree for the test server
- return resource
+ A resource tree is a mapping from path to twisted.web.resource.
+
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servlets against it.
+ """
+ servlet_resource = JsonResource(self.hs)
+ for servlet in self.servlets:
+ servlet(self.hs, servlet_resource)
+ return {
+ "/_matrix/client": servlet_resource,
+ "/_synapse/admin": servlet_resource,
+ }
def default_config(self):
"""
@@ -358,24 +372,6 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses.
"""
- # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
- # when the `request` arg isn't given, so we define an explicit override to
- # cover that case.
- @overload
- def make_request(
- self,
- method: Union[bytes, str],
- path: Union[bytes, str],
- content: Union[bytes, dict] = b"",
- access_token: Optional[str] = None,
- shorthand: bool = True,
- federation_auth_origin: str = None,
- content_is_form: bool = False,
- await_result: bool = True,
- ) -> Tuple[SynapseRequest, FakeChannel]:
- ...
-
- @overload
def make_request(
self,
method: Union[bytes, str],
@@ -387,21 +383,11 @@ class HomeserverTestCase(TestCase):
federation_auth_origin: str = None,
content_is_form: bool = False,
await_result: bool = True,
- ) -> Tuple[T, FakeChannel]:
- ...
-
- def make_request(
- self,
- method: Union[bytes, str],
- path: Union[bytes, str],
- content: Union[bytes, dict] = b"",
- access_token: Optional[str] = None,
- request: Type[T] = SynapseRequest,
- shorthand: bool = True,
- federation_auth_origin: str = None,
- content_is_form: bool = False,
- await_result: bool = True,
- ) -> Tuple[T, FakeChannel]:
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ client_ip: str = "127.0.0.1",
+ ) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
given content.
@@ -423,8 +409,13 @@ class HomeserverTestCase(TestCase):
true (the default), will pump the test reactor until the the renderer
tells the channel the request is finished.
+ custom_headers: (name, value) pairs to add as request headers
+
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
- Tuple[synapse.http.site.SynapseRequest, channel]
+ The FakeChannel object which stores the result of the request.
"""
return make_request(
self.reactor,
@@ -438,6 +429,8 @@ class HomeserverTestCase(TestCase):
federation_auth_origin,
content_is_form,
await_result,
+ custom_headers,
+ client_ip,
)
def setup_test_homeserver(self, *args, **kwargs):
@@ -554,7 +547,7 @@ class HomeserverTestCase(TestCase):
self.hs.config.registration_shared_secret = "shared"
# Create the user
- request, channel = self.make_request("GET", "/_synapse/admin/v1/register")
+ channel = self.make_request("GET", "/_synapse/admin/v1/register")
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
@@ -579,7 +572,7 @@ class HomeserverTestCase(TestCase):
"inhibit_login": True,
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/_synapse/admin/v1/register", body.encode("utf8")
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -597,7 +590,7 @@ class HomeserverTestCase(TestCase):
if device_id:
body["device_id"] = device_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.assertEqual(channel.code, 200, channel.result)
@@ -665,7 +658,7 @@ class HomeserverTestCase(TestCase):
"""
body = {"type": "m.login.password", "user": username, "password": password}
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.assertEqual(channel.code, 403, channel.result)
@@ -691,13 +684,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
- def prepare(self, reactor, clock, homeserver):
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+ return d
+
+
+class TestTransportLayerServer(JsonResource):
+ """A test implementation of TransportLayerServer
+
+ authenticates incoming requests as `other.example.com`.
+ """
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
+ authenticator = Authenticator()
+
ratelimiter = FederationRateLimiter(
- clock,
+ hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@@ -706,11 +715,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
- federation_server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
- return super().prepare(reactor, clock, homeserver)
+ federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):
@@ -735,3 +741,29 @@ def override_config(extra_config):
return func
return decorator
+
+
+TV = TypeVar("TV")
+
+
+def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
+ """A test decorator which will skip the decorated test unless a condition is set
+
+ For example:
+
+ class MyTestCase(TestCase):
+ @skip_unless(HAS_FOO, "Cannot test without foo")
+ def test_foo(self):
+ ...
+
+ Args:
+ condition: If true, the test will be skipped
+ reason: the reason to give for skipping the test
+ """
+
+ def decorator(f: TV) -> TV:
+ if not condition:
+ f.skip = reason # type: ignore
+ return f
+
+ return decorator
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index dadfabd46d..ecd9efc4df 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -25,13 +25,8 @@ from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
def test_empty(self):
cache = DeferredCache("test")
- failed = False
- try:
+ with self.assertRaises(KeyError):
cache.get("foo")
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
def test_hit(self):
cache = DeferredCache("test")
@@ -155,13 +150,8 @@ class DeferredCacheTestCase(TestCase):
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
- failed = False
- try:
+ with self.assertRaises(KeyError):
cache.get(("foo",))
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
def test_invalidate_all(self):
cache = DeferredCache("testcache")
@@ -215,13 +205,8 @@ class DeferredCacheTestCase(TestCase):
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
- failed = False
- try:
+ with self.assertRaises(KeyError):
cache.get(1)
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
cache.get(2)
cache.get(3)
@@ -239,13 +224,55 @@ class DeferredCacheTestCase(TestCase):
cache.prefill(3, "three")
- failed = False
- try:
+ with self.assertRaises(KeyError):
cache.get(2)
- except KeyError:
- failed = True
- self.assertTrue(failed)
+ cache.get(1)
+ cache.get(3)
+
+ def test_eviction_iterable(self):
+ cache = DeferredCache(
+ "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
+ )
+
+ cache.prefill(1, ["one", "two"])
+ cache.prefill(2, ["three"])
+ # Now access 1 again, thus causing 2 to be least-recently used
+ cache.get(1)
+
+ # Now add an item to the cache, which evicts 2.
+ cache.prefill(3, ["four"])
+ with self.assertRaises(KeyError):
+ cache.get(2)
+
+ # Ensure 1 & 3 are in the cache.
cache.get(1)
cache.get(3)
+
+ # Now access 1 again, thus causing 3 to be least-recently used
+ cache.get(1)
+
+ # Now add an item with multiple elements to the cache
+ cache.prefill(4, ["five", "six"])
+
+ # Both 1 and 3 are evicted since there's too many elements.
+ with self.assertRaises(KeyError):
+ cache.get(1)
+ with self.assertRaises(KeyError):
+ cache.get(3)
+
+ # Now add another item to fill the cache again.
+ cache.prefill(5, ["seven"])
+
+ # Now access 4, thus causing 5 to be least-recently used
+ cache.get(4)
+
+ # Add an empty item.
+ cache.prefill(6, [])
+
+ # 5 gets evicted and replaced since an empty element counts as an item.
+ with self.assertRaises(KeyError):
+ cache.get(5)
+ cache.get(4)
+ cache.get(6)
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..1ef0af8e8f 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
from tests.unittest import TestCase
@@ -45,3 +47,60 @@ class ChunkSeqTests(TestCase):
self.assertEqual(
list(parts), [],
)
+
+
+class SortTopologically(TestCase):
+ def test_empty(self):
+ "Test that an empty graph works correctly"
+
+ graph = {} # type: Dict[int, List[int]]
+ self.assertEqual(list(sorted_topologically([], graph)), [])
+
+ def test_handle_empty_graph(self):
+ "Test that a graph where a node doesn't have an entry is treated as empty"
+
+ graph = {} # type: Dict[int, List[int]]
+
+ # For disconnected nodes the output is simply sorted.
+ self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+ def test_disconnected(self):
+ "Test that a graph with no edges work"
+
+ graph = {1: [], 2: []} # type: Dict[int, List[int]]
+
+ # For disconnected nodes the output is simply sorted.
+ self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+ def test_linear(self):
+ "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+ graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+ def test_subset(self):
+ "Test that only sorting a subset of the graph works"
+ graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+ def test_fork(self):
+ "Test that a forked graph works"
+ graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]]
+
+ # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+ # always get the same one.
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+ def test_duplicates(self):
+ "Test that a graph with duplicate edges work"
+ graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+ def test_multiple_paths(self):
+ "Test that a graph with multiple paths between two nodes work"
+ graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
diff --git a/tests/utils.py b/tests/utils.py
index c8d3ffbaba..840b657f82 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,13 +20,12 @@ import os
import time
import uuid
import warnings
-from inspect import getcallargs
from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
@@ -34,15 +33,12 @@ from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.federation.transport import server as federation_server
-from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
-from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
#
@@ -161,6 +157,7 @@ def default_config(name, parse=False):
"local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
"saml2_enabled": False,
"public_baseurl": None,
"default_identity_server": None,
@@ -342,32 +339,9 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
- fed = kwargs.get("resource_for_federation", None)
- if fed:
- register_federation_servlets(hs, fed)
-
return hs
-def register_federation_servlets(hs, resource):
- federation_server.register_servlets(
- hs,
- resource=resource,
- authenticator=federation_server.Authenticator(hs),
- ratelimiter=FederationRateLimiter(
- hs.get_clock(), config=hs.config.rc_federation
- ),
- )
-
-
-def get_mock_call_args(pattern_func, mock_func):
- """ Return the arguments the mock function was called with interpreted
- by the pattern functions argument list.
- """
- invoked_args, invoked_kargs = mock_func.call_args
- return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
-
-
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
@@ -378,7 +352,7 @@ def mock_getRawHeaders(headers=None):
# This is a mock /resource/ not an entire server
-class MockHttpResource(HttpServer):
+class MockHttpResource:
def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix
@@ -553,86 +527,6 @@ class MockClock:
return d
-def _format_call(args, kwargs):
- return ", ".join(
- ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
- )
-
-
-class DeferredMockCallable:
- """A callable instance that stores a set of pending call expectations and
- return values for them. It allows a unit test to assert that the given set
- of function calls are eventually made, by awaiting on them to be called.
- """
-
- def __init__(self):
- self.expectations = []
- self.calls = []
-
- def __call__(self, *args, **kwargs):
- self.calls.append((args, kwargs))
-
- if not self.expectations:
- raise ValueError(
- "%r has no pending calls to handle call(%s)"
- % (self, _format_call(args, kwargs))
- )
-
- for (call, result, d) in self.expectations:
- if args == call[1] and kwargs == call[2]:
- d.callback(None)
- return result
-
- failure = AssertionError(
- "Was not expecting call(%s)" % (_format_call(args, kwargs))
- )
-
- for _, _, d in self.expectations:
- try:
- d.errback(failure)
- except Exception:
- pass
-
- raise failure
-
- def expect_call_and_return(self, call, result):
- self.expectations.append((call, result, defer.Deferred()))
-
- @defer.inlineCallbacks
- def await_calls(self, timeout=1000):
- deferred = defer.DeferredList(
- [d for _, _, d in self.expectations], fireOnOneErrback=True
- )
-
- timer = reactor.callLater(
- timeout / 1000,
- deferred.errback,
- AssertionError(
- "%d pending calls left: %s"
- % (
- len([e for e in self.expectations if not e[2].called]),
- [e for e in self.expectations if not e[2].called],
- )
- ),
- )
-
- yield deferred
-
- timer.cancel()
-
- self.calls = []
-
- def assert_had_no_calls(self):
- if self.calls:
- calls = self.calls
- self.calls = []
-
- raise AssertionError(
- "Expected not to received any calls, got:\n"
- + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
- )
-
-
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
"""
|