diff options
author | Richard van der Hoff <richard@matrix.org> | 2021-02-26 14:05:40 +0000 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2021-02-26 14:05:40 +0000 |
commit | fdbccc1e74c56da19bffc33bfe4f9f8d2b2d26f8 (patch) | |
tree | 4e2e1bdc58f6f77d7856e56261ace9b0218bb9f6 /tests | |
parent | Revert "Redirect redirect requests if they arrive on the wrong URI" (diff) | |
parent | SSO: redirect to public URL before setting cookies (#9436) (diff) | |
download | synapse-fdbccc1e74c56da19bffc33bfe4f9f8d2b2d26f8.tar.xz |
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'tests')
-rw-r--r-- | tests/push/test_email.py | 34 | ||||
-rw-r--r-- | tests/replication/tcp/streams/test_federation.py | 2 | ||||
-rw-r--r-- | tests/replication/test_federation_ack.py | 2 | ||||
-rw-r--r-- | tests/replication/test_federation_sender_shard.py | 2 | ||||
-rw-r--r-- | tests/replication/test_pusher_shard.py | 2 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 246 | ||||
-rw-r--r-- | tests/rest/client/v1/test_login.py | 61 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 19 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_auth.py | 6 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_shared_rooms.py | 75 | ||||
-rw-r--r-- | tests/server.py | 6 | ||||
-rw-r--r-- | tests/utils.py | 1 |
12 files changed, 367 insertions, 89 deletions
diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 22f452ec24..941cf42429 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -21,6 +21,7 @@ import pkg_resources from twisted.internet.defer import Deferred import synapse.rest.admin +from synapse.api.errors import Codes, SynapseError from synapse.rest.client.v1 import login, room from tests.unittest import HomeserverTestCase @@ -100,12 +101,19 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token) ) - token_id = user_tuple.token_id + self.token_id = user_tuple.token_id + + # We need to add email to account before we can create a pusher. + self.get_success( + hs.get_datastore().user_add_threepid( + self.user_id, "email", "a@example.com", 0, 0 + ) + ) self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( user_id=self.user_id, - access_token=token_id, + access_token=self.token_id, kind="email", app_id="m.email", app_display_name="Email Notifications", @@ -116,6 +124,28 @@ class EmailPusherTests(HomeserverTestCase): ) ) + def test_need_validated_email(self): + """Test that we can only add an email pusher if the user has validated + their email. + """ + with self.assertRaises(SynapseError) as cm: + self.get_success_or_raise( + self.hs.get_pusherpool().add_pusher( + user_id=self.user_id, + access_token=self.token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name="b@example.com", + pushkey="b@example.com", + lang=None, + data={}, + ) + ) + + self.assertEqual(400, cm.exception.code) + self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode) + def test_simple_sends_email(self): # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py index 2babea4e3e..aa4bf1c7e3 100644 --- a/tests/replication/tcp/streams/test_federation.py +++ b/tests/replication/tcp/streams/test_federation.py @@ -24,7 +24,7 @@ class FederationStreamTestCase(BaseStreamTestCase): # enable federation sending on the worker config = super()._get_worker_hs_config() # TODO: make it so we don't need both of these - config["send_federation"] = True + config["send_federation"] = False config["worker_app"] = "synapse.app.federation_sender" return config diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 1853667558..f235f1bd83 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -27,7 +27,7 @@ class FederationAckTestCase(HomeserverTestCase): def default_config(self) -> dict: config = super().default_config() config["worker_app"] = "synapse.app.federation_sender" - config["send_federation"] = True + config["send_federation"] = False return config def make_homeserver(self, reactor, clock): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index fffdb742c8..2f2d117858 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -49,7 +49,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.federation_sender", - {"send_federation": True}, + {"send_federation": False}, federation_http_client=mock_client, ) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index f118fe32af..ab2988a6ba 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -95,7 +95,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.pusher", - {"start_pushers": True}, + {"start_pushers": False}, proxied_blacklisted_http_client=http_client_mock, ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ba26895391..e58d5cf0db 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -18,7 +18,7 @@ import hmac import json import urllib.parse from binascii import unhexlify -from typing import Optional +from typing import List, Optional from mock import Mock @@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, sync from synapse.types import JsonDict from tests import unittest +from tests.server import FakeSite, make_request from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -1954,6 +1955,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -2024,7 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2045,7 +2047,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2066,7 +2068,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2080,11 +2082,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["media"]), 10) self._check_fields(channel.json_body["media"]) - def test_limit_is_negative(self): + def test_invalid_parameter(self): """ - Testing that a negative limit parameter returns a 400 + If parameters are invalid, an error is returned. """ + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # negative limit channel = self.make_request( "GET", self.url + "?limit=-5", @@ -2094,11 +2116,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_from_is_negative(self): - """ - Testing that a negative from parameter returns a 400 - """ - + # negative from channel = self.make_request( "GET", self.url + "?from=-5", @@ -2115,7 +2133,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) # `next_token` does not appear # Number of results is the number of entries @@ -2193,7 +2211,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 5 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2207,11 +2225,118 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def _create_media(self, user_token, number_media): + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + other_user_tok = self.login("user", "pass") + + # Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B + image_data1 = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + # Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B + image_data2 = unhexlify( + b"47494638376101000100800100000000" + b"ffffff2c00000000010001000002024c" + b"01003b" + ) + # Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B + image_data3 = unhexlify( + b"424d3a0000000000000036000000280000000100000001000000" + b"0100180000000000040000000000000000000000000000000000" + b"0000" + ) + + # create media and make sure they do not have the same timestamp + media1 = self._create_media_and_access(other_user_tok, image_data1, "image.png") + self.pump(1.0) + media2 = self._create_media_and_access(other_user_tok, image_data2, "image.gif") + self.pump(1.0) + media3 = self._create_media_and_access(other_user_tok, image_data3, "image.bmp") + self.pump(1.0) + + # Mark one media as safe from quarantine. + self.get_success(self.store.mark_local_media_as_safe(media2)) + # Quarantine one media + self.get_success( + self.store.quarantine_media_by_id("test", media3, self.admin_user) + ) + + # order by default ("created_ts") + # default is backwards + self._order_test([media3, media2, media1], None) + self._order_test([media1, media2, media3], None, "f") + self._order_test([media3, media2, media1], None, "b") + + # sort by media_id + sorted_media = sorted([media1, media2, media3], reverse=False) + sorted_media_reverse = sorted(sorted_media, reverse=True) + + # order by media_id + self._order_test(sorted_media, "media_id") + self._order_test(sorted_media, "media_id", "f") + self._order_test(sorted_media_reverse, "media_id", "b") + + # order by upload_name + self._order_test([media3, media2, media1], "upload_name") + self._order_test([media3, media2, media1], "upload_name", "f") + self._order_test([media1, media2, media3], "upload_name", "b") + + # order by media_type + # result is ordered by media_id + # because of uploaded media_type is always 'application/json' + self._order_test(sorted_media, "media_type") + self._order_test(sorted_media, "media_type", "f") + self._order_test(sorted_media, "media_type", "b") + + # order by media_length + self._order_test([media2, media3, media1], "media_length") + self._order_test([media2, media3, media1], "media_length", "f") + self._order_test([media1, media3, media2], "media_length", "b") + + # order by created_ts + self._order_test([media1, media2, media3], "created_ts") + self._order_test([media1, media2, media3], "created_ts", "f") + self._order_test([media3, media2, media1], "created_ts", "b") + + # order by last_access_ts + self._order_test([media1, media2, media3], "last_access_ts") + self._order_test([media1, media2, media3], "last_access_ts", "f") + self._order_test([media3, media2, media1], "last_access_ts", "b") + + # order by quarantined_by + # one media is in quarantine, others are ordered by media_ids + + # Different sort order of SQlite and PostreSQL + # If a media is not in quarantine `quarantined_by` is NULL + # SQLite considers NULL to be smaller than any other value. + # PostreSQL considers NULL to be larger than any other value. + + # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by") + # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by", "f") + # self._order_test([media3] + sorted([media1, media2]), "quarantined_by", "b") + + # order by safe_from_quarantine + # one media is safe from quarantine, others are ordered by media_ids + self._order_test(sorted([media1, media3]) + [media2], "safe_from_quarantine") + self._order_test( + sorted([media1, media3]) + [media2], "safe_from_quarantine", "f" + ) + self._order_test( + [media2] + sorted([media1, media3]), "safe_from_quarantine", "b" + ) + + def _create_media_for_user(self, user_token: str, number_media: int): """ Create a number of media for a specific user + Args: + user_token: Access token of the user + number_media: Number of media to be created for the user """ - upload_resource = self.media_repo.children[b"upload"] for i in range(number_media): # file size is 67 Byte image_data = unhexlify( @@ -2220,13 +2345,60 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): b"0a2db40000000049454e44ae426082" ) - # Upload some media into the room - self.helper.upload_media( - upload_resource, image_data, tok=user_token, expect_code=200 - ) + self._create_media_and_access(user_token, image_data) + + def _create_media_and_access( + self, + user_token: str, + image_data: bytes, + filename: str = "image1.png", + ) -> str: + """ + Create one media for a specific user, access and returns `media_id` + Args: + user_token: Access token of the user + image_data: binary data of image + filename: The filename of the media to be uploaded + Returns: + The ID of the newly created media. + """ + upload_resource = self.media_repo.children[b"upload"] + download_resource = self.media_repo.children[b"download"] + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, user_token, filename, expect_code=200 + ) + + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + media_id = server_and_media_id.split("/")[1] + + # Try to access a media and to create `last_access_ts` + channel = make_request( + self.reactor, + FakeSite(download_resource), + "GET", + server_and_media_id, + shorthand=False, + access_token=user_token, + ) + + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" % server_and_media_id + ), + ) - def _check_fields(self, content): - """Checks that all attributes are present in content""" + return media_id + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ for m in content: self.assertIn("media_id", m) self.assertIn("media_type", m) @@ -2237,6 +2409,38 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertIn("quarantined_by", m) self.assertIn("safe_from_quarantine", m) + def _order_test( + self, + expected_media_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of media in a certain order. Assert that order is what + we expect + Args: + expected_media_list: The list of media_ids in the order we expect to get + back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = self.url + "?" + if order_by is not None: + url += "order_by=%s&" % (order_by,) + if dir is not None and dir in ("b", "f"): + url += "dir=%s" % (dir,) + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_media_list)) + + returned_order = [row["media_id"] for row in channel.json_body["media"]] + self.assertEqual(expected_media_list, returned_order) + self._check_fields(channel.json_body["media"]) + class UserTokenRestTestCase(unittest.HomeserverTestCase): """Test for /_synapse/admin/v1/users/<user>/login""" diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index fb29eaed6f..744d8d0941 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,7 +15,7 @@ import time import urllib.parse -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from urllib.parse import urlencode from mock import Mock @@ -47,8 +47,14 @@ except ImportError: HAS_JWT = False -# public_base_url used in some tests -BASE_URL = "https://synapse/" +# synapse server name: used to populate public_baseurl in some tests +SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" + +# public_baseurl for some tests. It uses an http:// scheme because +# FakeChannel.isSecure() returns False, so synapse will see the requested uri as +# http://..., so using http in the public_baseurl stops Synapse trying to redirect to +# https://.... +BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) # CAS server used in some tests CAS_SERVER = "https://fake.test" @@ -480,11 +486,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker - channel = self.make_request( - "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(False, None) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -628,34 +630,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_client_idp_redirect_msc2858_disabled(self): """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_unknown(self): """If the client tries to pick an unknown IdP, return a 404""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(True, "xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_oidc(self): """If the client pick a known IdP, redirect to it""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) - + channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 302, channel.result) oidc_uri = channel.headers.getRawHeaders("Location")[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) @@ -663,6 +652,30 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + def _make_sso_redirect_request( + self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None + ): + """Send a request to /_matrix/client/r0/login/sso/redirect + + ... or the unstable equivalent + + ... possibly specifying an IDP provider + """ + endpoint = ( + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" + if unstable_endpoint + else "/_matrix/client/r0/login/sso/redirect" + ) + if idp_prov is not None: + endpoint += "/" + idp_prov + endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + return self.make_request( + "GET", + endpoint, + custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], + ) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8231a423f3..946740aa5d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -542,13 +542,30 @@ class RestHelper: if client_redirect_url: params["redirectUrl"] = client_redirect_url - # hit the redirect url (which will issue a cookie and state) + # hit the redirect url (which should redirect back to the redirect url. This + # is the easiest way of figuring out what the Host header ought to be set to + # to keep Synapse happy. channel = make_request( self.hs.get_reactor(), self.site, "GET", "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), ) + assert channel.code == 302 + + # hit the redirect url again with the right Host header, which should now issue + # a cookie and redirect to the SSO provider. + location = channel.headers.getRawHeaders("Location")[0] + parts = urllib.parse.urlsplit(location) + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + urllib.parse.urlunsplit(("", "") + parts[2:]), + custom_headers=[ + ("Host", parts[1]), + ], + ) assert channel.code == 302 channel.extract_cookies(cookies) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index c26ad824f7..9734a2159a 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase): def default_config(self): config = super().default_config() - config["public_baseurl"] = "https://synapse.test" + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" if HAS_OIDC: # we enable OIDC as a way of testing SSO flows diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index 116ace1812..dd83a1f8ff 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -54,61 +54,62 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): A room should show up in the shared list of rooms between two users if it is public. """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) def test_shared_room_list_private(self): """ A room should show up in the shared list of rooms between two users if it is private. """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self._check_shared_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) def test_shared_room_list_mixed(self): """ The shared room list between two users should contain both public and private rooms. """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_shared_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ): + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") u2 = self.register_user("user2", "pass") u2_token = self.login(u2, "pass") - room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token) - self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token) - self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token) - self.helper.join(room_public, user=u2, tok=u2_token) - self.helper.join(room_private, user=u1, tok=u1_token) + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 2) - self.assertTrue(room_public in channel.json_body["joined"]) - self.assertTrue(room_private in channel.json_body["joined"]) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) def test_shared_room_list_after_leave(self): """ @@ -132,6 +133,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.leave(room, user=u1, tok=u1_token) + # Check user1's view of shared rooms with user2 + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 channel = self._get_shared_rooms(u2_token, u1) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/server.py b/tests/server.py index d4ece5c448..939a0008ca 100644 --- a/tests/server.py +++ b/tests/server.py @@ -124,7 +124,11 @@ class FakeChannel: return address.IPv4Address("TCP", self._ip, 3423) def getHost(self): - return None + # this is called by Request.__init__ to configure Request.host. + return address.IPv4Address("TCP", "127.0.0.1", 8888) + + def isSecure(self): + return False @property def transport(self): diff --git a/tests/utils.py b/tests/utils.py index 4fb5098550..be80b13760 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -114,7 +114,6 @@ def default_config(name, parse=False): "server_name": name, "send_federation": False, "media_store_path": "media", - "uploads_path": "uploads", # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", |