summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/config/test_util.py53
-rw-r--r--tests/events/test_utils.py185
-rw-r--r--tests/handlers/test_cas.py2
-rw-r--r--tests/handlers/test_oidc.py83
-rw-r--r--tests/handlers/test_profile.py30
-rw-r--r--tests/handlers/test_saml.py2
-rw-r--r--tests/http/test_fedclient.py2
-rw-r--r--tests/rest/admin/test_admin.py3
-rw-r--r--tests/rest/admin/test_event_reports.py4
-rw-r--r--tests/rest/admin/test_media.py2
-rw-r--r--tests/rest/admin/test_room.py2
-rw-r--r--tests/rest/admin/test_statistics.py1
-rw-r--r--tests/rest/admin/test_user.py283
-rw-r--r--tests/rest/client/v1/test_login.py280
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/utils.py3
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py26
-rw-r--r--tests/rest/media/v1/test_url_preview.py7
-rw-r--r--tests/storage/test_account_data.py120
-rw-r--r--tests/storage/test_event_chain.py472
-rw-r--r--tests/storage/test_event_federation.py249
-rw-r--r--tests/storage/test_profile.py26
-rw-r--r--tests/test_preview.py11
-rw-r--r--tests/unittest.py28
-rw-r--r--tests/util/caches/test_deferred_cache.py73
-rw-r--r--tests/util/test_itertools.py41
26 files changed, 1830 insertions, 164 deletions
diff --git a/tests/config/test_util.py b/tests/config/test_util.py
new file mode 100644
index 0000000000..10363e3765
--- /dev/null
+++ b/tests/config/test_util.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config import ConfigError
+from synapse.config._util import validate_config
+
+from tests.unittest import TestCase
+
+
+class ValidateConfigTestCase(TestCase):
+    """Test cases for synapse.config._util.validate_config"""
+
+    def test_bad_object_in_array(self):
+        """malformed objects within an array should be validated correctly"""
+
+        # consider a structure:
+        #
+        # array_of_objs:
+        #   - r: 1
+        #     foo: 2
+        #
+        #   - r: 2
+        #     bar: 3
+        #
+        # ... where each entry must contain an "r": check that the path
+        # to the required item is correclty reported.
+
+        schema = {
+            "type": "object",
+            "properties": {
+                "array_of_objs": {
+                    "type": "array",
+                    "items": {"type": "object", "required": ["r"]},
+                },
+            },
+        }
+
+        with self.assertRaises(ConfigError) as c:
+            validate_config(schema, {"array_of_objs": [{}]}, ("base",))
+
+        self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"])
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index c1274c14af..8ba36c6074 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -34,11 +34,17 @@ def MockEvent(**kwargs):
 
 
 class PruneEventTestCase(unittest.TestCase):
-    """ Asserts that a new event constructed with `evdict` will look like
-    `matchdict` when it is redacted. """
-
     def run_test(self, evdict, matchdict, **kwargs):
-        self.assertEquals(
+        """
+        Asserts that a new event constructed with `evdict` will look like
+        `matchdict` when it is redacted.
+
+        Args:
+             evdict: The dictionary to build the event from.
+             matchdict: The expected resulting dictionary.
+             kwargs: Additional keyword arguments used to create the event.
+        """
+        self.assertEqual(
             prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
         )
 
@@ -55,54 +61,80 @@ class PruneEventTestCase(unittest.TestCase):
         )
 
     def test_basic_keys(self):
+        """Ensure that the keys that should be untouched are kept."""
+        # Note that some of the values below don't really make sense, but the
+        # pruning of events doesn't worry about the values of any fields (with
+        # the exception of the content field).
         self.run_test(
             {
+                "event_id": "$3:domain",
                 "type": "A",
                 "room_id": "!1:domain",
                 "sender": "@2:domain",
-                "event_id": "$3:domain",
+                "state_key": "B",
+                "content": {"other_key": "foo"},
+                "hashes": "hashes",
+                "signatures": {"domain": {"algo:1": "sigs"}},
+                "depth": 4,
+                "prev_events": "prev_events",
+                "prev_state": "prev_state",
+                "auth_events": "auth_events",
                 "origin": "domain",
+                "origin_server_ts": 1234,
+                "membership": "join",
+                # Also include a key that should be removed.
+                "other_key": "foo",
             },
             {
+                "event_id": "$3:domain",
                 "type": "A",
                 "room_id": "!1:domain",
                 "sender": "@2:domain",
-                "event_id": "$3:domain",
+                "state_key": "B",
+                "hashes": "hashes",
+                "depth": 4,
+                "prev_events": "prev_events",
+                "prev_state": "prev_state",
+                "auth_events": "auth_events",
                 "origin": "domain",
+                "origin_server_ts": 1234,
+                "membership": "join",
                 "content": {},
-                "signatures": {},
+                "signatures": {"domain": {"algo:1": "sigs"}},
                 "unsigned": {},
             },
         )
 
-    def test_unsigned_age_ts(self):
+        # As of MSC2176 we now redact the membership and prev_states keys.
         self.run_test(
-            {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}},
-            {
-                "type": "B",
-                "event_id": "$test:domain",
-                "content": {},
-                "signatures": {},
-                "unsigned": {"age_ts": 20},
-            },
+            {"type": "A", "prev_state": "prev_state", "membership": "join"},
+            {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
+            room_version=RoomVersions.MSC2176,
         )
 
+    def test_unsigned(self):
+        """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
         self.run_test(
             {
                 "type": "B",
                 "event_id": "$test:domain",
-                "unsigned": {"other_key": "here"},
+                "unsigned": {
+                    "age_ts": 20,
+                    "replaces_state": "$test2:domain",
+                    "other_key": "foo",
+                },
             },
             {
                 "type": "B",
                 "event_id": "$test:domain",
                 "content": {},
                 "signatures": {},
-                "unsigned": {},
+                "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"},
             },
         )
 
     def test_content(self):
+        """The content dictionary should be stripped in most cases."""
         self.run_test(
             {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
             {
@@ -114,11 +146,35 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
+        # Some events keep a single content key/value.
+        EVENT_KEEP_CONTENT_KEYS = [
+            ("member", "membership", "join"),
+            ("join_rules", "join_rule", "invite"),
+            ("history_visibility", "history_visibility", "shared"),
+        ]
+        for event_type, key, value in EVENT_KEEP_CONTENT_KEYS:
+            self.run_test(
+                {
+                    "type": "m.room." + event_type,
+                    "event_id": "$test:domain",
+                    "content": {key: value, "other_key": "foo"},
+                },
+                {
+                    "type": "m.room." + event_type,
+                    "event_id": "$test:domain",
+                    "content": {key: value},
+                    "signatures": {},
+                    "unsigned": {},
+                },
+            )
+
+    def test_create(self):
+        """Create events are partially redacted until MSC2176."""
         self.run_test(
             {
                 "type": "m.room.create",
                 "event_id": "$test:domain",
-                "content": {"creator": "@2:domain", "other_field": "here"},
+                "content": {"creator": "@2:domain", "other_key": "foo"},
             },
             {
                 "type": "m.room.create",
@@ -129,6 +185,68 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
+        # After MSC2176, create events get nothing redacted.
+        self.run_test(
+            {"type": "m.room.create", "content": {"not_a_real_key": True}},
+            {
+                "type": "m.room.create",
+                "content": {"not_a_real_key": True},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
+    def test_power_levels(self):
+        """Power level events keep a variety of content keys."""
+        self.run_test(
+            {
+                "type": "m.room.power_levels",
+                "event_id": "$test:domain",
+                "content": {
+                    "ban": 1,
+                    "events": {"m.room.name": 100},
+                    "events_default": 2,
+                    "invite": 3,
+                    "kick": 4,
+                    "redact": 5,
+                    "state_default": 6,
+                    "users": {"@admin:domain": 100},
+                    "users_default": 7,
+                    "other_key": 8,
+                },
+            },
+            {
+                "type": "m.room.power_levels",
+                "event_id": "$test:domain",
+                "content": {
+                    "ban": 1,
+                    "events": {"m.room.name": 100},
+                    "events_default": 2,
+                    # Note that invite is not here.
+                    "kick": 4,
+                    "redact": 5,
+                    "state_default": 6,
+                    "users": {"@admin:domain": 100},
+                    "users_default": 7,
+                },
+                "signatures": {},
+                "unsigned": {},
+            },
+        )
+
+        # After MSC2176, power levels events keep the invite key.
+        self.run_test(
+            {"type": "m.room.power_levels", "content": {"invite": 75}},
+            {
+                "type": "m.room.power_levels",
+                "content": {"invite": 75},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
     def test_alias_event(self):
         """Alias events have special behavior up through room version 6."""
         self.run_test(
@@ -146,8 +264,7 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
-    def test_msc2432_alias_event(self):
-        """After MSC2432, alias events have no special behavior."""
+        # After MSC2432, alias events have no special behavior.
         self.run_test(
             {"type": "m.room.aliases", "content": {"aliases": ["test"]}},
             {
@@ -159,6 +276,32 @@ class PruneEventTestCase(unittest.TestCase):
             room_version=RoomVersions.V6,
         )
 
+    def test_redacts(self):
+        """Redaction events have no special behaviour until MSC2174/MSC2176."""
+
+        self.run_test(
+            {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+            {
+                "type": "m.room.redaction",
+                "content": {},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.V6,
+        )
+
+        # After MSC2174, redaction events keep the redacts content key.
+        self.run_test(
+            {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+            {
+                "type": "m.room.redaction",
+                "content": {"redacts": "$test2:domain"},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
 
 class SerializeEventTestCase(unittest.TestCase):
     def serialize(self, ev, fields):
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index bd7a1b6891..c37bb6440e 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -118,4 +118,4 @@ class CasHandlerTestCase(HomeserverTestCase):
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "get_user_agent"])
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 368d600b33..2abd7a83b5 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import json
 import re
-from typing import Dict
+from typing import Dict, Optional
 from urllib.parse import parse_qs, urlencode, urlparse
 
 from mock import ANY, Mock, patch
@@ -24,7 +24,6 @@ import pymacaroons
 from twisted.web.resource import Resource
 
 from synapse.api.errors import RedirectException
-from synapse.handlers.oidc_handler import OidcError
 from synapse.handlers.sso import MappingException
 from synapse.rest.client.v1 import login
 from synapse.rest.synapse.client.pick_username import pick_username_resource
@@ -34,6 +33,14 @@ from synapse.types import UserID
 from tests.test_utils import FakeResponse, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
+try:
+    import authlib  # noqa: F401
+
+    HAS_OIDC = True
+except ImportError:
+    HAS_OIDC = False
+
+
 # These are a few constants that are used as config parameters in the tests.
 ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
@@ -113,6 +120,9 @@ async def get_json(url):
 
 
 class OidcHandlerTestCase(HomeserverTestCase):
+    if not HAS_OIDC:
+        skip = "requires OIDC"
+
     def default_config(self):
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
@@ -339,9 +349,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         cookie = args[1]
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._get_value_from_macaroon(macaroon, "state")
-        nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
-        redirect = self.handler._get_value_from_macaroon(
+        state = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "state"
+        )
+        nonce = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "nonce"
+        )
+        redirect = self.handler._token_generator._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
 
@@ -401,12 +415,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        session = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
-        )
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
         request = _build_callback_request(
             code, state, session, user_agent=user_agent, ip_address=ip_address
         )
@@ -458,6 +467,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
+        from synapse.handlers.oidc_handler import OidcError
+
         self.handler._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
@@ -488,11 +499,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_session")
 
         # Mismatching session
-        session = self.handler._generate_oidc_session_token(
-            state="state",
-            nonce="nonce",
-            client_redirect_url="http://client/redirect",
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
@@ -538,6 +546,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 body=b'{"error": "foo", "error_description": "bar"}',
             )
         )
+        from synapse.handlers.oidc_handler import OidcError
+
         exc = self.get_failure(self.handler._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
@@ -609,11 +619,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         state = "state"
         client_redirect_url = "http://client/redirect"
-        session = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state=state, nonce="nonce", client_redirect_url=client_redirect_url,
         )
         request = _build_callback_request("code", state, session)
 
@@ -827,8 +834,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
+    def _generate_oidc_session_token(
+        self,
+        state: str,
+        nonce: str,
+        client_redirect_url: str,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        from synapse.handlers.oidc_handler import OidcSessionData
+
+        return self.handler._token_generator.generate_oidc_session_token(
+            state=state,
+            session_data=OidcSessionData(
+                nonce=nonce,
+                client_redirect_url=client_redirect_url,
+                ui_auth_session_id=ui_auth_session_id,
+            ),
+        )
+
 
 class UsernamePickerTestCase(HomeserverTestCase):
+    if not HAS_OIDC:
+        skip = "requires OIDC"
+
     servlets = [login.register_servlets]
 
     def default_config(self):
@@ -948,17 +976,19 @@ async def _make_callback_with_userinfo(
         userinfo: the OIDC userinfo dict
         client_redirect_url: the URL to redirect to on success.
     """
+    from synapse.handlers.oidc_handler import OidcSessionData
+
     handler = hs.get_oidc_handler()
     handler._exchange_code = simple_async_mock(return_value={})
     handler._parse_id_token = simple_async_mock(return_value=userinfo)
     handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
     state = "state"
-    session = handler._generate_oidc_session_token(
+    session = handler._token_generator.generate_oidc_session_token(
         state=state,
-        nonce="nonce",
-        client_redirect_url=client_redirect_url,
-        ui_auth_session_id=None,
+        session_data=OidcSessionData(
+            nonce="nonce", client_redirect_url=client_redirect_url,
+        ),
     )
     request = _build_callback_request("code", state, session)
 
@@ -994,7 +1024,7 @@ def _build_callback_request(
             "addCookie",
             "requestHeaders",
             "getClientIP",
-            "get_user_agent",
+            "getHeader",
         ]
     )
 
@@ -1003,5 +1033,4 @@ def _build_callback_request(
     request.args[b"code"] = [code.encode("utf-8")]
     request.args[b"state"] = [state.encode("utf-8")]
     request.getClientIP.return_value = ip_address
-    request.get_user_agent.return_value = user_agent
     return request
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 919547556b..022943a10a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -105,6 +105,21 @@ class ProfileTestCase(unittest.TestCase):
             "Frank",
         )
 
+        # Set displayname to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), ""
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
@@ -223,6 +238,21 @@ class ProfileTestCase(unittest.TestCase):
             "http://my.server/me.png",
         )
 
+        # Set avatar to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank, synapse.types.create_requester(self.frank), "",
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
+        )
+
     @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 548038214b..261c7083d1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -262,4 +262,4 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "get_user_agent"])
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 212484a7fe..9c52c8fdca 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -560,4 +560,4 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         f = self.failureResultOf(test_d)
-        self.assertIsInstance(f.value, ValueError)
+        self.assertIsInstance(f.value, RequestSendFailed)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 0504cd187e..586b877bda 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -156,7 +154,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.hs = hs
 
         # Allow for uploading and downloading to/from the media repo
         self.media_repo = hs.get_media_repository_resource()
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index aa389df12f..d0090faa4f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -371,8 +369,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index c2b998cdae..51a7731693 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_device_handler()
         self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
@@ -181,7 +180,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_device_handler()
         self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index fa620f97f3..a0f32c5512 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -605,8 +605,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         # Create user
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 73f8a8ec99..f48be3d65a 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9b2e4765f6..04599c2fcf 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -25,6 +25,7 @@ from mock import Mock
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.api.room_versions import RoomVersions
 from synapse.rest.client.v1 import login, logout, profile, room
 from synapse.rest.client.v2_alpha import devices, sync
 
@@ -587,6 +588,200 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         _search_test(None, "bar", "user_id")
 
 
+class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass", displayname="User1")
+        self.other_user_token = self.login("user", "pass")
+        self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
+            self.other_user
+        )
+        self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote(
+            self.other_user
+        )
+
+        # set attributes for user
+        self.get_success(
+            self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+        )
+        self.get_success(
+            self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+        )
+
+    def test_no_auth(self):
+        """
+        Try to deactivate users without authentication.
+        """
+        channel = self.make_request("POST", self.url, b"{}")
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_not_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        url = "/_synapse/admin/v1/deactivate/@bob:test"
+
+        channel = self.make_request("POST", url, access_token=self.other_user_token)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+        channel = self.make_request(
+            "POST", url, access_token=self.other_user_token, content=b"{}",
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+    def test_user_does_not_exist(self):
+        """
+        Tests that deactivation for a user that does not exist returns a 404
+        """
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/deactivate/@unknown_person:test",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_erase_is_not_bool(self):
+        """
+        If parameter `erase` is not boolean, return an error
+        """
+        body = json.dumps({"erase": "False"})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that deactivation for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
+
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual("Can only deactivate local users", channel.json_body["error"])
+
+    def test_deactivate_user_erase_true(self):
+        """
+        Test deactivating an user and set `erase` to `true`
+        """
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        # Deactivate user
+        body = json.dumps({"erase": True})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertIsNone(channel.json_body["avatar_url"])
+        self.assertIsNone(channel.json_body["displayname"])
+
+        self._is_erased("@user:test", True)
+
+    def test_deactivate_user_erase_false(self):
+        """
+        Test deactivating an user and set `erase` to `false`
+        """
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        # Deactivate user
+        body = json.dumps({"erase": False})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        self._is_erased("@user:test", False)
+
+    def _is_erased(self, user_id: str, expect: bool) -> None:
+        """Assert that the user is erased or not
+        """
+        d = self.store.is_user_erased(user_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertFalse(self.get_success(d))
+
+
 class UserRestTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -986,6 +1181,26 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         Test deactivating another user.
         """
 
+        # set attributes for user
+        self.get_success(
+            self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+        )
+        self.get_success(
+            self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+        )
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
+
         # Deactivate user
         body = json.dumps({"deactivated": True})
 
@@ -999,6 +1214,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
         # the user is deactivated, the threepid will be deleted
 
         # Get user
@@ -1009,6 +1227,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
 
     @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
     def test_change_name_deactivate_user_user_directory(self):
@@ -1204,8 +1425,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -1236,24 +1455,26 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
 
     def test_user_does_not_exist(self):
         """
-        Tests that a lookup for a user that does not exist returns a 404
+        Tests that a lookup for a user that does not exist returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
         channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["total"])
+        self.assertEqual(0, len(channel.json_body["joined_rooms"]))
 
     def test_user_is_not_local(self):
         """
-        Tests that a lookup for a user that is not a local returns a 400
+        Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
 
         channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only lookup local users", channel.json_body["error"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["total"])
+        self.assertEqual(0, len(channel.json_body["joined_rooms"]))
 
     def test_no_memberships(self):
         """
@@ -1284,6 +1505,49 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(number_rooms, channel.json_body["total"])
         self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
 
+    def test_get_rooms_with_nonlocal_user(self):
+        """
+        Tests that a normal lookup for rooms is successful with a non-local user
+        """
+
+        other_user_tok = self.login("user", "pass")
+        event_builder_factory = self.hs.get_event_builder_factory()
+        event_creation_handler = self.hs.get_event_creation_handler()
+        storage = self.hs.get_storage()
+
+        # Create two rooms, one with a local user only and one with both a local
+        # and remote user.
+        self.helper.create_room_as(self.other_user, tok=other_user_tok)
+        local_and_remote_room_id = self.helper.create_room_as(
+            self.other_user, tok=other_user_tok
+        )
+
+        # Add a remote user to the room.
+        builder = event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": "m.room.member",
+                "sender": "@joiner:remote_hs",
+                "state_key": "@joiner:remote_hs",
+                "room_id": local_and_remote_room_id,
+                "content": {"membership": "join"},
+            },
+        )
+
+        event, context = self.get_success(
+            event_creation_handler.create_new_client_event(builder)
+        )
+
+        self.get_success(storage.persistence.persist_event(event, context))
+
+        # Now get rooms
+        url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(1, channel.json_body["total"])
+        self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
+
 
 class PushersRestTestCase(unittest.HomeserverTestCase):
 
@@ -1401,7 +1665,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -1868,8 +2131,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 18932d7518..f9b8011961 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,19 +1,68 @@
-import json
+# -*- coding: utf-8 -*-
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import time
 import urllib.parse
+from html.parser import HTMLParser
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
 
 from mock import Mock
 
-import jwt
+import pymacaroons
+
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.types import create_requester
 
 from tests import unittest
-from tests.unittest import override_config
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.handlers.test_saml import has_saml2
+from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.unittest import override_config, skip_unless
+
+try:
+    import jwt
+
+    HAS_JWT = True
+except ImportError:
+    HAS_JWT = False
+
+
+# public_base_url used in some tests
+BASE_URL = "https://synapse/"
+
+# CAS server used in some tests
+CAS_SERVER = "https://fake.test"
+
+# just enough to tell pysaml2 where to redirect to
+SAML_SERVER = "https://test.saml.server/idp/sso"
+TEST_SAML_METADATA = """
+<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
+  <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
+      <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
+  </md:IDPSSODescriptor>
+</md:EntityDescriptor>
+""" % {
+    "SAML_SERVER": SAML_SERVER,
+}
 
 LOGIN_URL = b"/_matrix/client/r0/login"
 TEST_URL = b"/_matrix/client/r0/account/whoami"
@@ -311,6 +360,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
 
+@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
+class MultiSSOTestCase(unittest.HomeserverTestCase):
+    """Tests for homeservers with multiple SSO providers enabled"""
+
+    servlets = [
+        login.register_servlets,
+    ]
+
+    def default_config(self) -> Dict[str, Any]:
+        config = super().default_config()
+
+        config["public_baseurl"] = BASE_URL
+
+        config["cas_config"] = {
+            "enabled": True,
+            "server_url": CAS_SERVER,
+            "service_url": "https://matrix.goodserver.com:8448",
+        }
+
+        config["saml2_config"] = {
+            "sp_config": {
+                "metadata": {"inline": [TEST_SAML_METADATA]},
+                # use the XMLSecurity backend to avoid relying on xmlsec1
+                "crypto_backend": "XMLSecurity",
+            },
+        }
+
+        config["oidc_config"] = TEST_OIDC_CONFIG
+
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
+        return d
+
+    def test_multi_sso_redirect(self):
+        """/login/sso/redirect should redirect to an identity picker"""
+        client_redirect_url = "https://x?<abc>"
+
+        # first hit the redirect url, which should redirect to our idp picker
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        uri = channel.headers.getRawHeaders("Location")[0]
+
+        # hitting that picker should give us some HTML
+        channel = self.make_request("GET", uri)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # parse the form to check it has fields assumed elsewhere in this class
+        class FormPageParser(HTMLParser):
+            def __init__(self):
+                super().__init__()
+
+                # the values of the hidden inputs: map from name to value
+                self.hiddens = {}  # type: Dict[str, Optional[str]]
+
+                # the values of the radio buttons
+                self.radios = []  # type: List[Optional[str]]
+
+            def handle_starttag(
+                self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+            ) -> None:
+                attr_dict = dict(attrs)
+                if tag == "input":
+                    if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
+                        self.radios.append(attr_dict["value"])
+                    elif attr_dict["type"] == "hidden":
+                        input_name = attr_dict["name"]
+                        assert input_name
+                        self.hiddens[input_name] = attr_dict["value"]
+
+            def error(_, message):
+                self.fail(message)
+
+        p = FormPageParser()
+        p.feed(channel.result["body"].decode("utf-8"))
+        p.close()
+
+        self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
+
+        self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
+
+    def test_multi_sso_redirect_to_cas(self):
+        """If CAS is chosen, should redirect to the CAS server"""
+        client_redirect_url = "https://x?<abc>"
+
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        cas_uri = channel.headers.getRawHeaders("Location")[0]
+        cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
+
+        # it should redirect us to the login page of the cas server
+        self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
+
+        # check that the redirectUrl is correctly encoded in the service param - ie, the
+        # place that CAS will redirect to
+        cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
+        service_uri = cas_uri_params["service"][0]
+        _, service_uri_query = service_uri.split("?", 1)
+        service_uri_params = urllib.parse.parse_qs(service_uri_query)
+        self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
+
+    def test_multi_sso_redirect_to_saml(self):
+        """If SAML is chosen, should redirect to the SAML server"""
+        client_redirect_url = "https://x?<abc>"
+
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + client_redirect_url
+            + "&idp=saml",
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        saml_uri = channel.headers.getRawHeaders("Location")[0]
+        saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
+
+        # it should redirect us to the login page of the SAML server
+        self.assertEqual(saml_uri_path, SAML_SERVER)
+
+        # the RelayState is used to carry the client redirect url
+        saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
+        relay_state_param = saml_uri_params["RelayState"][0]
+        self.assertEqual(relay_state_param, client_redirect_url)
+
+    def test_multi_sso_redirect_to_oidc(self):
+        """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
+        client_redirect_url = "https://x?<abc>"
+
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + client_redirect_url
+            + "&idp=oidc",
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+        # ... and should have set a cookie including the redirect url
+        cookies = dict(
+            h.split(";")[0].split("=", maxsplit=1)
+            for h in channel.headers.getRawHeaders("Set-Cookie")
+        )
+
+        oidc_session_cookie = cookies["oidc_session"]
+        macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
+        self.assertEqual(
+            self._get_value_from_macaroon(macaroon, "client_redirect_url"),
+            client_redirect_url,
+        )
+
+    def test_multi_sso_redirect_to_unknown(self):
+        """An unknown IdP should cause a 400"""
+        channel = self.make_request(
+            "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+
+    @staticmethod
+    def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+        prefix = key + " = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(prefix):
+                return caveat.caveat_id[len(prefix) :]
+        raise ValueError("No %s caveat in macaroon" % (key,))
+
+
 class CASTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -324,7 +551,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         config = self.default_config()
         config["cas_config"] = {
             "enabled": True,
-            "server_url": "https://fake.test",
+            "server_url": CAS_SERVER,
             "service_url": "https://matrix.goodserver.com:8448",
         }
 
@@ -385,7 +612,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", cas_ticket_url)
 
         # Test that the response is HTML.
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_header_value = ""
         for header in channel.result.get("headers", []):
             if header[0] == b"Content-Type":
@@ -410,8 +637,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_cas_redirect_whitelisted(self):
-        """Tests that the SSO login flow serves a redirect to a whitelisted url
-        """
+        """Tests that the SSO login flow serves a redirect to a whitelisted url"""
         self._test_redirect("https://legit-site.com/")
 
     @override_config({"public_baseurl": "https://example.com"})
@@ -442,7 +668,9 @@ class CASTestCase(unittest.HomeserverTestCase):
 
         # Deactivate the account.
         self.get_success(
-            self.deactivate_account_handler.deactivate_account(self.user_id, False)
+            self.deactivate_account_handler.deactivate_account(
+                self.user_id, False, create_requester(self.user_id)
+            )
         )
 
         # Request the CAS ticket.
@@ -459,6 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         self.assertIn(b"SSO account deactivated", channel.result["body"])
 
 
+@skip_unless(HAS_JWT, "requires jwt")
 class JWTTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -475,17 +704,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.hs.config.jwt_algorithm = self.jwt_algorithm
         return self.hs
 
-    def jwt_encode(self, token: str, secret: str = jwt_secret) -> str:
+    def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result = jwt.encode(token, secret, self.jwt_algorithm)
+        result = jwt.encode(
+            payload, secret, self.jwt_algorithm
+        )  # type: Union[str, bytes]
         if isinstance(result, bytes):
             return result.decode("ascii")
         return result
 
     def jwt_login(self, *args):
-        params = json.dumps(
-            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
-        )
+        params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
@@ -617,7 +846,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         )
 
     def test_login_no_token(self):
-        params = json.dumps({"type": "org.matrix.login.jwt"})
+        params = {"type": "org.matrix.login.jwt"}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -627,6 +856,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
 # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
 # signed by the private key.
+@skip_unless(HAS_JWT, "requires jwt")
 class JWTPubKeyTestCase(unittest.HomeserverTestCase):
     servlets = [
         login.register_servlets,
@@ -684,17 +914,15 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         self.hs.config.jwt_algorithm = "RS256"
         return self.hs
 
-    def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str:
+    def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result = jwt.encode(token, secret, "RS256")
+        result = jwt.encode(payload, secret, "RS256")  # type: Union[bytes,str]
         if isinstance(result, bytes):
             return result.decode("ascii")
         return result
 
     def jwt_login(self, *args):
-        params = json.dumps(
-            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
-        )
+        params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
@@ -764,8 +992,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         return self.hs
 
     def test_login_appservice_user(self):
-        """Test that an appservice user can use /login
-        """
+        """Test that an appservice user can use /login"""
         self.register_as_user(AS_USER)
 
         params = {
@@ -779,8 +1006,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_login_appservice_user_bot(self):
-        """Test that the appservice bot can use /login
-        """
+        """Test that the appservice bot can use /login"""
         self.register_as_user(AS_USER)
 
         params = {
@@ -794,8 +1020,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_login_appservice_wrong_user(self):
-        """Test that non-as users cannot login with the as token
-        """
+        """Test that non-as users cannot login with the as token"""
         self.register_as_user(AS_USER)
 
         params = {
@@ -809,8 +1034,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
     def test_login_appservice_wrong_as(self):
-        """Test that as users cannot login with wrong as token
-        """
+        """Test that as users cannot login with wrong as token"""
         self.register_as_user(AS_USER)
 
         params = {
@@ -825,7 +1049,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
 
     def test_login_appservice_no_token(self):
         """Test that users must provide a token when using the appservice
-           login method
+        login method
         """
         self.register_as_user(AS_USER)
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 6105eac47c..d4e3165436 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -29,7 +29,7 @@ from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias, UserID
+from synapse.types import JsonDict, RoomAlias, UserID, create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -1687,7 +1687,9 @@ class ContextTestCase(unittest.HomeserverTestCase):
 
         deactivate_account_handler = self.hs.get_deactivate_account_handler()
         self.get_success(
-            deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
+            deactivate_account_handler.deactivate_account(
+                self.user_id, True, create_requester(self.user_id)
+            )
         )
 
         # Invite another user in the room. This is needed because messages will be
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index dbc27893b5..81b7f84360 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -444,6 +444,7 @@ class RestHelper:
 
 
 # an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
 TEST_OIDC_CONFIG = {
     "enabled": True,
     "discover": False,
@@ -451,7 +452,7 @@ TEST_OIDC_CONFIG = {
     "client_id": "test-client-id",
     "client_secret": "test-client-secret",
     "scopes": ["profile"],
-    "authorization_endpoint": "https://z",
+    "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
     "token_endpoint": "https://issuer.test/token",
     "userinfo_endpoint": "https://issuer.test/userinfo",
     "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index ac66a4e0b7..bb91e0c331 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -26,8 +26,10 @@ from synapse.rest.oidc import OIDCResource
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
+from tests.handlers.test_oidc import HAS_OIDC
 from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
+from tests.unittest import override_config, skip_unless
 
 
 class DummyRecaptchaChecker(UserInteractiveAuthChecker):
@@ -158,20 +160,22 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     def default_config(self):
         config = super().default_config()
+        config["public_baseurl"] = "https://synapse.test"
 
-        # we enable OIDC as a way of testing SSO flows
-        oidc_config = {}
-        oidc_config.update(TEST_OIDC_CONFIG)
-        oidc_config["allow_existing_users"] = True
+        if HAS_OIDC:
+            # we enable OIDC as a way of testing SSO flows
+            oidc_config = {}
+            oidc_config.update(TEST_OIDC_CONFIG)
+            oidc_config["allow_existing_users"] = True
+            config["oidc_config"] = oidc_config
 
-        config["oidc_config"] = oidc_config
-        config["public_baseurl"] = "https://synapse.test"
         return config
 
     def create_resource_dict(self):
         resource_dict = super().create_resource_dict()
-        # mount the OIDC resource at /_synapse/oidc
-        resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+        if HAS_OIDC:
+            # mount the OIDC resource at /_synapse/oidc
+            resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
         return resource_dict
 
     def prepare(self, reactor, clock, hs):
@@ -380,6 +384,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Note that *no auth* information is provided, not even a session iD!
         self.delete_device(self.user_tok, self.device_id, 200)
 
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_does_not_offer_password_for_sso_user(self):
         login_resp = self.helper.login_via_oidc("username")
         user_tok = login_resp["access_token"]
@@ -393,13 +399,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
 
     def test_does_not_offer_sso_for_password_user(self):
-        # now call the device deletion API: we should get the option to auth with SSO
-        # and not password.
         channel = self.delete_device(self.user_tok, self.device_id, 401)
 
         flows = channel.json_body["flows"]
         self.assertEqual(flows, [{"stages": ["m.login.password"]}])
 
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_offers_both_flows_for_upgraded_user(self):
         """A user that had a password and then logged in with SSO should get both flows
         """
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 83d728b4a4..6968502433 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -26,8 +26,15 @@ from twisted.test.proto_helpers import AccumulatingProtocol
 from tests import unittest
 from tests.server import FakeTransport
 
+try:
+    import lxml
+except ImportError:
+    lxml = None
+
 
 class URLPreviewTests(unittest.HomeserverTestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
 
     hijack_auth = True
     user_id = "@test:user"
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
new file mode 100644
index 0000000000..673e1fe3e3
--- /dev/null
+++ b/tests/storage/test_account_data.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Set
+
+from synapse.api.constants import AccountDataTypes
+
+from tests import unittest
+
+
+class IgnoredUsersTestCase(unittest.HomeserverTestCase):
+    def prepare(self, hs, reactor, clock):
+        self.store = self.hs.get_datastore()
+        self.user = "@user:test"
+
+    def _update_ignore_list(
+        self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
+    ) -> None:
+        """Update the account data to block the given users."""
+        if ignorer_user_id is None:
+            ignorer_user_id = self.user
+
+        self.get_success(
+            self.store.add_account_data_for_user(
+                ignorer_user_id,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": {u: {} for u in ignored_user_ids}},
+            )
+        )
+
+    def assert_ignorers(
+        self, ignored_user_id: str, expected_ignorer_user_ids: Set[str]
+    ) -> None:
+        self.assertEqual(
+            self.get_success(self.store.ignored_by(ignored_user_id)),
+            expected_ignorer_user_ids,
+        )
+
+    def test_ignoring_users(self):
+        """Basic adding/removing of users from the ignore list."""
+        self._update_ignore_list("@other:test", "@another:remote")
+
+        # Check a user which no one ignores.
+        self.assert_ignorers("@user:test", set())
+
+        # Check a local user which is ignored.
+        self.assert_ignorers("@other:test", {self.user})
+
+        # Check a remote user which is ignored.
+        self.assert_ignorers("@another:remote", {self.user})
+
+        # Add one user, remove one user, and leave one user.
+        self._update_ignore_list("@foo:test", "@another:remote")
+
+        # Check the removed user.
+        self.assert_ignorers("@other:test", set())
+
+        # Check the added user.
+        self.assert_ignorers("@foo:test", {self.user})
+
+        # Check the removed user.
+        self.assert_ignorers("@another:remote", {self.user})
+
+    def test_caching(self):
+        """Ensure that caching works properly between different users."""
+        # The first user ignores a user.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # The second user ignores them.
+        self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+        self.assert_ignorers("@other:test", {self.user, "@second:test"})
+
+        # The first user un-ignores them.
+        self._update_ignore_list()
+        self.assert_ignorers("@other:test", {"@second:test"})
+
+    def test_invalid_data(self):
+        """Invalid data ends up clearing out the ignored users list."""
+        # Add some data and ensure it is there.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # No ignored_users key.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user, AccountDataTypes.IGNORED_USER_LIST, {},
+            )
+        )
+
+        # No one ignores the user now.
+        self.assert_ignorers("@other:test", set())
+
+        # Add some data and ensure it is there.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # Invalid data.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": "unexpected"},
+            )
+        )
+
+        # No one ignores the user now.
+        self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..83c377824b
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,472 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.storage.databases.main.events import _LinkMap
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self._next_stream_ordering = 1
+
+    def test_simple(self):
+        """Test that the example in `docs/auth_chain_difference_algorithm.md`
+        works.
+        """
+
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        power_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        bob_join_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[
+                    create.event_id,
+                    alice_join.event_id,
+                    power_2.event_id,
+                ],
+            )
+        )
+
+        events = [
+            create,
+            bob_join,
+            power,
+            alice_invite,
+            alice_join,
+            bob_join_2,
+            power_2,
+            alice_join2,
+        ]
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+            (bob_join_2, power),
+            (alice_join2, power_2),
+        ]
+
+        self.persist(events)
+        chain_map, link_map = self.fetch_chains(events)
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+        # Test that everything can reach the create event, but the create event
+        # can't reach anything.
+        for event in events[1:]:
+            self.assertTrue(
+                link_map.exists_path_from(
+                    chain_map[event.event_id], chain_map[create.event_id]
+                ),
+            )
+
+            self.assertFalse(
+                link_map.exists_path_from(
+                    chain_map[create.event_id], chain_map[event.event_id],
+                ),
+            )
+
+    def test_out_of_order_events(self):
+        """Test that we handle persisting events that we don't have the full
+        auth chain for yet (which should only happen for out of band memberships).
+        """
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        # First persist the base room.
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        self.persist([create, bob_join, power])
+
+        # Now persist an invite and a couple of memberships out of order.
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+            )
+        )
+
+        self.persist([alice_join])
+        self.persist([alice_join2])
+        self.persist([alice_invite])
+
+        # The end result should be sane.
+        events = [create, bob_join, power, alice_invite, alice_join]
+
+        chain_map, link_map = self.fetch_chains(events)
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+        ]
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+    def persist(
+        self, events: List[EventBase],
+    ):
+        """Persist the given events and check that the links generated match
+        those given.
+        """
+
+        persist_events_store = self.hs.get_datastores().persist_events
+
+        for e in events:
+            e.internal_metadata.stream_ordering = self._next_stream_ordering
+            self._next_stream_ordering += 1
+
+        def _persist(txn):
+            # We need to persist the events to the events and state_events
+            # tables.
+            persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+            # Actually call the function that calculates the auth chain stuff.
+            persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+        self.get_success(
+            persist_events_store.db_pool.runInteraction("_persist", _persist,)
+        )
+
+    def fetch_chains(
+        self, events: List[EventBase]
+    ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+        # Fetch the map from event ID -> (chain ID, sequence number)
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chains",
+                column="event_id",
+                iterable=[e.event_id for e in events],
+                retcols=("event_id", "chain_id", "sequence_number"),
+                keyvalues={},
+            )
+        )
+
+        chain_map = {
+            row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+        }
+
+        # Fetch all the links and pass them to the _LinkMap.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chain_links",
+                column="origin_chain_id",
+                iterable=[chain_id for chain_id, _ in chain_map.values()],
+                retcols=(
+                    "origin_chain_id",
+                    "origin_sequence_number",
+                    "target_chain_id",
+                    "target_sequence_number",
+                ),
+                keyvalues={},
+            )
+        )
+
+        link_map = _LinkMap()
+        for row in rows:
+            added = link_map.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+            )
+
+            # We shouldn't have persisted any redundant links
+            self.assertTrue(added)
+
+        return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+    def test_simple(self):
+        """Basic tests for the LinkMap.
+        """
+        link_map = _LinkMap()
+
+        link_map.add_link((1, 1), (2, 1), new=False)
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+        self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+        self.assertCountEqual(link_map.get_additions(), [])
+        self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+        self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+        # Attempting to add a redundant link is ignored.
+        self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+        # Adding new non-redundant links works
+        self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+        self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 482506d731..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
 import tests.unittest
 import tests.utils
 
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
 
-    def test_auth_difference(self):
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_difference(self, use_chain_cover_index: bool):
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "j": 1,
         }
 
+        # Mark the room as not having a cover index
+
+        def store_room(txn):
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                "rooms",
+                {
+                    "room_id": room_id,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": use_chain_cover_index,
+                },
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn, event_id, stream_ordering):
+        def insert_event(txn):
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn,
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                ],
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+        # Now actually test that various combinations give the right result:
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
+        self.assertSetEqual(difference, set())
+
+    def test_auth_difference_partial_cover(self):
+        """Test that we correctly handle rooms where not all events have a chain
+        cover calculated. This can happen in some obscure edge cases, including
+        during the background update that calculates the chain cover for old
+        rooms.
+        """
+
+        room_id = "@ROOM:local"
+
+        # The silly auth graph we use to test the auth difference algorithm,
+        # where the top are the most recent events.
+        #
+        #   A   B
+        #    \ /
+        #  D  E
+        #  \  |
+        #   ` F   C
+        #     |  /|
+        #     G ´ |
+        #     | \ |
+        #     H   I
+        #     |   |
+        #     K   J
+
+        auth_graph = {
+            "a": ["e"],
+            "b": ["e"],
+            "c": ["g", "i"],
+            "d": ["f"],
+            "e": ["f"],
+            "f": ["g"],
+            "g": ["h", "i"],
+            "h": ["k"],
+            "i": ["j"],
+            "k": [],
+            "j": [],
+        }
+
+        depth_map = {
+            "a": 7,
+            "b": 7,
+            "c": 4,
+            "d": 6,
+            "e": 6,
+            "f": 5,
+            "g": 3,
+            "h": 2,
+            "i": 2,
+            "k": 1,
+            "j": 1,
+        }
 
-            depth = depth_map[event_id]
+        # We rudely fiddle with the appropriate tables directly, as that's much
+        # easier than constructing events properly.
 
+        def insert_event(txn):
+            # First insert the room and mark it as having a chain cover.
             self.store.db_pool.simple_insert_txn(
                 txn,
-                table="events",
-                values={
-                    "event_id": event_id,
+                "rooms",
+                {
                     "room_id": room_id,
-                    "depth": depth,
-                    "topological_ordering": depth,
-                    "type": "m.test",
-                    "processed": True,
-                    "outlier": False,
-                    "stream_ordering": stream_ordering,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": True,
                 },
             )
 
-            self.store.db_pool.simple_insert_many_txn(
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            # Insert all events apart from 'B'
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
-                table="event_auth",
-                values=[
-                    {"event_id": event_id, "room_id": room_id, "auth_id": a}
-                    for a in auth_graph[event_id]
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                    if event_id != "b"
                 ],
             )
 
-        next_stream_ordering = 0
-        for event_id in auth_graph:
-            next_stream_ordering += 1
-            self.get_success(
-                self.store.db_pool.runInteraction(
-                    "insert", insert_event, event_id, next_stream_ordering
-                )
+            # Now we insert the event 'B' without a chain cover, by temporarily
+            # pretending the room doesn't have a chain cover.
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+            )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn, [FakeEvent("b", room_id, auth_graph["b"])],
+            )
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": True},
             )
 
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(
@@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.store.get_auth_chain_difference(room_id, [{"a"}])
         )
         self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+    event_id = attr.ib()
+    room_id = attr.ib()
+    auth_events = attr.ib()
+
+    type = "foo"
+    state_key = "foo"
+
+    internal_metadata = _EventInternalMetadata({})
+
+    def auth_event_ids(self):
+        return self.auth_events
+
+    def is_state(self):
+        return True
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..ea63bd56b4 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase):
             ),
         )
 
+        # test set to None
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.u_frank.localpart, None)
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.u_frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_avatar_url(self):
         yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
@@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase):
                 )
             ),
         )
+
+        # test set to None
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(self.u_frank.localpart, None)
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.u_frank.localpart)
+                )
+            )
+        )
diff --git a/tests/test_preview.py b/tests/test_preview.py
index a883d707df..c19facc1cb 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -20,8 +20,16 @@ from synapse.rest.media.v1.preview_url_resource import (
 
 from . import unittest
 
+try:
+    import lxml
+except ImportError:
+    lxml = None
+
 
 class PreviewTestCase(unittest.TestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
+
     def test_long_summarize(self):
         example_paras = [
             """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
@@ -137,6 +145,9 @@ class PreviewTestCase(unittest.TestCase):
 
 
 class PreviewUrlTestCase(unittest.TestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
+
     def test_simple(self):
         html = """
         <html>
diff --git a/tests/unittest.py b/tests/unittest.py
index af7f752c5a..bbd295687c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
 import inspect
 import logging
 import time
-from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
 
 from mock import Mock, patch
 
@@ -736,3 +736,29 @@ def override_config(extra_config):
         return func
 
     return decorator
+
+
+TV = TypeVar("TV")
+
+
+def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
+    """A test decorator which will skip the decorated test unless a condition is set
+
+    For example:
+
+    class MyTestCase(TestCase):
+        @skip_unless(HAS_FOO, "Cannot test without foo")
+        def test_foo(self):
+            ...
+
+    Args:
+        condition: If true, the test will be skipped
+        reason: the reason to give for skipping the test
+    """
+
+    def decorator(f: TV) -> TV:
+        if not condition:
+            f.skip = reason  # type: ignore
+        return f
+
+    return decorator
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index dadfabd46d..ecd9efc4df 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -25,13 +25,8 @@ from tests.unittest import TestCase
 class DeferredCacheTestCase(TestCase):
     def test_empty(self):
         cache = DeferredCache("test")
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get("foo")
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
     def test_hit(self):
         cache = DeferredCache("test")
@@ -155,13 +150,8 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill(("foo",), 123)
         cache.invalidate(("foo",))
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(("foo",))
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
     def test_invalidate_all(self):
         cache = DeferredCache("testcache")
@@ -215,13 +205,8 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill(2, "two")
         cache.prefill(3, "three")  # 1 will be evicted
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(1)
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
         cache.get(2)
         cache.get(3)
@@ -239,13 +224,55 @@ class DeferredCacheTestCase(TestCase):
 
         cache.prefill(3, "three")
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(2)
-        except KeyError:
-            failed = True
 
-        self.assertTrue(failed)
+        cache.get(1)
+        cache.get(3)
+
+    def test_eviction_iterable(self):
+        cache = DeferredCache(
+            "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
+        )
+
+        cache.prefill(1, ["one", "two"])
+        cache.prefill(2, ["three"])
 
+        # Now access 1 again, thus causing 2 to be least-recently used
+        cache.get(1)
+
+        # Now add an item to the cache, which evicts 2.
+        cache.prefill(3, ["four"])
+        with self.assertRaises(KeyError):
+            cache.get(2)
+
+        # Ensure 1 & 3 are in the cache.
         cache.get(1)
         cache.get(3)
+
+        # Now access 1 again, thus causing 3 to be least-recently used
+        cache.get(1)
+
+        # Now add an item with multiple elements to the cache
+        cache.prefill(4, ["five", "six"])
+
+        # Both 1 and 3 are evicted since there's too many elements.
+        with self.assertRaises(KeyError):
+            cache.get(1)
+        with self.assertRaises(KeyError):
+            cache.get(3)
+
+        # Now add another item to fill the cache again.
+        cache.prefill(5, ["seven"])
+
+        # Now access 4, thus causing 5 to be least-recently used
+        cache.get(4)
+
+        # Add an empty item.
+        cache.prefill(6, [])
+
+        # 5 gets evicted and replaced since an empty element counts as an item.
+        with self.assertRaises(KeyError):
+            cache.get(5)
+        cache.get(4)
+        cache.get(6)
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..1184cea5a3 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
 
 from tests.unittest import TestCase
 
@@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase):
         self.assertEqual(
             list(parts), [],
         )
+
+
+class SortTopologically(TestCase):
+    def test_empty(self):
+        "Test that an empty graph works correctly"
+
+        graph = {}  # type: Dict[int, List[int]]
+        self.assertEqual(list(sorted_topologically([], graph)), [])
+
+    def test_disconnected(self):
+        "Test that a graph with no edges work"
+
+        graph = {1: [], 2: []}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+    def test_linear(self):
+        "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_subset(self):
+        "Test that only sorting a subset of the graph works"
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+    def test_fork(self):
+        "Test that a forked graph works"
+        graph = {1: [], 2: [1], 3: [1], 4: [2, 3]}  # type: Dict[int, List[int]]
+
+        # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+        # always get the same one.
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])