summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2021-02-26 14:05:40 +0000
committerRichard van der Hoff <richard@matrix.org>2021-02-26 14:05:40 +0000
commitfdbccc1e74c56da19bffc33bfe4f9f8d2b2d26f8 (patch)
tree4e2e1bdc58f6f77d7856e56261ace9b0218bb9f6 /tests
parentRevert "Redirect redirect requests if they arrive on the wrong URI" (diff)
parentSSO: redirect to public URL before setting cookies (#9436) (diff)
downloadsynapse-fdbccc1e74c56da19bffc33bfe4f9f8d2b2d26f8.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'tests')
-rw-r--r--tests/push/test_email.py34
-rw-r--r--tests/replication/tcp/streams/test_federation.py2
-rw-r--r--tests/replication/test_federation_ack.py2
-rw-r--r--tests/replication/test_federation_sender_shard.py2
-rw-r--r--tests/replication/test_pusher_shard.py2
-rw-r--r--tests/rest/admin/test_user.py246
-rw-r--r--tests/rest/client/v1/test_login.py61
-rw-r--r--tests/rest/client/v1/utils.py19
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py6
-rw-r--r--tests/rest/client/v2_alpha/test_shared_rooms.py75
-rw-r--r--tests/server.py6
-rw-r--r--tests/utils.py1
12 files changed, 367 insertions, 89 deletions
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 22f452ec24..941cf42429 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -21,6 +21,7 @@ import pkg_resources
 from twisted.internet.defer import Deferred
 
 import synapse.rest.admin
+from synapse.api.errors import Codes, SynapseError
 from synapse.rest.client.v1 import login, room
 
 from tests.unittest import HomeserverTestCase
@@ -100,12 +101,19 @@ class EmailPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(self.access_token)
         )
-        token_id = user_tuple.token_id
+        self.token_id = user_tuple.token_id
+
+        # We need to add email to account before we can create a pusher.
+        self.get_success(
+            hs.get_datastore().user_add_threepid(
+                self.user_id, "email", "a@example.com", 0, 0
+            )
+        )
 
         self.pusher = self.get_success(
             self.hs.get_pusherpool().add_pusher(
                 user_id=self.user_id,
-                access_token=token_id,
+                access_token=self.token_id,
                 kind="email",
                 app_id="m.email",
                 app_display_name="Email Notifications",
@@ -116,6 +124,28 @@ class EmailPusherTests(HomeserverTestCase):
             )
         )
 
+    def test_need_validated_email(self):
+        """Test that we can only add an email pusher if the user has validated
+        their email.
+        """
+        with self.assertRaises(SynapseError) as cm:
+            self.get_success_or_raise(
+                self.hs.get_pusherpool().add_pusher(
+                    user_id=self.user_id,
+                    access_token=self.token_id,
+                    kind="email",
+                    app_id="m.email",
+                    app_display_name="Email Notifications",
+                    device_display_name="b@example.com",
+                    pushkey="b@example.com",
+                    lang=None,
+                    data={},
+                )
+            )
+
+        self.assertEqual(400, cm.exception.code)
+        self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode)
+
     def test_simple_sends_email(self):
         # Create a simple room with two users
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index 2babea4e3e..aa4bf1c7e3 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -24,7 +24,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
         # enable federation sending on the worker
         config = super()._get_worker_hs_config()
         # TODO: make it so we don't need both of these
-        config["send_federation"] = True
+        config["send_federation"] = False
         config["worker_app"] = "synapse.app.federation_sender"
         return config
 
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 1853667558..f235f1bd83 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -27,7 +27,7 @@ class FederationAckTestCase(HomeserverTestCase):
     def default_config(self) -> dict:
         config = super().default_config()
         config["worker_app"] = "synapse.app.federation_sender"
-        config["send_federation"] = True
+        config["send_federation"] = False
         return config
 
     def make_homeserver(self, reactor, clock):
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index fffdb742c8..2f2d117858 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -49,7 +49,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
 
         self.make_worker_hs(
             "synapse.app.federation_sender",
-            {"send_federation": True},
+            {"send_federation": False},
             federation_http_client=mock_client,
         )
 
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index f118fe32af..ab2988a6ba 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -95,7 +95,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
 
         self.make_worker_hs(
             "synapse.app.pusher",
-            {"start_pushers": True},
+            {"start_pushers": False},
             proxied_blacklisted_http_client=http_client_mock,
         )
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ba26895391..e58d5cf0db 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,7 +18,7 @@ import hmac
 import json
 import urllib.parse
 from binascii import unhexlify
-from typing import Optional
+from typing import List, Optional
 
 from mock import Mock
 
@@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, sync
 from synapse.types import JsonDict
 
 from tests import unittest
+from tests.server import FakeSite, make_request
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
@@ -1954,6 +1955,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -2024,7 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         number_media = 20
         other_user_tok = self.login("user", "pass")
-        self._create_media(other_user_tok, number_media)
+        self._create_media_for_user(other_user_tok, number_media)
 
         channel = self.make_request(
             "GET",
@@ -2045,7 +2047,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         number_media = 20
         other_user_tok = self.login("user", "pass")
-        self._create_media(other_user_tok, number_media)
+        self._create_media_for_user(other_user_tok, number_media)
 
         channel = self.make_request(
             "GET",
@@ -2066,7 +2068,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         number_media = 20
         other_user_tok = self.login("user", "pass")
-        self._create_media(other_user_tok, number_media)
+        self._create_media_for_user(other_user_tok, number_media)
 
         channel = self.make_request(
             "GET",
@@ -2080,11 +2082,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(channel.json_body["media"]), 10)
         self._check_fields(channel.json_body["media"])
 
-    def test_limit_is_negative(self):
+    def test_invalid_parameter(self):
         """
-        Testing that a negative limit parameter returns a 400
+        If parameters are invalid, an error is returned.
         """
+        # unkown order_by
+        channel = self.make_request(
+            "GET",
+            self.url + "?order_by=bar",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
 
+        # invalid search order
+        channel = self.make_request(
+            "GET",
+            self.url + "?dir=bar",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+        # negative limit
         channel = self.make_request(
             "GET",
             self.url + "?limit=-5",
@@ -2094,11 +2116,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
-    def test_from_is_negative(self):
-        """
-        Testing that a negative from parameter returns a 400
-        """
-
+        # negative from
         channel = self.make_request(
             "GET",
             self.url + "?from=-5",
@@ -2115,7 +2133,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         number_media = 20
         other_user_tok = self.login("user", "pass")
-        self._create_media(other_user_tok, number_media)
+        self._create_media_for_user(other_user_tok, number_media)
 
         #  `next_token` does not appear
         # Number of results is the number of entries
@@ -2193,7 +2211,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         number_media = 5
         other_user_tok = self.login("user", "pass")
-        self._create_media(other_user_tok, number_media)
+        self._create_media_for_user(other_user_tok, number_media)
 
         channel = self.make_request(
             "GET",
@@ -2207,11 +2225,118 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("next_token", channel.json_body)
         self._check_fields(channel.json_body["media"])
 
-    def _create_media(self, user_token, number_media):
+    def test_order_by(self):
+        """
+        Testing order list with parameter `order_by`
+        """
+
+        other_user_tok = self.login("user", "pass")
+
+        # Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
+        image_data1 = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+        # Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
+        image_data2 = unhexlify(
+            b"47494638376101000100800100000000"
+            b"ffffff2c00000000010001000002024c"
+            b"01003b"
+        )
+        # Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B
+        image_data3 = unhexlify(
+            b"424d3a0000000000000036000000280000000100000001000000"
+            b"0100180000000000040000000000000000000000000000000000"
+            b"0000"
+        )
+
+        # create media and make sure they do not have the same timestamp
+        media1 = self._create_media_and_access(other_user_tok, image_data1, "image.png")
+        self.pump(1.0)
+        media2 = self._create_media_and_access(other_user_tok, image_data2, "image.gif")
+        self.pump(1.0)
+        media3 = self._create_media_and_access(other_user_tok, image_data3, "image.bmp")
+        self.pump(1.0)
+
+        # Mark one media as safe from quarantine.
+        self.get_success(self.store.mark_local_media_as_safe(media2))
+        # Quarantine one media
+        self.get_success(
+            self.store.quarantine_media_by_id("test", media3, self.admin_user)
+        )
+
+        # order by default ("created_ts")
+        # default is backwards
+        self._order_test([media3, media2, media1], None)
+        self._order_test([media1, media2, media3], None, "f")
+        self._order_test([media3, media2, media1], None, "b")
+
+        # sort by media_id
+        sorted_media = sorted([media1, media2, media3], reverse=False)
+        sorted_media_reverse = sorted(sorted_media, reverse=True)
+
+        # order by media_id
+        self._order_test(sorted_media, "media_id")
+        self._order_test(sorted_media, "media_id", "f")
+        self._order_test(sorted_media_reverse, "media_id", "b")
+
+        # order by upload_name
+        self._order_test([media3, media2, media1], "upload_name")
+        self._order_test([media3, media2, media1], "upload_name", "f")
+        self._order_test([media1, media2, media3], "upload_name", "b")
+
+        # order by media_type
+        # result is ordered by media_id
+        # because of uploaded media_type is always 'application/json'
+        self._order_test(sorted_media, "media_type")
+        self._order_test(sorted_media, "media_type", "f")
+        self._order_test(sorted_media, "media_type", "b")
+
+        # order by media_length
+        self._order_test([media2, media3, media1], "media_length")
+        self._order_test([media2, media3, media1], "media_length", "f")
+        self._order_test([media1, media3, media2], "media_length", "b")
+
+        # order by created_ts
+        self._order_test([media1, media2, media3], "created_ts")
+        self._order_test([media1, media2, media3], "created_ts", "f")
+        self._order_test([media3, media2, media1], "created_ts", "b")
+
+        # order by last_access_ts
+        self._order_test([media1, media2, media3], "last_access_ts")
+        self._order_test([media1, media2, media3], "last_access_ts", "f")
+        self._order_test([media3, media2, media1], "last_access_ts", "b")
+
+        # order by quarantined_by
+        # one media is in quarantine, others are ordered by media_ids
+
+        # Different sort order of SQlite and PostreSQL
+        # If a media is not in quarantine `quarantined_by` is NULL
+        # SQLite considers NULL to be smaller than any other value.
+        # PostreSQL considers NULL to be larger than any other value.
+
+        # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by")
+        # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by", "f")
+        # self._order_test([media3] + sorted([media1, media2]), "quarantined_by", "b")
+
+        # order by safe_from_quarantine
+        # one media is safe from quarantine, others are ordered by media_ids
+        self._order_test(sorted([media1, media3]) + [media2], "safe_from_quarantine")
+        self._order_test(
+            sorted([media1, media3]) + [media2], "safe_from_quarantine", "f"
+        )
+        self._order_test(
+            [media2] + sorted([media1, media3]), "safe_from_quarantine", "b"
+        )
+
+    def _create_media_for_user(self, user_token: str, number_media: int):
         """
         Create a number of media for a specific user
+        Args:
+            user_token: Access token of the user
+            number_media: Number of media to be created for the user
         """
-        upload_resource = self.media_repo.children[b"upload"]
         for i in range(number_media):
             # file size is 67 Byte
             image_data = unhexlify(
@@ -2220,13 +2345,60 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
                 b"0a2db40000000049454e44ae426082"
             )
 
-            # Upload some media into the room
-            self.helper.upload_media(
-                upload_resource, image_data, tok=user_token, expect_code=200
-            )
+            self._create_media_and_access(user_token, image_data)
+
+    def _create_media_and_access(
+        self,
+        user_token: str,
+        image_data: bytes,
+        filename: str = "image1.png",
+    ) -> str:
+        """
+        Create one media for a specific user, access and returns `media_id`
+        Args:
+            user_token: Access token of the user
+            image_data: binary data of image
+            filename: The filename of the media to be uploaded
+        Returns:
+            The ID of the newly created media.
+        """
+        upload_resource = self.media_repo.children[b"upload"]
+        download_resource = self.media_repo.children[b"download"]
+
+        # Upload some media into the room
+        response = self.helper.upload_media(
+            upload_resource, image_data, user_token, filename, expect_code=200
+        )
+
+        # Extract media ID from the response
+        server_and_media_id = response["content_uri"][6:]  # Cut off 'mxc://'
+        media_id = server_and_media_id.split("/")[1]
+
+        # Try to access a media and to create `last_access_ts`
+        channel = make_request(
+            self.reactor,
+            FakeSite(download_resource),
+            "GET",
+            server_and_media_id,
+            shorthand=False,
+            access_token=user_token,
+        )
+
+        self.assertEqual(
+            200,
+            channel.code,
+            msg=(
+                "Expected to receive a 200 on accessing media: %s" % server_and_media_id
+            ),
+        )
 
-    def _check_fields(self, content):
-        """Checks that all attributes are present in content"""
+        return media_id
+
+    def _check_fields(self, content: JsonDict):
+        """Checks that the expected user attributes are present in content
+        Args:
+            content: List that is checked for content
+        """
         for m in content:
             self.assertIn("media_id", m)
             self.assertIn("media_type", m)
@@ -2237,6 +2409,38 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
             self.assertIn("quarantined_by", m)
             self.assertIn("safe_from_quarantine", m)
 
+    def _order_test(
+        self,
+        expected_media_list: List[str],
+        order_by: Optional[str],
+        dir: Optional[str] = None,
+    ):
+        """Request the list of media in a certain order. Assert that order is what
+        we expect
+        Args:
+            expected_media_list: The list of media_ids in the order we expect to get
+                back from the server
+            order_by: The type of ordering to give the server
+            dir: The direction of ordering to give the server
+        """
+
+        url = self.url + "?"
+        if order_by is not None:
+            url += "order_by=%s&" % (order_by,)
+        if dir is not None and dir in ("b", "f"):
+            url += "dir=%s" % (dir,)
+        channel = self.make_request(
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], len(expected_media_list))
+
+        returned_order = [row["media_id"] for row in channel.json_body["media"]]
+        self.assertEqual(expected_media_list, returned_order)
+        self._check_fields(channel.json_body["media"])
+
 
 class UserTokenRestTestCase(unittest.HomeserverTestCase):
     """Test for /_synapse/admin/v1/users/<user>/login"""
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index fb29eaed6f..744d8d0941 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
 
 import time
 import urllib.parse
-from typing import Any, Dict, List, Union
+from typing import Any, Dict, List, Optional, Union
 from urllib.parse import urlencode
 
 from mock import Mock
@@ -47,8 +47,14 @@ except ImportError:
     HAS_JWT = False
 
 
-# public_base_url used in some tests
-BASE_URL = "https://synapse/"
+# synapse server name: used to populate public_baseurl in some tests
+SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
+
+# public_baseurl for some tests. It uses an http:// scheme because
+# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
+# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
+# https://....
+BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
 
 # CAS server used in some tests
 CAS_SERVER = "https://fake.test"
@@ -480,11 +486,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_multi_sso_redirect(self):
         """/login/sso/redirect should redirect to an identity picker"""
         # first hit the redirect url, which should redirect to our idp picker
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/r0/login/sso/redirect?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
-        )
+        channel = self._make_sso_redirect_request(False, None)
         self.assertEqual(channel.code, 302, channel.result)
         uri = channel.headers.getRawHeaders("Location")[0]
 
@@ -628,34 +630,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
     def test_client_idp_redirect_msc2858_disabled(self):
         """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
-        )
+        channel = self._make_sso_redirect_request(True, "oidc")
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
     @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_client_idp_redirect_to_unknown(self):
         """If the client tries to pick an unknown IdP, return a 404"""
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
-        )
+        channel = self._make_sso_redirect_request(True, "xxx")
         self.assertEqual(channel.code, 404, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
 
     @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_client_idp_redirect_to_oidc(self):
         """If the client pick a known IdP, redirect to it"""
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
-        )
-
+        channel = self._make_sso_redirect_request(True, "oidc")
         self.assertEqual(channel.code, 302, channel.result)
         oidc_uri = channel.headers.getRawHeaders("Location")[0]
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
@@ -663,6 +652,30 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # it should redirect us to the auth page of the OIDC server
         self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
 
+    def _make_sso_redirect_request(
+        self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
+    ):
+        """Send a request to /_matrix/client/r0/login/sso/redirect
+
+        ... or the unstable equivalent
+
+        ... possibly specifying an IDP provider
+        """
+        endpoint = (
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect"
+            if unstable_endpoint
+            else "/_matrix/client/r0/login/sso/redirect"
+        )
+        if idp_prov is not None:
+            endpoint += "/" + idp_prov
+        endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+
+        return self.make_request(
+            "GET",
+            endpoint,
+            custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
+        )
+
     @staticmethod
     def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
         prefix = key + " = "
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 8231a423f3..946740aa5d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -542,13 +542,30 @@ class RestHelper:
         if client_redirect_url:
             params["redirectUrl"] = client_redirect_url
 
-        # hit the redirect url (which will issue a cookie and state)
+        # hit the redirect url (which should redirect back to the redirect url. This
+        # is the easiest way of figuring out what the Host header ought to be set to
+        # to keep Synapse happy.
         channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "GET",
             "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
         )
+        assert channel.code == 302
+
+        # hit the redirect url again with the right Host header, which should now issue
+        # a cookie and redirect to the SSO provider.
+        location = channel.headers.getRawHeaders("Location")[0]
+        parts = urllib.parse.urlsplit(location)
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            urllib.parse.urlunsplit(("", "") + parts[2:]),
+            custom_headers=[
+                ("Host", parts[1]),
+            ],
+        )
 
         assert channel.code == 302
         channel.extract_cookies(cookies)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index c26ad824f7..9734a2159a 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     def default_config(self):
         config = super().default_config()
-        config["public_baseurl"] = "https://synapse.test"
+
+        # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+        # False, so synapse will see the requested uri as http://..., so using http in
+        # the public_baseurl stops Synapse trying to redirect to https.
+        config["public_baseurl"] = "http://synapse.test"
 
         if HAS_OIDC:
             # we enable OIDC as a way of testing SSO flows
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 116ace1812..dd83a1f8ff 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -54,61 +54,62 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         A room should show up in the shared list of rooms between two users
         if it is public.
         """
-        u1 = self.register_user("user1", "pass")
-        u1_token = self.login(u1, "pass")
-        u2 = self.register_user("user2", "pass")
-        u2_token = self.login(u2, "pass")
-
-        room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
-        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
-        self.helper.join(room, user=u2, tok=u2_token)
-
-        channel = self._get_shared_rooms(u1_token, u2)
-        self.assertEquals(200, channel.code, channel.result)
-        self.assertEquals(len(channel.json_body["joined"]), 1)
-        self.assertEquals(channel.json_body["joined"][0], room)
+        self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
 
     def test_shared_room_list_private(self):
         """
         A room should show up in the shared list of rooms between two users
         if it is private.
         """
-        u1 = self.register_user("user1", "pass")
-        u1_token = self.login(u1, "pass")
-        u2 = self.register_user("user2", "pass")
-        u2_token = self.login(u2, "pass")
-
-        room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
-        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
-        self.helper.join(room, user=u2, tok=u2_token)
-
-        channel = self._get_shared_rooms(u1_token, u2)
-        self.assertEquals(200, channel.code, channel.result)
-        self.assertEquals(len(channel.json_body["joined"]), 1)
-        self.assertEquals(channel.json_body["joined"][0], room)
+        self._check_shared_rooms_with(
+            room_one_is_public=False, room_two_is_public=False
+        )
 
     def test_shared_room_list_mixed(self):
         """
         The shared room list between two users should contain both public and private
         rooms.
         """
+        self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
+
+    def _check_shared_rooms_with(
+        self, room_one_is_public: bool, room_two_is_public: bool
+    ):
+        """Checks that shared public or private rooms between two users appear in
+        their shared room lists
+        """
         u1 = self.register_user("user1", "pass")
         u1_token = self.login(u1, "pass")
         u2 = self.register_user("user2", "pass")
         u2_token = self.login(u2, "pass")
 
-        room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
-        room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
-        self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
-        self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
-        self.helper.join(room_public, user=u2, tok=u2_token)
-        self.helper.join(room_private, user=u1, tok=u1_token)
+        # Create a room. user1 invites user2, who joins
+        room_id_one = self.helper.create_room_as(
+            u1, is_public=room_one_is_public, tok=u1_token
+        )
+        self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room_id_one, user=u2, tok=u2_token)
 
+        # Check shared rooms from user1's perspective.
+        # We should see the one room in common
+        channel = self._get_shared_rooms(u1_token, u2)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(len(channel.json_body["joined"]), 1)
+        self.assertEquals(channel.json_body["joined"][0], room_id_one)
+
+        # Create another room and invite user2 to it
+        room_id_two = self.helper.create_room_as(
+            u1, is_public=room_two_is_public, tok=u1_token
+        )
+        self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room_id_two, user=u2, tok=u2_token)
+
+        # Check shared rooms again. We should now see both rooms.
         channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 2)
-        self.assertTrue(room_public in channel.json_body["joined"])
-        self.assertTrue(room_private in channel.json_body["joined"])
+        for room_id_id in channel.json_body["joined"]:
+            self.assertIn(room_id_id, [room_id_one, room_id_two])
 
     def test_shared_room_list_after_leave(self):
         """
@@ -132,6 +133,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
 
         self.helper.leave(room, user=u1, tok=u1_token)
 
+        # Check user1's view of shared rooms with user2
+        channel = self._get_shared_rooms(u1_token, u2)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(len(channel.json_body["joined"]), 0)
+
+        # Check user2's view of shared rooms with user1
         channel = self._get_shared_rooms(u2_token, u1)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/server.py b/tests/server.py
index d4ece5c448..939a0008ca 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -124,7 +124,11 @@ class FakeChannel:
         return address.IPv4Address("TCP", self._ip, 3423)
 
     def getHost(self):
-        return None
+        # this is called by Request.__init__ to configure Request.host.
+        return address.IPv4Address("TCP", "127.0.0.1", 8888)
+
+    def isSecure(self):
+        return False
 
     @property
     def transport(self):
diff --git a/tests/utils.py b/tests/utils.py
index 4fb5098550..be80b13760 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -114,7 +114,6 @@ def default_config(name, parse=False):
         "server_name": name,
         "send_federation": False,
         "media_store_path": "media",
-        "uploads_path": "uploads",
         # the test signing key is just an arbitrary ed25519 key to keep the config
         # parser happy
         "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",