diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/config/test_tls.py | 29 | ||||
-rw-r--r-- | tests/events/test_utils.py | 35 | ||||
-rw-r--r-- | tests/handlers/test_directory.py | 190 | ||||
-rw-r--r-- | tests/http/federation/test_matrix_federation_agent.py | 6 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_events.py | 10 | ||||
-rw-r--r-- | tests/rest/admin/test_admin.py | 7 | ||||
-rw-r--r-- | tests/rest/client/v1/test_directory.py | 41 | ||||
-rw-r--r-- | tests/rest/client/v1/test_login.py | 111 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 160 | ||||
-rw-r--r-- | tests/state/test_v2.py | 13 | ||||
-rw-r--r-- | tests/storage/test_event_federation.py | 157 | ||||
-rw-r--r-- | tests/storage/test_monthly_active_users.py | 42 | ||||
-rw-r--r-- | tests/test_event_auth.py | 93 | ||||
-rw-r--r-- | tests/test_types.py | 2 |
14 files changed, 784 insertions, 112 deletions
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 1be6ff563b..ec32d4b1ca 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -23,7 +23,7 @@ from OpenSSL import SSL from synapse.config._base import Config, RootConfig from synapse.config.tls import ConfigError, TlsConfig -from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.crypto.context_factory import FederationPolicyForHTTPS from tests.unittest import TestCase @@ -180,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) + options = _get_ssl_context_options(cf._verify_ssl_context) # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 - self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) - self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0) + self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) def test_tls_client_minimum_set_passed_through_1_0(self): """ @@ -195,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) + options = _get_ssl_context_options(cf._verify_ssl_context) # The context has not had any of the NO_TLS set. - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) def test_acme_disabled_in_generated_config_no_acme_domain_provied(self): """ @@ -273,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) # Not in the whitelist opts = cf.get_options(b"notexample.com") @@ -282,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= # Caught by the wildcard opts = cf.get_options(idna.encode("ใในใ.ใใกใคใณ.ใในใ")) self.assertFalse(opts._verifier._verify_certs) + + +def _get_ssl_context_options(ssl_context: SSL.Context) -> int: + """get the options bits from an openssl context object""" + # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to + # use the low-level interface + return SSL._lib.SSL_CTX_get_options(ssl_context._context) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 45d55b9e94..ab5f5ac549 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( copy_power_levels_contents, @@ -36,9 +37,9 @@ class PruneEventTestCase(unittest.TestCase): """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. """ - def run_test(self, evdict, matchdict): + def run_test(self, evdict, matchdict, **kwargs): self.assertEquals( - prune_event(make_event_from_dict(evdict)).get_dict(), matchdict + prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict ) def test_minimal(self): @@ -128,6 +129,36 @@ class PruneEventTestCase(unittest.TestCase): }, ) + def test_alias_event(self): + """Alias events have special behavior up through room version 6.""" + self.run_test( + { + "type": "m.room.aliases", + "event_id": "$test:domain", + "content": {"aliases": ["test"]}, + }, + { + "type": "m.room.aliases", + "event_id": "$test:domain", + "content": {"aliases": ["test"]}, + "signatures": {}, + "unsigned": {}, + }, + ) + + def test_msc2432_alias_event(self): + """After MSC2432, alias events have no special behavior.""" + self.run_test( + {"type": "m.room.aliases", "content": {"aliases": ["test"]}}, + { + "type": "m.room.aliases", + "content": {}, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.MSC2432_DEV, + ) + class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 27b916aed4..5e40adba52 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,6 +18,7 @@ from mock import Mock from twisted.internet import defer +import synapse import synapse.api.errors from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig @@ -87,50 +88,131 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_delete_alias_not_allowed(self): - room_id = "!8765qwer:test" + def test_incoming_fed_query(self): self.get_success( - self.store.create_room_alias_association(self.my_room, room_id, ["test"]) + self.store.create_room_alias_association( + self.your_room, "!8765asdf:test", ["test"] + ) + ) + + response = self.get_success( + self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) + self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + + +class TestDeleteAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def _create_alias(self, user): + # Create a new alias to this room. + self.get_success( + self.store.create_room_alias_association( + self.room_alias, self.room_id, ["test"], user + ) + ) + + def test_delete_alias_not_allowed(self): + """A user that doesn't meet the expected guidelines cannot delete an alias.""" + self._create_alias(self.admin_user) self.get_failure( self.handler.delete_association( - create_requester("@user:test"), self.my_room + create_requester(self.test_user), self.room_alias ), synapse.api.errors.AuthError, ) - def test_delete_alias(self): - room_id = "!8765qwer:test" - user_id = "@user:test" - self.get_success( - self.store.create_room_alias_association( - self.my_room, room_id, ["test"], user_id + def test_delete_alias_creator(self): + """An alias creator can delete their own alias.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias ) ) + self.assertEquals(self.room_id, result) + + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + def test_delete_alias_admin(self): + """A server admin can delete an alias created by another user.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + # Delete the user's alias as the admin. result = self.get_success( - self.handler.delete_association(create_requester(user_id), self.my_room) + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) ) - self.assertEquals(room_id, result) + self.assertEquals(self.room_id, result) - # The alias should not be found. + # Confirm the alias is gone. self.get_failure( - self.handler.get_association(self.my_room), synapse.api.errors.SynapseError + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, ) - def test_incoming_fed_query(self): - self.get_success( - self.store.create_room_alias_association( - self.your_room, "!8765asdf:test", ["test"] - ) + def test_delete_alias_sufficient_power(self): + """A user with a sufficient power level should be able to delete an alias.""" + self._create_alias(self.admin_user) + + # Increase the user's power level. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"users": {self.test_user: 100}}, + tok=self.admin_user_tok, ) - response = self.get_success( - self.handler.on_directory_query({"room_alias": "#your-room:test"}) + # They can now delete the alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ) ) + self.assertEquals(self.room_id, result) - self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) class CanonicalAliasTestCase(unittest.HomeserverTestCase): @@ -159,30 +241,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) self.test_alias = "#test:test" - self.room_alias = RoomAlias.from_string(self.test_alias) + self.room_alias = self._add_alias(self.test_alias) + + def _add_alias(self, alias: str) -> RoomAlias: + """Add an alias to the test room.""" + room_alias = RoomAlias.from_string(alias) # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( - self.room_alias, self.room_id, ["test"], self.admin_user + room_alias, self.room_id, ["test"], self.admin_user ) ) + return room_alias - def test_remove_alias(self): - """Removing an alias that is the canonical alias should remove it there too.""" - # Set this new alias as the canonical alias for this room + def _set_canonical_alias(self, content): + """Configure the canonical alias state on the room.""" self.helper.send_state( - self.room_id, - "m.room.canonical_alias", - {"alias": self.test_alias, "alt_aliases": [self.test_alias]}, - tok=self.admin_user_tok, + self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, ) - data = self.get_success( + def _get_canonical_alias(self): + """Get the canonical alias state of the room.""" + return self.get_success( self.state_handler.get_current_state( self.room_id, EventTypes.CanonicalAlias, "" ) ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self._set_canonical_alias( + {"alias": self.test_alias, "alt_aliases": [self.test_alias]} + ) + + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) @@ -193,11 +287,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) @@ -205,29 +295,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" - other_room_alias = RoomAlias.from_string(other_test_alias) - self.get_success( - self.store.create_room_alias_association( - other_room_alias, self.room_id, ["test"], self.admin_user - ) - ) + other_room_alias = self._add_alias(other_test_alias) # Set the alias as the canonical alias for this room. - self.helper.send_state( - self.room_id, - "m.room.canonical_alias", + self._set_canonical_alias( { "alias": self.test_alias, "alt_aliases": [self.test_alias, other_test_alias], - }, - tok=self.admin_user_tok, + } ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual( data["content"]["alt_aliases"], [self.test_alias, other_test_alias] @@ -240,11 +318,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index cfcd98ff7d..fdc1d918ff 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -31,7 +31,7 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS from synapse.config.homeserver import HomeServerConfig -from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( @@ -79,7 +79,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") - self.tls_factory = ClientTLSOptionsFactory(config) + self.tls_factory = FederationPolicyForHTTPS(config) self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) @@ -715,7 +715,7 @@ class MatrixFederationAgentTests(unittest.TestCase): config = default_config("test", parse=True) # Build a new agent and WellKnownResolver with a different tls factory - tls_factory = ClientTLSOptionsFactory(config) + tls_factory = FederationPolicyForHTTPS(config) agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=tls_factory, diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index d31210fbe4..f0561b30e3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -15,6 +15,7 @@ import logging from canonicaljson import encode_canonical_json +from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource @@ -58,6 +59,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super(SlavedEventStoreTestCase, self).setUp() + def prepare(self, *args, **kwargs): + super().prepare(*args, **kwargs) + + self.get_success( + self.master_store.store_room( + ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1, + ) + ) + def tearDown(self): [unpatch() for unpatch in self.unpatches] diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index e5984aaad8..0342aed416 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -870,6 +870,13 @@ class RoomTestCase(unittest.HomeserverTestCase): # Set this new alias as the canonical alias for this room self.helper.send_state( room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, "m.room.canonical_alias", {"alias": test_alias}, tok=self.admin_user_tok, diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 914cf54927..633b7dbda0 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -51,30 +51,26 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "test") self.user_tok = self.login("user", "test") - def test_cannot_set_alias_via_state_event(self): - self.ensure_user_joined_room() - url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( - self.room_id, - self.hs.hostname, - ) - - data = {"aliases": [self.random_alias(5)]} - request_data = json.dumps(data) - - request, channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.render(request) - self.assertEqual(channel.code, 400, channel.result) + def test_state_event_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_state_event(403) def test_directory_endpoint_not_in_room(self): self.ensure_user_left_room() self.set_alias_via_directory(403) + def test_state_event_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(400, alias_length=256) + def test_directory_in_room_too_long(self): self.ensure_user_joined_room() self.set_alias_via_directory(400, alias_length=256) + def test_state_event_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(200) + def test_directory_in_room(self): self.ensure_user_joined_room() self.set_alias_via_directory(200) @@ -106,6 +102,21 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(channel.code, 200, channel.result) + def set_alias_via_state_event(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + def set_alias_via_directory(self, expected_code, alias_length=5): url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) data = {"room_id": self.room_id} diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index eae5411325..da2c9bfa1e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,4 +1,7 @@ import json +import urllib.parse + +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import login @@ -252,3 +255,111 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEquals(channel.code, 200, channel.result) + + +class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.base_url = "https://matrix.goodserver.com/" + self.redirect_path = "_synapse/client/login/sso/redirect/confirm" + + config = self.default_config() + config["cas_config"] = { + "enabled": True, + "server_url": "https://fake.test", + "service_url": "https://matrix.goodserver.com:8448", + } + + async def get_raw(uri, args): + """Return an example response payload from a call to the `/proxyValidate` + endpoint of a CAS server, copied from + https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 + + This needs to be returned by an async function (as opposed to set as the + mock's return value) because the corresponding Synapse code awaits on it. + """ + return """ + <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'> + <cas:authenticationSuccess> + <cas:user>username</cas:user> + <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket> + <cas:proxies> + <cas:proxy>https://proxy2/pgtUrl</cas:proxy> + <cas:proxy>https://proxy1/pgtUrl</cas:proxy> + </cas:proxies> + </cas:authenticationSuccess> + </cas:serviceResponse> + """ + + mocked_http_client = Mock(spec=["get_raw"]) + mocked_http_client.get_raw.side_effect = get_raw + + self.hs = self.setup_test_homeserver( + config=config, proxied_http_client=mocked_http_client, + ) + + return self.hs + + def test_cas_redirect_confirm(self): + """Tests that the SSO login flow serves a confirmation page before redirecting a + user to the redirect URL. + """ + base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" + redirect_url = "https://dodgy-site.com/" + + url_parts = list(urllib.parse.urlparse(base_url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"redirectUrl": redirect_url}) + query.update({"ticket": "ticket"}) + url_parts[4] = urllib.parse.urlencode(query) + cas_ticket_url = urllib.parse.urlunparse(url_parts) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Test that the response is HTML. + self.assertEqual(channel.code, 200) + content_type_header_value = "" + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") + + self.assertTrue(content_type_header_value.startswith("text/html")) + + # Test that the body isn't empty. + self.assertTrue(len(channel.result["body"]) > 0) + + # And that it contains our redirect link + self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) + + @override_config( + { + "sso": { + "client_whitelist": [ + "https://legit-site.com/", + "https://other-site.com/", + ] + } + } + ) + def test_cas_redirect_whitelisted(self): + """Tests that the SSO login flow serves a redirect to a whitelisted url + """ + redirect_url = "https://legit-site.com/" + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + self.assertEqual(channel.code, 302) + location_headers = channel.headers.getRawHeaders("Location") + self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2f3df5f88f..7dd86d0c27 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 5059ade850..a44960203e 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -603,7 +603,7 @@ class TestStateResolutionStore(object): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids, ignore_events): + def _get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected events). @@ -617,9 +617,6 @@ class TestStateResolutionStore(object): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. - ignore_events: Set of events to exclude from the returned auth - chain. - Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ @@ -629,7 +626,7 @@ class TestStateResolutionStore(object): stack = list(event_ids) while stack: event_id = stack.pop() - if event_id in result or event_id in ignore_events: + if event_id in result: continue result.add(event_id) @@ -639,3 +636,9 @@ class TestStateResolutionStore(object): stack.append(aid) return list(result) + + def get_auth_chain_difference(self, auth_sets): + chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] + + common = set(chains[0]).intersection(*chains[1:]) + return set(chains[0]).union(*chains[1:]) - common diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a331517f4d..3aeec0dc0f 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import tests.unittest import tests.utils -class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_get_prev_events_for_room(self): room_id = "@ROOM:local" @@ -61,15 +56,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.db.runInteraction("insert", insert_event, i) + self.get_success(self.store.db.runInteraction("insert", insert_event, i)) # this should get the last ten - r = yield self.store.get_prev_events_for_room(room_id) + r = self.get_success(self.store.get_prev_events_for_room(room_id)) self.assertEqual(10, len(r)) for i in range(0, 10): self.assertEqual("$event_%i:local" % (19 - i), r[i]) - @defer.inlineCallbacks def test_get_rooms_with_many_extremities(self): room1 = "#room1" room2 = "#room2" @@ -86,25 +80,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.db.runInteraction("insert", insert_event, i, room1) - yield self.store.db.runInteraction("insert", insert_event, i, room2) - yield self.store.db.runInteraction("insert", insert_event, i, room3) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room1) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room2) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room3) + ) # Test simple case - r = yield self.store.get_rooms_with_many_extremities(5, 5, []) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [])) self.assertEqual(len(r), 3) # Does filter work? - r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1]) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1])) self.assertTrue(room2 in r) self.assertTrue(room3 in r) self.assertEqual(len(r), 2) - r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + r = self.get_success( + self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + ) self.assertEqual(r, [room3]) # Does filter and limit work? - r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1]) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) + + def test_auth_difference(self): + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ยด | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } + + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + + def insert_event(txn, event_id, stream_ordering): + + depth = depth_map[event_id] + + self.store.db.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.store.db.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + {"event_id": event_id, "room_id": room_id, "auth_id": a} + for a in auth_graph[event_id] + ], + ) + + next_stream_ordering = 0 + for event_id in auth_graph: + next_stream_ordering += 1 + self.get_success( + self.store.db.runInteraction( + "insert", insert_event, event_id, next_stream_ordering + ) + ) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + self.assertSetEqual(difference, set()) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 3c78faab45..bc53bf0951 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -303,3 +303,45 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_monthly_active_count_by_service(self): + appservice1_user1 = "@appservice1_user1:example.com" + appservice1_user2 = "@appservice1_user2:example.com" + + appservice2_user1 = "@appservice2_user1:example.com" + native_user1 = "@native_user1:example.com" + + service1 = "service1" + service2 = "service2" + native = "native" + + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) + self.store.register_user(user_id=native_user1, password_hash=None) + self.pump() + + count = self.store.get_monthly_active_count_by_service() + self.assertEqual({}, self.get_success(count)) + + self.store.upsert_monthly_active_user(native_user1) + self.store.upsert_monthly_active_user(appservice1_user1) + self.store.upsert_monthly_active_user(appservice1_user2) + self.store.upsert_monthly_active_user(appservice2_user1) + self.pump() + + count = self.store.get_monthly_active_count() + self.assertEqual(4, self.get_success(count)) + + count = self.store.get_monthly_active_count_by_service() + result = self.get_success(count) + + self.assertEqual(2, result[service1]) + self.assertEqual(1, result[service2]) + self.assertEqual(1, result[native]) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index bfa5d6f510..6c2351cf55 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -19,6 +19,7 @@ from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict +from synapse.types import get_domain_from_id class EventAuthTestCase(unittest.TestCase): @@ -51,7 +52,7 @@ class EventAuthTestCase(unittest.TestCase): _random_state_event(joiner), auth_events, do_sig_check=False, - ), + ) def test_state_default_level(self): """ @@ -87,6 +88,83 @@ class EventAuthTestCase(unittest.TestCase): RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, ) + def test_alias_event(self): + """Alias events have special behavior up through room version 6.""" + creator = "@creator:example.com" + other = "@other:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + } + + # creator should be able to send aliases + event_auth.check( + RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False, + ) + + # Reject an event with no state key. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V1, + _alias_event(creator, state_key=""), + auth_events, + do_sig_check=False, + ) + + # If the domain of the sender does not match the state key, reject. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V1, + _alias_event(creator, state_key="test.com"), + auth_events, + do_sig_check=False, + ) + + # Note that the member does *not* need to be in the room. + event_auth.check( + RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False, + ) + + def test_msc2432_alias_event(self): + """After MSC2432, alias events have no special behavior.""" + creator = "@creator:example.com" + other = "@other:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + } + + # creator should be able to send aliases + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator), + auth_events, + do_sig_check=False, + ) + + # No particular checks are done on the state key. + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator, state_key=""), + auth_events, + do_sig_check=False, + ) + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator, state_key="test.com"), + auth_events, + do_sig_check=False, + ) + + # Per standard auth rules, the member must be in the room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(other), + auth_events, + do_sig_check=False, + ) + # helpers for making events @@ -131,6 +209,19 @@ def _power_levels_event(sender, content): ) +def _alias_event(sender, **kwargs): + data = { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.aliases", + "sender": sender, + "state_key": get_domain_from_id(sender), + "content": {"aliases": []}, + } + data.update(**kwargs) + return make_event_from_dict(data) + + def _random_state_event(sender): return make_event_from_dict( { diff --git a/tests/test_types.py b/tests/test_types.py index 8d97c751ea..480bea1bdc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase): self.fail("Parsing '%s' should raise exception" % id_string) except SynapseError as exc: self.assertEqual(400, exc.code) - self.assertEqual("M_UNKNOWN", exc.errcode) + self.assertEqual("M_INVALID_PARAM", exc.errcode) class MapUsernameTestCase(unittest.TestCase): |