summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8802.doc1
-rw-r--r--changelog.d/8821.bugfix1
-rw-r--r--changelog.d/8827.bugfix1
-rw-r--r--changelog.d/8837.bugfix1
-rw-r--r--changelog.d/8853.feature1
-rw-r--r--changelog.d/8858.bugfix1
-rw-r--r--changelog.d/8861.misc1
-rw-r--r--changelog.d/8862.bugfix1
-rw-r--r--changelog.d/8864.misc1
-rw-r--r--changelog.d/8865.bugfix1
-rw-r--r--changelog.d/8867.bugfix1
-rw-r--r--changelog.d/8872.bugfix1
-rw-r--r--changelog.d/8873.doc1
-rw-r--r--changelog.d/8874.feature1
-rw-r--r--changelog.d/8879.misc1
-rw-r--r--changelog.d/8880.misc1
-rw-r--r--changelog.d/8881.misc1
-rw-r--r--changelog.d/8882.misc1
-rw-r--r--changelog.d/8883.bugfix1
-rw-r--r--changelog.d/8887.feature1
-rw-r--r--changelog.d/8891.doc1
-rw-r--r--contrib/prometheus/synapse-v2.rules18
-rw-r--r--docs/sample_config.yaml33
-rw-r--r--docs/sso_mapping_providers.md4
-rw-r--r--docs/workers.md6
-rw-r--r--mypy.ini5
-rw-r--r--synapse/app/generic_worker.py1
-rw-r--r--synapse/app/homeserver.py46
-rw-r--r--synapse/config/_base.py14
-rw-r--r--synapse/config/_base.pyi7
-rw-r--r--synapse/config/_util.py35
-rw-r--r--synapse/config/emailconfig.py5
-rw-r--r--synapse/config/federation.py40
-rw-r--r--synapse/config/oidc_config.py2
-rw-r--r--synapse/config/password_auth_providers.py5
-rw-r--r--synapse/config/repository.py6
-rw-r--r--synapse/config/room_directory.py2
-rw-r--r--synapse/config/saml2_config.py2
-rw-r--r--synapse/config/spam_checker.py9
-rw-r--r--synapse/config/sso.py7
-rw-r--r--synapse/config/third_party_event_rules.py4
-rw-r--r--synapse/config/workers.py10
-rw-r--r--synapse/crypto/keyring.py4
-rw-r--r--synapse/federation/federation_server.py1
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py2
-rw-r--r--synapse/handlers/_base.py4
-rw-r--r--synapse/handlers/auth.py69
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/identity.py6
-rw-r--r--synapse/handlers/oidc_handler.py44
-rw-r--r--synapse/handlers/saml_handler.py64
-rw-r--r--synapse/handlers/sso.py59
-rw-r--r--synapse/http/client.py46
-rw-r--r--synapse/http/federation/matrix_federation_agent.py16
-rw-r--r--synapse/http/matrixfederationclient.py26
-rw-r--r--synapse/push/__init__.py53
-rw-r--r--synapse/push/emailpusher.py82
-rw-r--r--synapse/push/httppusher.py86
-rw-r--r--synapse/push/mailer.py123
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/push/pusher.py22
-rw-r--r--synapse/push/pusherpool.py36
-rw-r--r--synapse/replication/http/_base.py47
-rw-r--r--synapse/rest/admin/users.py7
-rw-r--r--synapse/rest/client/v2_alpha/register.py2
-rw-r--r--synapse/rest/media/v1/_base.py5
-rw-r--r--synapse/rest/media/v1/media_repository.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py6
-rw-r--r--synapse/rest/media/v1/upload_resource.py2
-rw-r--r--synapse/server.py36
-rw-r--r--synapse/state/__init__.py4
-rw-r--r--synapse/state/v2.py90
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/registration.py25
-rw-r--r--synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql17
-rw-r--r--synapse/util/module_loader.py64
-rw-r--r--tests/api/test_filtering.py16
-rw-r--r--tests/app/test_frontend_proxy.py2
-rw-r--r--tests/app/test_openid_listener.py4
-rw-r--r--tests/crypto/test_keyring.py6
-rw-r--r--tests/handlers/test_device.py4
-rw-r--r--tests/handlers/test_directory.py2
-rw-r--r--tests/handlers/test_federation.py2
-rw-r--r--tests/handlers/test_oidc.py17
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py2
-rw-r--r--tests/handlers/test_typing.py17
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py3
-rw-r--r--tests/push/test_http.py106
-rw-r--r--tests/replication/_base.py18
-rw-r--r--tests/replication/test_auth.py119
-rw-r--r--tests/replication/test_client_reader_shard.py9
-rw-r--r--tests/replication/test_federation_sender_shard.py10
-rw-r--r--tests/replication/test_pusher_shard.py14
-rw-r--r--tests/rest/admin/test_admin.py2
-rw-r--r--tests/rest/admin/test_user.py9
-rw-r--r--tests/rest/client/v1/test_presence.py2
-rw-r--r--tests/rest/client/v1/test_profile.py2
-rw-r--r--tests/rest/client/v1/test_rooms.py2
-rw-r--r--tests/rest/client/v1/test_typing.py2
-rw-r--r--tests/rest/client/v1/utils.py116
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py94
-rw-r--r--tests/rest/client/v2_alpha/test_register.py1
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/rest/media/v1/test_media_storage.py15
-rw-r--r--tests/rest/media/v1/test_url_preview.py26
-rw-r--r--tests/server.py4
-rw-r--r--tests/state/test_v2.py136
-rw-r--r--tests/storage/test_e2e_room_keys.py2
-rw-r--r--tests/storage/test_event_federation.py21
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_redaction.py11
-rw-r--r--tests/storage/test_roommember.py8
-rw-r--r--tests/test_federation.py2
-rw-r--r--tests/test_preview.py27
-rw-r--r--tests/test_server.py5
-rw-r--r--tests/test_utils/__init__.py27
-rw-r--r--tests/unittest.py53
-rw-r--r--tests/utils.py108
120 files changed, 1721 insertions, 660 deletions
diff --git a/changelog.d/8802.doc b/changelog.d/8802.doc
new file mode 100644
index 0000000000..580c4281f8
--- /dev/null
+++ b/changelog.d/8802.doc
@@ -0,0 +1 @@
+Fix the "Event persist rate" section of the included grafana dashboard by adding missing prometheus rules.
diff --git a/changelog.d/8821.bugfix b/changelog.d/8821.bugfix
new file mode 100644
index 0000000000..8ddfbf31ce
--- /dev/null
+++ b/changelog.d/8821.bugfix
@@ -0,0 +1 @@
+Apply the `federation_ip_range_blacklist` to push and key revocation requests.
diff --git a/changelog.d/8827.bugfix b/changelog.d/8827.bugfix
new file mode 100644
index 0000000000..18195680d3
--- /dev/null
+++ b/changelog.d/8827.bugfix
@@ -0,0 +1 @@
+Fix bug where we might not correctly calculate the current state for rooms with multiple extremities.
diff --git a/changelog.d/8837.bugfix b/changelog.d/8837.bugfix
new file mode 100644
index 0000000000..b2977d0c31
--- /dev/null
+++ b/changelog.d/8837.bugfix
@@ -0,0 +1 @@
+Fix a long standing bug in the register admin endpoint (`/_synapse/admin/v1/register`) when the `mac` field was not provided. The endpoint now properly returns a 400 error. Contributed by @edwargix.
diff --git a/changelog.d/8853.feature b/changelog.d/8853.feature
new file mode 100644
index 0000000000..63c59f4ff2
--- /dev/null
+++ b/changelog.d/8853.feature
@@ -0,0 +1 @@
+Add optional HTTP authentication to replication endpoints.
diff --git a/changelog.d/8858.bugfix b/changelog.d/8858.bugfix
new file mode 100644
index 0000000000..0d58cb9abc
--- /dev/null
+++ b/changelog.d/8858.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.
diff --git a/changelog.d/8861.misc b/changelog.d/8861.misc
new file mode 100644
index 0000000000..9821f804cf
--- /dev/null
+++ b/changelog.d/8861.misc
@@ -0,0 +1 @@
+Remove some unnecessary stubbing from unit tests.
diff --git a/changelog.d/8862.bugfix b/changelog.d/8862.bugfix
new file mode 100644
index 0000000000..bdbd633f72
--- /dev/null
+++ b/changelog.d/8862.bugfix
@@ -0,0 +1 @@
+Fix a longstanding bug where a 500 error would be returned if the `Content-Length` header was not provided to the upload media resource.
diff --git a/changelog.d/8864.misc b/changelog.d/8864.misc
new file mode 100644
index 0000000000..a780883495
--- /dev/null
+++ b/changelog.d/8864.misc
@@ -0,0 +1 @@
+Remove unused `FakeResponse` class from unit tests.
diff --git a/changelog.d/8865.bugfix b/changelog.d/8865.bugfix
new file mode 100644
index 0000000000..a1e625f552
--- /dev/null
+++ b/changelog.d/8865.bugfix
@@ -0,0 +1 @@
+Add additional validation to pusher URLs to be compliant with the specification.
diff --git a/changelog.d/8867.bugfix b/changelog.d/8867.bugfix
new file mode 100644
index 0000000000..f2414ff111
--- /dev/null
+++ b/changelog.d/8867.bugfix
@@ -0,0 +1 @@
+Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled.
diff --git a/changelog.d/8872.bugfix b/changelog.d/8872.bugfix
new file mode 100644
index 0000000000..ed00b70a0f
--- /dev/null
+++ b/changelog.d/8872.bugfix
@@ -0,0 +1 @@
+Fix a bug where `PUT /_synapse/admin/v2/users/<user_id>` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0.
diff --git a/changelog.d/8873.doc b/changelog.d/8873.doc
new file mode 100644
index 0000000000..0c2a043bd1
--- /dev/null
+++ b/changelog.d/8873.doc
@@ -0,0 +1 @@
+Fix an error in the documentation for the SAML username mapping provider.
diff --git a/changelog.d/8874.feature b/changelog.d/8874.feature
new file mode 100644
index 0000000000..720665ecac
--- /dev/null
+++ b/changelog.d/8874.feature
@@ -0,0 +1 @@
+Improve the error messages printed as a result of configuration problems for extension modules.
diff --git a/changelog.d/8879.misc b/changelog.d/8879.misc
new file mode 100644
index 0000000000..6f9516b314
--- /dev/null
+++ b/changelog.d/8879.misc
@@ -0,0 +1 @@
+Pass `room_id` to `get_auth_chain_difference`.
diff --git a/changelog.d/8880.misc b/changelog.d/8880.misc
new file mode 100644
index 0000000000..4ff0b94b94
--- /dev/null
+++ b/changelog.d/8880.misc
@@ -0,0 +1 @@
+Add type hints to push module.
diff --git a/changelog.d/8881.misc b/changelog.d/8881.misc
new file mode 100644
index 0000000000..07d3f30fb2
--- /dev/null
+++ b/changelog.d/8881.misc
@@ -0,0 +1 @@
+Simplify logic for handling user-interactive-auth via single-sign-on servers.
diff --git a/changelog.d/8882.misc b/changelog.d/8882.misc
new file mode 100644
index 0000000000..4ff0b94b94
--- /dev/null
+++ b/changelog.d/8882.misc
@@ -0,0 +1 @@
+Add type hints to push module.
diff --git a/changelog.d/8883.bugfix b/changelog.d/8883.bugfix
new file mode 100644
index 0000000000..6137fc5b2b
--- /dev/null
+++ b/changelog.d/8883.bugfix
@@ -0,0 +1 @@
+Fix a 500 error when attempting to preview an empty HTML file.
diff --git a/changelog.d/8887.feature b/changelog.d/8887.feature
new file mode 100644
index 0000000000..729eb1f1ea
--- /dev/null
+++ b/changelog.d/8887.feature
@@ -0,0 +1 @@
+Add `X-Robots-Tag` header to stop web crawlers from indexing media.
diff --git a/changelog.d/8891.doc b/changelog.d/8891.doc
new file mode 100644
index 0000000000..c3947fe7c2
--- /dev/null
+++ b/changelog.d/8891.doc
@@ -0,0 +1 @@
+Clarify comments around template directories in `sample_config.yaml`.
diff --git a/contrib/prometheus/synapse-v2.rules b/contrib/prometheus/synapse-v2.rules
index 6ccca2daaf..7e405bf7f0 100644
--- a/contrib/prometheus/synapse-v2.rules
+++ b/contrib/prometheus/synapse-v2.rules
@@ -58,3 +58,21 @@ groups:
     labels:
       type: "PDU"
     expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
+
+  - record: synapse_storage_events_persisted_by_source_type
+    expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_type="remote"})
+    labels:
+      type: remote
+  - record: synapse_storage_events_persisted_by_source_type
+    expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity="*client*",origin_type="local"})
+    labels:
+      type: local
+  - record: synapse_storage_events_persisted_by_source_type
+    expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity!="*client*",origin_type="local"})
+    labels:
+      type: bridges
+  - record: synapse_storage_events_persisted_by_event_type
+    expr: sum without(origin_entity, origin_type) (synapse_storage_events_persisted_events_sep)
+  - record: synapse_storage_events_persisted_by_origin
+    expr: sum without(type) (synapse_storage_events_persisted_events_sep)
+
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 394eb9a3ff..68c8f4f0e2 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -642,17 +642,19 @@ acme:
 #  - nyc.example.com
 #  - syd.example.com
 
-# Prevent federation requests from being sent to the following
-# blacklist IP address CIDR ranges. If this option is not specified, or
-# specified with an empty list, no ip range blacklist will be enforced.
+# Prevent outgoing requests from being sent to the following blacklisted IP address
+# CIDR ranges. If this option is not specified, or specified with an empty list,
+# no IP range blacklist will be enforced.
 #
-# As of Synapse v1.4.0 this option also affects any outbound requests to identity
-# servers provided by user input.
+# The blacklist applies to the outbound requests for federation, identity servers,
+# push servers, and for checking key validitity for third-party invite events.
 #
 # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
 # listed here, since they correspond to unroutable addresses.)
 #
-federation_ip_range_blacklist:
+# This option replaces federation_ip_range_blacklist in Synapse v1.24.0.
+#
+ip_range_blacklist:
   - '127.0.0.0/8'
   - '10.0.0.0/8'
   - '172.16.0.0/12'
@@ -1877,11 +1879,8 @@ sso:
     #  - https://my.custom.client/
 
     # Directory in which Synapse will try to find the template files below.
-    # If not set, default templates from within the Synapse package will be used.
-    #
-    # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
-    # If you *do* uncomment it, you will need to make sure that all the templates
-    # below are in the directory.
+    # If not set, or the files named below are not found within the template
+    # directory, default templates from within the Synapse package will be used.
     #
     # Synapse will look for the following templates in this directory:
     #
@@ -2111,9 +2110,8 @@ email:
   #validation_token_lifetime: 15m
 
   # Directory in which Synapse will try to find the template files below.
-  # If not set, default templates from within the Synapse package will be used.
-  #
-  # Do not uncomment this setting unless you want to customise the templates.
+  # If not set, or the files named below are not found within the template
+  # directory, default templates from within the Synapse package will be used.
   #
   # Synapse will look for the following templates in this directory:
   #
@@ -2587,6 +2585,13 @@ opentracing:
 #
 #run_background_tasks_on: worker1
 
+# A shared secret used by the replication APIs to authenticate HTTP requests
+# from workers.
+#
+# By default this is unused and traffic is not authenticated.
+#
+#worker_replication_secret: ""
+
 
 # Configuration for Redis when using workers. This *must* be enabled when
 # using workers (unless using old style direct TCP configuration).
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index ab2a648910..7714b1d844 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -116,11 +116,13 @@ comment these options out and use those specified by the module instead.
 
 A custom mapping provider must specify the following methods:
 
-* `__init__(self, parsed_config)`
+* `__init__(self, parsed_config, module_api)`
    - Arguments:
      - `parsed_config` - A configuration object that is the return value of the
        `parse_config` method. You should set any configuration options needed by
        the module here.
+     - `module_api` - a `synapse.module_api.ModuleApi` object which provides the
+       stable API available for extension modules.
 * `parse_config(config)`
     - This method should have the `@staticmethod` decoration.
     - Arguments:
diff --git a/docs/workers.md b/docs/workers.md
index c53d1bd2ff..efe97af31a 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -89,7 +89,8 @@ shared configuration file.
 Normally, only a couple of changes are needed to make an existing configuration
 file suitable for use with workers. First, you need to enable an "HTTP replication
 listener" for the main process; and secondly, you need to enable redis-based
-replication. For example:
+replication. Optionally, a shared secret can be used to authenticate HTTP
+traffic between workers. For example:
 
 
 ```yaml
@@ -103,6 +104,9 @@ listeners:
     resources:
      - names: [replication]
 
+# Add a random shared secret to authenticate traffic.
+worker_replication_secret: ""
+
 redis:
     enabled: true
 ```
diff --git a/mypy.ini b/mypy.ini
index 3c8d303064..12408b8d95 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -43,6 +43,7 @@ files =
   synapse/handlers/room_member.py,
   synapse/handlers/room_member_worker.py,
   synapse/handlers/saml_handler.py,
+  synapse/handlers/sso.py,
   synapse/handlers/sync.py,
   synapse/handlers/ui_auth,
   synapse/http/client.py,
@@ -55,6 +56,10 @@ files =
   synapse/metrics,
   synapse/module_api,
   synapse/notifier.py,
+  synapse/push/emailpusher.py,
+  synapse/push/httppusher.py,
+  synapse/push/mailer.py,
+  synapse/push/pusher.py,
   synapse/push/pusherpool.py,
   synapse/push/push_rule_evaluator.py,
   synapse/replication,
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1b511890aa..aa12c74358 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -266,7 +266,6 @@ class GenericWorkerPresence(BasePresenceHandler):
         super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
-        self.http_client = hs.get_simple_http_client()
 
         self._presence_enabled = hs.config.use_presence
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2b5465417f..bbb7407838 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -19,7 +19,7 @@ import gc
 import logging
 import os
 import sys
-from typing import Iterable
+from typing import Iterable, Iterator
 
 from twisted.application import service
 from twisted.internet import defer, reactor
@@ -90,7 +90,7 @@ class SynapseHomeServer(HomeServer):
         tls = listener_config.tls
         site_tag = listener_config.http_options.tag
         if site_tag is None:
-            site_tag = port
+            site_tag = str(port)
 
         # We always include a health resource.
         resources = {"/health": HealthResource()}
@@ -107,7 +107,10 @@ class SynapseHomeServer(HomeServer):
         logger.debug("Configuring additional resources: %r", additional_resources)
         module_api = self.get_module_api()
         for path, resmodule in additional_resources.items():
-            handler_cls, config = load_module(resmodule)
+            handler_cls, config = load_module(
+                resmodule,
+                ("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
+            )
             handler = handler_cls(config, module_api)
             if IResource.providedBy(handler):
                 resource = handler
@@ -342,7 +345,10 @@ def setup(config_options):
             "Synapse Homeserver", config_options
         )
     except ConfigError as e:
-        sys.stderr.write("\nERROR: %s\n" % (e,))
+        sys.stderr.write("\n")
+        for f in format_config_error(e):
+            sys.stderr.write(f)
+        sys.stderr.write("\n")
         sys.exit(1)
 
     if not config:
@@ -445,6 +451,38 @@ def setup(config_options):
     return hs
 
 
+def format_config_error(e: ConfigError) -> Iterator[str]:
+    """
+    Formats a config error neatly
+
+    The idea is to format the immediate error, plus the "causes" of those errors,
+    hopefully in a way that makes sense to the user. For example:
+
+        Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
+          Failed to parse config for module 'JinjaOidcMappingProvider':
+            invalid jinja template:
+              unexpected end of template, expected 'end of print statement'.
+
+    Args:
+        e: the error to be formatted
+
+    Returns: An iterator which yields string fragments to be formatted
+    """
+    yield "Error in configuration"
+
+    if e.path:
+        yield " at '%s'" % (".".join(e.path),)
+
+    yield ":\n  %s" % (e.msg,)
+
+    e = e.__cause__
+    indent = 1
+    while e:
+        indent += 1
+        yield ":\n%s%s" % ("  " * indent, str(e))
+        e = e.__cause__
+
+
 class SynapseService(service.Service):
     """
     A twisted Service class that will start synapse. Used to run synapse
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 85f65da4d9..2931a88207 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -23,7 +23,7 @@ import urllib.parse
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
-from typing import Any, Callable, List, MutableMapping, Optional
+from typing import Any, Callable, Iterable, List, MutableMapping, Optional
 
 import attr
 import jinja2
@@ -32,7 +32,17 @@ import yaml
 
 
 class ConfigError(Exception):
-    pass
+    """Represents a problem parsing the configuration
+
+    Args:
+        msg:  A textual description of the error.
+        path: Where appropriate, an indication of where in the configuration
+           the problem lies.
+    """
+
+    def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+        self.msg = msg
+        self.path = path
 
 
 # We split these messages out to allow packages to override with package
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index b8faafa9bd..ed26e2fb60 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,4 +1,4 @@
-from typing import Any, List, Optional
+from typing import Any, Iterable, List, Optional
 
 from synapse.config import (
     api,
@@ -35,7 +35,10 @@ from synapse.config import (
     workers,
 )
 
-class ConfigError(Exception): ...
+class ConfigError(Exception):
+    def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+        self.msg = msg
+        self.path = path
 
 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
 MISSING_REPORT_STATS_SPIEL: str
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index c74969a977..1bbe83c317 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -38,14 +38,27 @@ def validate_config(
     try:
         jsonschema.validate(config, json_schema)
     except jsonschema.ValidationError as e:
-        # copy `config_path` before modifying it.
-        path = list(config_path)
-        for p in list(e.path):
-            if isinstance(p, int):
-                path.append("<item %i>" % p)
-            else:
-                path.append(str(p))
-
-        raise ConfigError(
-            "Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
-        )
+        raise json_error_to_config_error(e, config_path)
+
+
+def json_error_to_config_error(
+    e: jsonschema.ValidationError, config_path: Iterable[str]
+) -> ConfigError:
+    """Converts a json validation error to a user-readable ConfigError
+
+    Args:
+        e: the exception to be converted
+        config_path: the path within the config file. This will be used as a basis
+           for the error message.
+
+    Returns:
+        a ConfigError
+    """
+    # copy `config_path` before modifying it.
+    path = list(config_path)
+    for p in list(e.path):
+        if isinstance(p, int):
+            path.append("<item %i>" % p)
+        else:
+            path.append(str(p))
+    return ConfigError(e.message, path)
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index cceffbfee2..7c8b64d84b 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -390,9 +390,8 @@ class EmailConfig(Config):
           #validation_token_lifetime: 15m
 
           # Directory in which Synapse will try to find the template files below.
-          # If not set, default templates from within the Synapse package will be used.
-          #
-          # Do not uncomment this setting unless you want to customise the templates.
+          # If not set, or the files named below are not found within the template
+          # directory, default templates from within the Synapse package will be used.
           #
           # Synapse will look for the following templates in this directory:
           #
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index ffd8fca54e..27ccf61c3c 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -36,22 +36,30 @@ class FederationConfig(Config):
             for domain in federation_domain_whitelist:
                 self.federation_domain_whitelist[domain] = True
 
-        self.federation_ip_range_blacklist = config.get(
-            "federation_ip_range_blacklist", []
-        )
+        ip_range_blacklist = config.get("ip_range_blacklist", [])
 
         # Attempt to create an IPSet from the given ranges
         try:
-            self.federation_ip_range_blacklist = IPSet(
-                self.federation_ip_range_blacklist
-            )
-
-            # Always blacklist 0.0.0.0, ::
-            self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+            self.ip_range_blacklist = IPSet(ip_range_blacklist)
+        except Exception as e:
+            raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e)
+        # Always blacklist 0.0.0.0, ::
+        self.ip_range_blacklist.update(["0.0.0.0", "::"])
+
+        # The federation_ip_range_blacklist is used for backwards-compatibility
+        # and only applies to federation and identity servers. If it is not given,
+        # default to ip_range_blacklist.
+        federation_ip_range_blacklist = config.get(
+            "federation_ip_range_blacklist", ip_range_blacklist
+        )
+        try:
+            self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
         except Exception as e:
             raise ConfigError(
                 "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
             )
+        # Always blacklist 0.0.0.0, ::
+        self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
 
         federation_metrics_domains = config.get("federation_metrics_domains") or []
         validate_config(
@@ -76,17 +84,19 @@ class FederationConfig(Config):
         #  - nyc.example.com
         #  - syd.example.com
 
-        # Prevent federation requests from being sent to the following
-        # blacklist IP address CIDR ranges. If this option is not specified, or
-        # specified with an empty list, no ip range blacklist will be enforced.
+        # Prevent outgoing requests from being sent to the following blacklisted IP address
+        # CIDR ranges. If this option is not specified, or specified with an empty list,
+        # no IP range blacklist will be enforced.
         #
-        # As of Synapse v1.4.0 this option also affects any outbound requests to identity
-        # servers provided by user input.
+        # The blacklist applies to the outbound requests for federation, identity servers,
+        # push servers, and for checking key validitity for third-party invite events.
         #
         # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
         # listed here, since they correspond to unroutable addresses.)
         #
-        federation_ip_range_blacklist:
+        # This option replaces federation_ip_range_blacklist in Synapse v1.24.0.
+        #
+        ip_range_blacklist:
           - '127.0.0.0/8'
           - '10.0.0.0/8'
           - '172.16.0.0/12'
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 69d188341c..1abf8ed405 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -66,7 +66,7 @@ class OIDCConfig(Config):
         (
             self.oidc_user_mapping_provider_class,
             self.oidc_user_mapping_provider_config,
-        ) = load_module(ump_config)
+        ) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
 
         # Ensure loaded user mapping module has defined all necessary methods
         required_methods = [
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 4fda8ae987..85d07c4f8f 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config):
             providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
 
         providers.extend(config.get("password_providers") or [])
-        for provider in providers:
+        for i, provider in enumerate(providers):
             mod_name = provider["module"]
 
             # This is for backwards compat when the ldap auth provider resided
@@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config):
                 mod_name = LDAP_PROVIDER
 
             (provider_class, provider_config) = load_module(
-                {"module": mod_name, "config": provider["config"]}
+                {"module": mod_name, "config": provider["config"]},
+                ("password_providers", "<item %i>" % i),
             )
 
             self.password_providers.append((provider_class, provider_config))
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index ba1e9d2361..17ce9145ef 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -142,7 +142,7 @@ class ContentRepositoryConfig(Config):
         # them to be started.
         self.media_storage_providers = []  # type: List[tuple]
 
-        for provider_config in storage_providers:
+        for i, provider_config in enumerate(storage_providers):
             # We special case the module "file_system" so as not to need to
             # expose FileStorageProviderBackend
             if provider_config["module"] == "file_system":
@@ -151,7 +151,9 @@ class ContentRepositoryConfig(Config):
                     ".FileStorageProviderBackend"
                 )
 
-            provider_class, parsed_config = load_module(provider_config)
+            provider_class, parsed_config = load_module(
+                provider_config, ("media_storage_providers", "<item %i>" % i)
+            )
 
             wrapper_config = MediaStorageProviderConfig(
                 provider_config.get("store_local", False),
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 92e1b67528..9a3e1c3e7d 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -180,7 +180,7 @@ class _RoomDirectoryRule:
             self._alias_regex = glob_to_regex(alias)
             self._room_id_regex = glob_to_regex(room_id)
         except Exception as e:
-            raise ConfigError("Failed to parse glob into regex: %s", e)
+            raise ConfigError("Failed to parse glob into regex") from e
 
     def matches(self, user_id, room_id, aliases):
         """Tests if this rule matches the given user_id, room_id and aliases.
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c1b8e98ae0..7b97d4f114 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -125,7 +125,7 @@ class SAML2Config(Config):
         (
             self.saml2_user_mapping_provider_class,
             self.saml2_user_mapping_provider_config,
-        ) = load_module(ump_dict)
+        ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
 
         # Ensure loaded user mapping module has defined all necessary methods
         # Note parse_config() is already checked during the call to load_module
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 3d067d29db..3d05abc158 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -33,13 +33,14 @@ class SpamCheckerConfig(Config):
             # spam checker, and thus was simply a dictionary with module
             # and config keys. Support this old behaviour by checking
             # to see if the option resolves to a dictionary
-            self.spam_checkers.append(load_module(spam_checkers))
+            self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",)))
         elif isinstance(spam_checkers, list):
-            for spam_checker in spam_checkers:
+            for i, spam_checker in enumerate(spam_checkers):
+                config_path = ("spam_checker", "<item %i>" % i)
                 if not isinstance(spam_checker, dict):
-                    raise ConfigError("spam_checker syntax is incorrect")
+                    raise ConfigError("expected a mapping", config_path)
 
-                self.spam_checkers.append(load_module(spam_checker))
+                self.spam_checkers.append(load_module(spam_checker, config_path))
         else:
             raise ConfigError("spam_checker syntax is incorrect")
 
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 4427676167..93bbd40937 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -93,11 +93,8 @@ class SSOConfig(Config):
             #  - https://my.custom.client/
 
             # Directory in which Synapse will try to find the template files below.
-            # If not set, default templates from within the Synapse package will be used.
-            #
-            # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
-            # If you *do* uncomment it, you will need to make sure that all the templates
-            # below are in the directory.
+            # If not set, or the files named below are not found within the template
+            # directory, default templates from within the Synapse package will be used.
             #
             # Synapse will look for the following templates in this directory:
             #
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index 10a99c792e..c04e1c4e07 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config):
 
         provider = config.get("third_party_event_rules", None)
         if provider is not None:
-            self.third_party_event_rules = load_module(provider)
+            self.third_party_event_rules = load_module(
+                provider, ("third_party_event_rules",)
+            )
 
     def generate_config_section(self, **kwargs):
         return """\
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 57ab097eba..7ca9efec52 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -85,6 +85,9 @@ class WorkerConfig(Config):
         # The port on the main synapse for HTTP replication endpoint
         self.worker_replication_http_port = config.get("worker_replication_http_port")
 
+        # The shared secret used for authentication when connecting to the main synapse.
+        self.worker_replication_secret = config.get("worker_replication_secret", None)
+
         self.worker_name = config.get("worker_name", self.worker_app)
 
         self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@@ -185,6 +188,13 @@ class WorkerConfig(Config):
         # data). If not provided this defaults to the main process.
         #
         #run_background_tasks_on: worker1
+
+        # A shared secret used by the replication APIs to authenticate HTTP requests
+        # from workers.
+        #
+        # By default this is unused and traffic is not authenticated.
+        #
+        #worker_replication_secret: ""
         """
 
     def read_arguments(self, args):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c04ad77cf9..f23eacc0d7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -578,7 +578,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
     def __init__(self, hs):
         super().__init__(hs)
         self.clock = hs.get_clock()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
         self.key_servers = self.config.key_servers
 
     async def get_keys(self, keys_to_fetch):
@@ -748,7 +748,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
     def __init__(self, hs):
         super().__init__(hs)
         self.clock = hs.get_clock()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
 
     async def get_keys(self, keys_to_fetch):
         """
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4b6ab470d0..35e345ce70 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -845,7 +845,6 @@ class FederationHandlerRegistry:
 
     def __init__(self, hs: "HomeServer"):
         self.config = hs.config
-        self.http_client = hs.get_simple_http_client()
         self.clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f622e..abe9168c78 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -35,7 +35,7 @@ class TransportLayerClient:
 
     def __init__(self, hs):
         self.server_name = hs.hostname
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
 
     @log_function
     def get_room_state_ids(self, destination, room_id, event_id):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b53e7a20ec..434718ddfc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
 
     Args:
         hs (synapse.server.HomeServer): homeserver
-        resource (TransportLayerServer): resource class to register to
+        resource (JsonResource): resource class to register to
         authenticator (Authenticator): authenticator to use
         ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
         servlet_groups (list[str], optional): List of servlet groups to register.
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e81d..d29b066a56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
 class BaseHandler:
     """
     Common base class for the event handlers.
+
+    Deprecated: new code should not use this. Instead, Handler classes should define the
+    fields they actually need. The utility methods should either be factored out to
+    standalone helper functions, or to different Handler classes.
     """
 
     def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..afae6d3272 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -36,6 +36,8 @@ import attr
 import bcrypt
 import pymacaroons
 
+from twisted.web.http import Request
+
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     AuthError,
@@ -193,9 +195,7 @@ class AuthHandler(BaseHandler):
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
-        self._sso_enabled = (
-            hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
-        )
+        self._password_localdb_enabled = hs.config.password_localdb_enabled
 
         # we keep this as a list despite the O(N^2) implication so that we can
         # keep PASSWORD first and avoid confusing clients which pick the first
@@ -205,7 +205,7 @@ class AuthHandler(BaseHandler):
 
         # start out by assuming PASSWORD is enabled; we will remove it later if not.
         login_types = []
-        if hs.config.password_localdb_enabled:
+        if self._password_localdb_enabled:
             login_types.append(LoginType.PASSWORD)
 
         for provider in self.password_providers:
@@ -219,14 +219,6 @@ class AuthHandler(BaseHandler):
 
         self._supported_login_types = login_types
 
-        # Login types and UI Auth types have a heavy overlap, but are not
-        # necessarily identical. Login types have SSO (and other login types)
-        # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
-        ui_auth_types = login_types.copy()
-        if self._sso_enabled:
-            ui_auth_types.append(LoginType.SSO)
-        self._supported_ui_auth_types = ui_auth_types
-
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -339,7 +331,10 @@ class AuthHandler(BaseHandler):
         self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
 
         # build a list of supported flows
-        flows = [[login_type] for login_type in self._supported_ui_auth_types]
+        supported_ui_auth_types = await self._get_available_ui_auth_types(
+            requester.user
+        )
+        flows = [[login_type] for login_type in supported_ui_auth_types]
 
         try:
             result, params, session_id = await self.check_ui_auth(
@@ -351,7 +346,7 @@ class AuthHandler(BaseHandler):
             raise
 
         # find the completed login type
-        for login_type in self._supported_ui_auth_types:
+        for login_type in supported_ui_auth_types:
             if login_type not in result:
                 continue
 
@@ -367,6 +362,41 @@ class AuthHandler(BaseHandler):
 
         return params, session_id
 
+    async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+        """Get a list of the authentication types this user can use
+        """
+
+        ui_auth_types = set()
+
+        # if the HS supports password auth, and the user has a non-null password, we
+        # support password auth
+        if self._password_localdb_enabled and self._password_enabled:
+            lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+            if lookupres:
+                _, password_hash = lookupres
+                if password_hash:
+                    ui_auth_types.add(LoginType.PASSWORD)
+
+        # also allow auth from password providers
+        for provider in self.password_providers:
+            for t in provider.get_supported_login_types().keys():
+                if t == LoginType.PASSWORD and not self._password_enabled:
+                    continue
+                ui_auth_types.add(t)
+
+        # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+        # from sso to mxid.
+        if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+            if await self.store.get_external_ids_by_user(user.to_string()):
+                ui_auth_types.add(LoginType.SSO)
+
+        # Our CAS impl does not (yet) correctly register users in user_external_ids,
+        # so always offer that if it's available.
+        if self.hs.config.cas.cas_enabled:
+            ui_auth_types.add(LoginType.SSO)
+
+        return ui_auth_types
+
     def get_enabled_auth_types(self):
         """Return the enabled user-interactive authentication types
 
@@ -1029,7 +1059,7 @@ class AuthHandler(BaseHandler):
             if result:
                 return result
 
-        if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+        if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
 
             # we've already checked that there is a (valid) password field
@@ -1303,15 +1333,14 @@ class AuthHandler(BaseHandler):
         )
 
     async def complete_sso_ui_auth(
-        self, registered_user_id: str, session_id: str, request: SynapseRequest,
+        self, registered_user_id: str, session_id: str, request: Request,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
         Args:
             registered_user_id: The registered user ID to complete SSO login for.
+            session_id: The ID of the user-interactive auth session.
             request: The request to complete.
-            client_redirect_url: The URL to which to redirect the user at the end of the
-                process.
         """
         # Mark the stage of the authentication as successful.
         # Save the user who authenticated with SSO, this will be used to ensure
@@ -1327,7 +1356,7 @@ class AuthHandler(BaseHandler):
     async def complete_sso_login(
         self,
         registered_user_id: str,
-        request: SynapseRequest,
+        request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
     ):
@@ -1355,7 +1384,7 @@ class AuthHandler(BaseHandler):
     def _complete_sso_login(
         self,
         registered_user_id: str,
-        request: SynapseRequest,
+        request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
     ):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..df82e60b33 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
         self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
-        self.http_client = hs.get_simple_http_client()
+        self.http_client = hs.get_proxied_blacklisted_http_client()
         self._instance_name = hs.get_instance_name()
         self._replication = hs.get_replication_data_handler()
 
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..7301c24710 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler):
     def __init__(self, hs):
         super().__init__(hs)
 
+        # An HTTP client for contacting trusted URLs.
         self.http_client = SimpleHttpClient(hs)
-        # We create a blacklisting instance of SimpleHttpClient for contacting identity
-        # servers specified by clients
+        # An HTTP client for contacting identity servers specified by clients.
         self.blacklisting_http_client = SimpleHttpClient(
             hs, ip_blacklist=hs.config.federation_ip_range_blacklist
         )
-        self.federation_http_client = hs.get_http_client()
+        self.federation_http_client = hs.get_federation_http_client()
         self.hs = hs
 
     async def threepid_from_creds(
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..f626117f76 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
                 self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
+        # first check if we're doing a UIA
+        if ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_userinfo(userinfo)
+            except Exception as e:
+                logger.exception("Could not extract remote user id")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+            )
+
+        # otherwise, it's a login
+
         # Pull out the user-agent and IP from the request.
         user_agent = request.get_user_agent("")
         ip_address = self.hs.get_ip_from_request(request)
@@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
             extra_attributes = await get_extra_attributes(userinfo, token)
 
         # and finally complete the login
-        if ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, ui_auth_session_id, request
-            )
-        else:
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url, extra_attributes
-            )
+        await self._auth_handler.complete_sso_login(
+            user_id, request, client_redirect_url, extra_attributes
+        )
 
     def _generate_oidc_session_token(
         self,
@@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
             The mxid of the user
         """
         try:
-            remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+            remote_user_id = self._remote_id_from_userinfo(userinfo)
         except Exception as e:
             raise MappingException(
                 "Failed to extract subject from OIDC response: %s" % (e,)
             )
-        # Some OIDC providers use integer IDs, but Synapse expects external IDs
-        # to be strings.
-        remote_user_id = str(remote_user_id)
 
         # Older mapping providers don't accept the `failures` argument, so we
         # try and detect support.
@@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
             grandfather_existing_users,
         )
 
+    def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+        """Extract the unique remote id from an OIDC UserInfo block
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+        Returns:
+            remote user id
+        """
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+        # Some OIDC providers use integer IDs, but Synapse expects external IDs
+        # to be strings.
+        return str(remote_user_id)
+
 
 UserAttributeDict = TypedDict(
     "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..5846f08609 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -183,6 +183,24 @@ class SamlHandler(BaseHandler):
             saml2_auth.in_response_to, None
         )
 
+        # first check if we're doing a UIA
+        if current_session and current_session.ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+            except MappingException as e:
+                logger.exception("Failed to extract remote user id from SAML response")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self._auth_provider_id,
+                remote_user_id,
+                current_session.ui_auth_session_id,
+                request,
+            )
+
+        # otherwise, we're handling a login request.
+
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
         for requirement in self._saml2_attribute_requirements:
@@ -206,14 +224,7 @@ class SamlHandler(BaseHandler):
             self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+        await self._auth_handler.complete_sso_login(user_id, request, relay_state)
 
     async def _map_saml_response_to_user(
         self,
@@ -239,16 +250,10 @@ class SamlHandler(BaseHandler):
             RedirectException: some mapping providers may raise this if they need
                 to redirect to an interstitial page.
         """
-
-        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+        remote_user_id = self._remote_id_from_saml_response(
             saml2_auth, client_redirect_url
         )
 
-        if not remote_user_id:
-            raise MappingException(
-                "Failed to extract remote user id from SAML response"
-            )
-
         async def saml_response_to_remapped_user_attributes(
             failures: int,
         ) -> UserAttributes:
@@ -304,6 +309,35 @@ class SamlHandler(BaseHandler):
                 grandfather_existing_users,
             )
 
+    def _remote_id_from_saml_response(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: Optional[str],
+    ) -> str:
+        """Extract the unique remote id from a SAML2 AuthnResponse
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+        Returns:
+            remote user id
+
+        Raises:
+            MappingException if there was an error extracting the user id
+        """
+        # It's not obvious why we need to pass in the redirect URI to the mapping
+        # provider, but we do :/
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+            saml2_auth, client_redirect_url
+        )
+
+        if not remote_user_id:
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
+
+        return remote_user_id
+
     def expire_sessions(self):
         expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..e24767b921 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -17,8 +17,9 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
 
 import attr
 
+from twisted.web.http import Request
+
 from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
 from synapse.http.server import respond_with_html
 from synapse.types import UserID, contains_invalid_mxid_characters
 
@@ -42,14 +43,16 @@ class UserAttributes:
     emails = attr.ib(type=List[str], default=attr.Factory(list))
 
 
-class SsoHandler(BaseHandler):
+class SsoHandler:
     # The number of attempts to ask the mapping provider for when generating an MXID.
     _MAP_USERNAME_RETRIES = 1000
 
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        self._store = hs.get_datastore()
+        self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
         self._error_template = hs.config.sso_error_template
+        self._auth_handler = hs.get_auth_handler()
 
     def render_error(
         self, request, error: str, error_description: Optional[str] = None
@@ -95,7 +98,7 @@ class SsoHandler(BaseHandler):
         )
 
         # Check if we already have a mapping for this user.
-        previously_registered_user_id = await self.store.get_user_by_external_id(
+        previously_registered_user_id = await self._store.get_user_by_external_id(
             auth_provider_id, remote_user_id,
         )
 
@@ -181,7 +184,7 @@ class SsoHandler(BaseHandler):
             previously_registered_user_id = await grandfather_existing_users()
             if previously_registered_user_id:
                 # Future logins should also match this user ID.
-                await self.store.record_user_external_id(
+                await self._store.record_user_external_id(
                     auth_provider_id, remote_user_id, previously_registered_user_id
                 )
                 return previously_registered_user_id
@@ -214,8 +217,8 @@ class SsoHandler(BaseHandler):
                 )
 
             # Check if this mxid already exists
-            user_id = UserID(attributes.localpart, self.server_name).to_string()
-            if not await self.store.get_users_by_id_case_insensitive(user_id):
+            user_id = UserID(attributes.localpart, self._server_name).to_string()
+            if not await self._store.get_users_by_id_case_insensitive(user_id):
                 # This mxid is free
                 break
         else:
@@ -238,7 +241,47 @@ class SsoHandler(BaseHandler):
             user_agent_ips=[(user_agent, ip_address)],
         )
 
-        await self.store.record_user_external_id(
+        await self._store.record_user_external_id(
             auth_provider_id, remote_user_id, registered_user_id
         )
         return registered_user_id
+
+    async def complete_sso_ui_auth_request(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        ui_auth_session_id: str,
+        request: Request,
+    ) -> None:
+        """
+        Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+        Note that this requires that the user is mapped in the "user_external_ids"
+        table. This will be the case if they have ever logged in via SAML or OIDC in
+        recentish synapse versions, but may not be for older users.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The unique identifier from the SSO provider.
+            ui_auth_session_id: The ID of the user-interactive auth session.
+            request: The request to complete.
+        """
+
+        user_id = await self.get_sso_user_by_remote_user_id(
+            auth_provider_id, remote_user_id,
+        )
+
+        if not user_id:
+            logger.warning(
+                "Remote user %s/%s has not previously logged in here: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+            )
+            # Let the UIA flow handle this the same as if they presented creds for a
+            # different user.
+            user_id = ""
+
+        await self._auth_handler.complete_sso_ui_auth(
+            user_id, ui_auth_session_id, request
+        )
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e5b13593f2..df7730078f 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -125,7 +125,7 @@ def _make_scheduler(reactor):
     return _scheduler
 
 
-class IPBlacklistingResolver:
+class _IPBlacklistingResolver:
     """
     A proxy for reactor.nameResolver which only produces non-blacklisted IP
     addresses, preventing DNS rebinding attacks on URL preview.
@@ -199,6 +199,35 @@ class IPBlacklistingResolver:
         return r
 
 
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+    """
+    A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+    addresses, to prevent DNS rebinding.
+    """
+
+    def __init__(
+        self,
+        reactor: IReactorPluggableNameResolver,
+        ip_whitelist: Optional[IPSet],
+        ip_blacklist: IPSet,
+    ):
+        self._reactor = reactor
+
+        # We need to use a DNS resolver which filters out blacklisted IP
+        # addresses, to prevent DNS rebinding.
+        self._nameResolver = _IPBlacklistingResolver(
+            self._reactor, ip_whitelist, ip_blacklist
+        )
+
+    def __getattr__(self, attr: str) -> Any:
+        # Passthrough to the real reactor except for the DNS resolver.
+        if attr == "nameResolver":
+            return self._nameResolver
+        else:
+            return getattr(self._reactor, attr)
+
+
 class BlacklistingAgentWrapper(Agent):
     """
     An Agent wrapper which will prevent access to IP addresses being accessed
@@ -292,22 +321,11 @@ class SimpleHttpClient:
         self.user_agent = self.user_agent.encode("ascii")
 
         if self._ip_blacklist:
-            real_reactor = hs.get_reactor()
             # If we have an IP blacklist, we need to use a DNS resolver which
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
-            nameResolver = IPBlacklistingResolver(
-                real_reactor, self._ip_whitelist, self._ip_blacklist
+            self.reactor = BlacklistingReactorWrapper(
+                hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
             )
-
-            @implementer(IReactorPluggableNameResolver)
-            class Reactor:
-                def __getattr__(_self, attr):
-                    if attr == "nameResolver":
-                        return nameResolver
-                    else:
-                        return getattr(real_reactor, attr)
-
-            self.reactor = Reactor()
         else:
             self.reactor = hs.get_reactor()
 
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index e77f9587d0..3b756a7dc2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -16,7 +16,7 @@ import logging
 import urllib.parse
 from typing import List, Optional
 
-from netaddr import AddrFormatError, IPAddress
+from netaddr import AddrFormatError, IPAddress, IPSet
 from zope.interface import implementer
 
 from twisted.internet import defer
@@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
 
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
+from synapse.http.client import BlacklistingAgentWrapper
 from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import WellKnownResolver
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -70,6 +71,7 @@ class MatrixFederationAgent:
         reactor: IReactorCore,
         tls_client_options_factory: Optional[FederationPolicyForHTTPS],
         user_agent: bytes,
+        ip_blacklist: IPSet,
         _srv_resolver: Optional[SrvResolver] = None,
         _well_known_resolver: Optional[WellKnownResolver] = None,
     ):
@@ -90,12 +92,18 @@ class MatrixFederationAgent:
         self.user_agent = user_agent
 
         if _well_known_resolver is None:
+            # Note that the name resolver has already been wrapped in a
+            # IPBlacklistingResolver by MatrixFederationHttpClient.
             _well_known_resolver = WellKnownResolver(
                 self._reactor,
-                agent=Agent(
+                agent=BlacklistingAgentWrapper(
+                    Agent(
+                        self._reactor,
+                        pool=self._pool,
+                        contextFactory=tls_client_options_factory,
+                    ),
                     self._reactor,
-                    pool=self._pool,
-                    contextFactory=tls_client_options_factory,
+                    ip_blacklist=ip_blacklist,
                 ),
                 user_agent=self.user_agent,
             )
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 4e27f93b7a..c962994727 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -26,11 +26,10 @@ import treq
 from canonicaljson import encode_canonical_json
 from prometheus_client import Counter
 from signedjson.sign import sign_json
-from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
+from twisted.internet.interfaces import IReactorTime
 from twisted.internet.task import _EPSILON, Cooperator
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
@@ -45,7 +44,7 @@ from synapse.api.errors import (
 from synapse.http import QuieterFileBodyProducer
 from synapse.http.client import (
     BlacklistingAgentWrapper,
-    IPBlacklistingResolver,
+    BlacklistingReactorWrapper,
     encode_query_args,
     readBodyToFile,
 )
@@ -221,31 +220,22 @@ class MatrixFederationHttpClient:
         self.signing_key = hs.signing_key
         self.server_name = hs.hostname
 
-        real_reactor = hs.get_reactor()
-
         # We need to use a DNS resolver which filters out blacklisted IP
         # addresses, to prevent DNS rebinding.
-        nameResolver = IPBlacklistingResolver(
-            real_reactor, None, hs.config.federation_ip_range_blacklist
+        self.reactor = BlacklistingReactorWrapper(
+            hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
         )
 
-        @implementer(IReactorPluggableNameResolver)
-        class Reactor:
-            def __getattr__(_self, attr):
-                if attr == "nameResolver":
-                    return nameResolver
-                else:
-                    return getattr(real_reactor, attr)
-
-        self.reactor = Reactor()
-
         user_agent = hs.version_string
         if hs.config.user_agent_suffix:
             user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
         user_agent = user_agent.encode("ascii")
 
         self.agent = MatrixFederationAgent(
-            self.reactor, tls_client_options_factory, user_agent
+            self.reactor,
+            tls_client_options_factory,
+            user_agent,
+            hs.config.federation_ip_range_blacklist,
         )
 
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5a437f9810..3d2e874838 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -13,7 +13,56 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from synapse.types import RoomStreamToken
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
+class Pusher(metaclass=abc.ABCMeta):
+    def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
+        self.hs = hs
+        self.store = self.hs.get_datastore()
+        self.clock = self.hs.get_clock()
+
+        self.pusher_id = pusherdict["id"]
+        self.user_id = pusherdict["user_name"]
+        self.app_id = pusherdict["app_id"]
+        self.pushkey = pusherdict["pushkey"]
+
+        # This is the highest stream ordering we know it's safe to process.
+        # When new events arrive, we'll be given a window of new events: we
+        # should honour this rather than just looking for anything higher
+        # because of potential out-of-order event serialisation. This starts
+        # off as None though as we don't know any better.
+        self.max_stream_ordering = None  # type: Optional[int]
+
+    @abc.abstractmethod
+    def on_new_notifications(self, max_token: RoomStreamToken) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def on_started(self, have_notifs: bool) -> None:
+        """Called when this pusher has been started.
+
+        Args:
+            should_check_for_notifs: Whether we should immediately
+                check for push to send. Set to False only if it's known there
+                is nothing to send
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def on_stop(self) -> None:
+        raise NotImplementedError()
+
 
 class PusherConfigException(Exception):
-    def __init__(self, msg):
-        super().__init__(msg)
+    """An error occurred when creating a pusher."""
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c6763971ee..64a35c1994 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,12 +14,19 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
 
+from twisted.internet.base import DelayedCall
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.push import Pusher
+from synapse.push.mailer import Mailer
 from synapse.types import RoomStreamToken
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # The amount of time we always wait before ever emailing about a notification
@@ -46,7 +53,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
 INCLUDE_ALL_UNREAD_NOTIFS = False
 
 
-class EmailPusher:
+class EmailPusher(Pusher):
     """
     A pusher that sends email notifications about events (approximately)
     when they happen.
@@ -54,37 +61,31 @@ class EmailPusher:
     factor out the common parts
     """
 
-    def __init__(self, hs, pusherdict, mailer):
-        self.hs = hs
+    def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
+        super().__init__(hs, pusherdict)
         self.mailer = mailer
 
         self.store = self.hs.get_datastore()
-        self.clock = self.hs.get_clock()
-        self.pusher_id = pusherdict["id"]
-        self.user_id = pusherdict["user_name"]
-        self.app_id = pusherdict["app_id"]
         self.email = pusherdict["pushkey"]
         self.last_stream_ordering = pusherdict["last_stream_ordering"]
-        self.timed_call = None
-        self.throttle_params = None
-
-        # See httppusher
-        self.max_stream_ordering = None
+        self.timed_call = None  # type: Optional[DelayedCall]
+        self.throttle_params = {}  # type: Dict[str, Dict[str, int]]
+        self._inited = False
 
         self._is_processing = False
 
-    def on_started(self, should_check_for_notifs):
+    def on_started(self, should_check_for_notifs: bool) -> None:
         """Called when this pusher has been started.
 
         Args:
-            should_check_for_notifs (bool): Whether we should immediately
+            should_check_for_notifs: Whether we should immediately
                 check for push to send. Set to False only if it's known there
                 is nothing to send
         """
         if should_check_for_notifs and self.mailer is not None:
             self._start_processing()
 
-    def on_stop(self):
+    def on_stop(self) -> None:
         if self.timed_call:
             try:
                 self.timed_call.cancel()
@@ -92,7 +93,7 @@ class EmailPusher:
                 pass
             self.timed_call = None
 
-    def on_new_notifications(self, max_token: RoomStreamToken):
+    def on_new_notifications(self, max_token: RoomStreamToken) -> None:
         # We just use the minimum stream ordering and ignore the vector clock
         # component. This is safe to do as long as we *always* ignore the vector
         # clock components.
@@ -106,23 +107,23 @@ class EmailPusher:
             self.max_stream_ordering = max_stream_ordering
         self._start_processing()
 
-    def on_new_receipts(self, min_stream_id, max_stream_id):
+    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
         # We could wake up and cancel the timer but there tend to be quite a
         # lot of read receipts so it's probably less work to just let the
         # timer fire
         pass
 
-    def on_timer(self):
+    def on_timer(self) -> None:
         self.timed_call = None
         self._start_processing()
 
-    def _start_processing(self):
+    def _start_processing(self) -> None:
         if self._is_processing:
             return
 
         run_as_background_process("emailpush.process", self._process)
 
-    def _pause_processing(self):
+    def _pause_processing(self) -> None:
         """Used by tests to temporarily pause processing of events.
 
         Asserts that its not currently processing.
@@ -130,25 +131,26 @@ class EmailPusher:
         assert not self._is_processing
         self._is_processing = True
 
-    def _resume_processing(self):
+    def _resume_processing(self) -> None:
         """Used by tests to resume processing of events after pausing.
         """
         assert self._is_processing
         self._is_processing = False
         self._start_processing()
 
-    async def _process(self):
+    async def _process(self) -> None:
         # we should never get here if we are already processing
         assert not self._is_processing
 
         try:
             self._is_processing = True
 
-            if self.throttle_params is None:
+            if not self._inited:
                 # this is our first loop: load up the throttle params
                 self.throttle_params = await self.store.get_throttle_params_by_room(
                     self.pusher_id
                 )
+                self._inited = True
 
             # if the max ordering changes while we're running _unsafe_process,
             # call it again, and so on until we've caught up.
@@ -163,17 +165,19 @@ class EmailPusher:
         finally:
             self._is_processing = False
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         """
         Main logic of the push loop without the wrapper function that sets
         up logging, measures and guards against multiple instances of it
         being run.
         """
         start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
-        fn = self.store.get_unread_push_actions_for_user_in_range_for_email
-        unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
+        assert self.max_stream_ordering is not None
+        unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
+            self.user_id, start, self.max_stream_ordering
+        )
 
-        soonest_due_at = None
+        soonest_due_at = None  # type: Optional[int]
 
         if not unprocessed:
             await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
@@ -230,7 +234,9 @@ class EmailPusher:
                 self.seconds_until(soonest_due_at), self.on_timer
             )
 
-    async def save_last_stream_ordering_and_success(self, last_stream_ordering):
+    async def save_last_stream_ordering_and_success(
+        self, last_stream_ordering: Optional[int]
+    ) -> None:
         if last_stream_ordering is None:
             # This happens if we haven't yet processed anything
             return
@@ -248,28 +254,30 @@ class EmailPusher:
             # lets just stop and return.
             self.on_stop()
 
-    def seconds_until(self, ts_msec):
+    def seconds_until(self, ts_msec: int) -> float:
         secs = (ts_msec - self.clock.time_msec()) / 1000
         return max(secs, 0)
 
-    def get_room_throttle_ms(self, room_id):
+    def get_room_throttle_ms(self, room_id: str) -> int:
         if room_id in self.throttle_params:
             return self.throttle_params[room_id]["throttle_ms"]
         else:
             return 0
 
-    def get_room_last_sent_ts(self, room_id):
+    def get_room_last_sent_ts(self, room_id: str) -> int:
         if room_id in self.throttle_params:
             return self.throttle_params[room_id]["last_sent_ts"]
         else:
             return 0
 
-    def room_ready_to_notify_at(self, room_id):
+    def room_ready_to_notify_at(self, room_id: str) -> int:
         """
         Determines whether throttling should prevent us from sending an email
         for the given room
-        Returns: The timestamp when we are next allowed to send an email notif
-        for this room
+
+        Returns:
+            The timestamp when we are next allowed to send an email notif
+            for this room
         """
         last_sent_ts = self.get_room_last_sent_ts(room_id)
         throttle_ms = self.get_room_throttle_ms(room_id)
@@ -277,7 +285,9 @@ class EmailPusher:
         may_send_at = last_sent_ts + throttle_ms
         return may_send_at
 
-    async def sent_notif_update_throttle(self, room_id, notified_push_action):
+    async def sent_notif_update_throttle(
+        self, room_id: str, notified_push_action: dict
+    ) -> None:
         # We have sent a notification, so update the throttle accordingly.
         # If the event that triggered the notif happened more than
         # THROTTLE_RESET_AFTER_MS after the previous one that triggered a
@@ -315,7 +325,7 @@ class EmailPusher:
             self.pusher_id, room_id, self.throttle_params[room_id]
         )
 
-    async def send_notification(self, push_actions, reason):
+    async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
         logger.info("Sending notif email for user %r", self.user_id)
 
         await self.mailer.send_notification_mail(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index eff0975b6a..5408aa1295 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -14,19 +14,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import urllib.parse
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
 
 from prometheus_client import Counter
 
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
 from synapse.logging import opentracing
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import PusherConfigException
+from synapse.push import Pusher, PusherConfigException
 from synapse.types import RoomStreamToken
 
 from . import push_rule_evaluator, push_tools
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 http_push_processed_counter = Counter(
@@ -50,24 +56,18 @@ http_badges_failed_counter = Counter(
 )
 
 
-class HttpPusher:
+class HttpPusher(Pusher):
     INITIAL_BACKOFF_SEC = 1  # in seconds because that's what Twisted takes
     MAX_BACKOFF_SEC = 60 * 60
 
     # This one's in ms because we compare it against the clock
     GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
 
-    def __init__(self, hs, pusherdict):
-        self.hs = hs
-        self.store = self.hs.get_datastore()
+    def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
+        super().__init__(hs, pusherdict)
         self.storage = self.hs.get_storage()
-        self.clock = self.hs.get_clock()
-        self.state_handler = self.hs.get_state_handler()
-        self.user_id = pusherdict["user_name"]
-        self.app_id = pusherdict["app_id"]
         self.app_display_name = pusherdict["app_display_name"]
         self.device_display_name = pusherdict["device_display_name"]
-        self.pushkey = pusherdict["pushkey"]
         self.pushkey_ts = pusherdict["ts"]
         self.data = pusherdict["data"]
         self.last_stream_ordering = pusherdict["last_stream_ordering"]
@@ -77,13 +77,6 @@ class HttpPusher:
         self._is_processing = False
         self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
 
-        # This is the highest stream ordering we know it's safe to process.
-        # When new events arrive, we'll be given a window of new events: we
-        # should honour this rather than just looking for anything higher
-        # because of potential out-of-order event serialisation. This starts
-        # off as None though as we don't know any better.
-        self.max_stream_ordering = None
-
         if "data" not in pusherdict:
             raise PusherConfigException("No 'data' key for HTTP pusher")
         self.data = pusherdict["data"]
@@ -97,26 +90,39 @@ class HttpPusher:
         if self.data is None:
             raise PusherConfigException("data can not be null for HTTP pusher")
 
+        # Validate that there's a URL and it is of the proper form.
         if "url" not in self.data:
             raise PusherConfigException("'url' required in data for HTTP pusher")
-        self.url = self.data["url"]
-        self.http_client = hs.get_proxied_http_client()
+
+        url = self.data["url"]
+        if not isinstance(url, str):
+            raise PusherConfigException("'url' must be a string")
+        url_parts = urllib.parse.urlparse(url)
+        # Note that the specification also says the scheme must be HTTPS, but
+        # it isn't up to the homeserver to verify that.
+        if url_parts.path != "/_matrix/push/v1/notify":
+            raise PusherConfigException(
+                "'url' must have a path of '/_matrix/push/v1/notify'"
+            )
+
+        self.url = url
+        self.http_client = hs.get_proxied_blacklisted_http_client()
         self.data_minus_url = {}
         self.data_minus_url.update(self.data)
         del self.data_minus_url["url"]
 
-    def on_started(self, should_check_for_notifs):
+    def on_started(self, should_check_for_notifs: bool) -> None:
         """Called when this pusher has been started.
 
         Args:
-            should_check_for_notifs (bool): Whether we should immediately
+            should_check_for_notifs: Whether we should immediately
                 check for push to send. Set to False only if it's known there
                 is nothing to send
         """
         if should_check_for_notifs:
             self._start_processing()
 
-    def on_new_notifications(self, max_token: RoomStreamToken):
+    def on_new_notifications(self, max_token: RoomStreamToken) -> None:
         # We just use the minimum stream ordering and ignore the vector clock
         # component. This is safe to do as long as we *always* ignore the vector
         # clock components.
@@ -127,14 +133,14 @@ class HttpPusher:
         )
         self._start_processing()
 
-    def on_new_receipts(self, min_stream_id, max_stream_id):
+    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
         # Note that the min here shouldn't be relied upon to be accurate.
 
         # We could check the receipts are actually m.read receipts here,
         # but currently that's the only type of receipt anyway...
         run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
 
-    async def _update_badge(self):
+    async def _update_badge(self) -> None:
         # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
         # to be largely redundant. perhaps we can remove it.
         badge = await push_tools.get_badge_count(
@@ -144,10 +150,10 @@ class HttpPusher:
         )
         await self._send_badge(badge)
 
-    def on_timer(self):
+    def on_timer(self) -> None:
         self._start_processing()
 
-    def on_stop(self):
+    def on_stop(self) -> None:
         if self.timed_call:
             try:
                 self.timed_call.cancel()
@@ -155,13 +161,13 @@ class HttpPusher:
                 pass
             self.timed_call = None
 
-    def _start_processing(self):
+    def _start_processing(self) -> None:
         if self._is_processing:
             return
 
         run_as_background_process("httppush.process", self._process)
 
-    async def _process(self):
+    async def _process(self) -> None:
         # we should never get here if we are already processing
         assert not self._is_processing
 
@@ -180,7 +186,7 @@ class HttpPusher:
         finally:
             self._is_processing = False
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         """
         Looks for unset notifications and dispatch them, in order
         Never call this directly: use _process which will only allow this to
@@ -188,6 +194,7 @@ class HttpPusher:
         """
 
         fn = self.store.get_unread_push_actions_for_user_in_range_for_http
+        assert self.max_stream_ordering is not None
         unprocessed = await fn(
             self.user_id, self.last_stream_ordering, self.max_stream_ordering
         )
@@ -257,17 +264,12 @@ class HttpPusher:
                     )
                     self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
                     self.last_stream_ordering = push_action["stream_ordering"]
-                    pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
+                    await self.store.update_pusher_last_stream_ordering(
                         self.app_id,
                         self.pushkey,
                         self.user_id,
                         self.last_stream_ordering,
                     )
-                    if not pusher_still_exists:
-                        # The pusher has been deleted while we were processing, so
-                        # lets just stop and return.
-                        self.on_stop()
-                        return
 
                     self.failing_since = None
                     await self.store.update_pusher_failing_since(
@@ -283,7 +285,7 @@ class HttpPusher:
                     )
                     break
 
-    async def _process_one(self, push_action):
+    async def _process_one(self, push_action: dict) -> bool:
         if "notify" not in push_action["actions"]:
             return True
 
@@ -314,7 +316,9 @@ class HttpPusher:
                     await self.hs.remove_pusher(self.app_id, pk, self.user_id)
         return True
 
-    async def _build_notification_dict(self, event, tweaks, badge):
+    async def _build_notification_dict(
+        self, event: EventBase, tweaks: Dict[str, bool], badge: int
+    ) -> Dict[str, Any]:
         priority = "low"
         if (
             event.type == EventTypes.Encrypted
@@ -344,9 +348,7 @@ class HttpPusher:
             }
             return d
 
-        ctx = await push_tools.get_context_for_event(
-            self.storage, self.state_handler, event, self.user_id
-        )
+        ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
 
         d = {
             "notification": {
@@ -386,7 +388,9 @@ class HttpPusher:
 
         return d
 
-    async def dispatch_push(self, event, tweaks, badge):
+    async def dispatch_push(
+        self, event: EventBase, tweaks: Dict[str, bool], badge: int
+    ) -> Union[bool, Iterable[str]]:
         notification_dict = await self._build_notification_dict(event, tweaks, badge)
         if not notification_dict:
             return []
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 38195c8eea..9ff092e8bb 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -19,7 +19,7 @@ import logging
 import urllib.parse
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-from typing import Iterable, List, TypeVar
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
 
 import bleach
 import jinja2
@@ -27,16 +27,20 @@ import jinja2
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import StoreError
 from synapse.config.emailconfig import EmailSubjectConfig
+from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable
 from synapse.push.presentable_names import (
     calculate_room_name,
     descriptor_from_member_events,
     name_from_member_event,
 )
-from synapse.types import UserID
+from synapse.types import StateMap, UserID
 from synapse.util.async_helpers import concurrently_execute
 from synapse.visibility import filter_events_for_client
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 T = TypeVar("T")
@@ -93,7 +97,13 @@ ALLOWED_ATTRS = {
 
 
 class Mailer:
-    def __init__(self, hs, app_name, template_html, template_text):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        app_name: str,
+        template_html: jinja2.Template,
+        template_text: jinja2.Template,
+    ):
         self.hs = hs
         self.template_html = template_html
         self.template_text = template_text
@@ -108,17 +118,19 @@ class Mailer:
 
         logger.info("Created Mailer for app_name %s" % app_name)
 
-    async def send_password_reset_mail(self, email_address, token, client_secret, sid):
+    async def send_password_reset_mail(
+        self, email_address: str, token: str, client_secret: str, sid: str
+    ) -> None:
         """Send an email with a password reset link to a user
 
         Args:
-            email_address (str): Email address we're sending the password
+            email_address: Email address we're sending the password
                 reset to
-            token (str): Unique token generated by the server to verify
+            token: Unique token generated by the server to verify
                 the email was received
-            client_secret (str): Unique token generated by the client to
+            client_secret: Unique token generated by the client to
                 group together multiple email sending attempts
-            sid (str): The generated session ID
+            sid: The generated session ID
         """
         params = {"token": token, "client_secret": client_secret, "sid": sid}
         link = (
@@ -136,17 +148,19 @@ class Mailer:
             template_vars,
         )
 
-    async def send_registration_mail(self, email_address, token, client_secret, sid):
+    async def send_registration_mail(
+        self, email_address: str, token: str, client_secret: str, sid: str
+    ) -> None:
         """Send an email with a registration confirmation link to a user
 
         Args:
-            email_address (str): Email address we're sending the registration
+            email_address: Email address we're sending the registration
                 link to
-            token (str): Unique token generated by the server to verify
+            token: Unique token generated by the server to verify
                 the email was received
-            client_secret (str): Unique token generated by the client to
+            client_secret: Unique token generated by the client to
                 group together multiple email sending attempts
-            sid (str): The generated session ID
+            sid: The generated session ID
         """
         params = {"token": token, "client_secret": client_secret, "sid": sid}
         link = (
@@ -164,18 +178,20 @@ class Mailer:
             template_vars,
         )
 
-    async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
+    async def send_add_threepid_mail(
+        self, email_address: str, token: str, client_secret: str, sid: str
+    ) -> None:
         """Send an email with a validation link to a user for adding a 3pid to their account
 
         Args:
-            email_address (str): Email address we're sending the validation link to
+            email_address: Email address we're sending the validation link to
 
-            token (str): Unique token generated by the server to verify the email was received
+            token: Unique token generated by the server to verify the email was received
 
-            client_secret (str): Unique token generated by the client to group together
+            client_secret: Unique token generated by the client to group together
                 multiple email sending attempts
 
-            sid (str): The generated session ID
+            sid: The generated session ID
         """
         params = {"token": token, "client_secret": client_secret, "sid": sid}
         link = (
@@ -194,8 +210,13 @@ class Mailer:
         )
 
     async def send_notification_mail(
-        self, app_id, user_id, email_address, push_actions, reason
-    ):
+        self,
+        app_id: str,
+        user_id: str,
+        email_address: str,
+        push_actions: Iterable[Dict[str, Any]],
+        reason: Dict[str, Any],
+    ) -> None:
         """Send email regarding a user's room notifications"""
         rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
 
@@ -203,7 +224,7 @@ class Mailer:
             [pa["event_id"] for pa in push_actions]
         )
 
-        notifs_by_room = {}
+        notifs_by_room = {}  # type: Dict[str, List[Dict[str, Any]]]
         for pa in push_actions:
             notifs_by_room.setdefault(pa["room_id"], []).append(pa)
 
@@ -262,7 +283,9 @@ class Mailer:
 
         await self.send_email(email_address, summary_text, template_vars)
 
-    async def send_email(self, email_address, subject, extra_template_vars):
+    async def send_email(
+        self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
+    ) -> None:
         """Send an email with the given information and template text"""
         try:
             from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@@ -315,8 +338,13 @@ class Mailer:
         )
 
     async def get_room_vars(
-        self, room_id, user_id, notifs, notif_events, room_state_ids
-    ):
+        self,
+        room_id: str,
+        user_id: str,
+        notifs: Iterable[Dict[str, Any]],
+        notif_events: Dict[str, EventBase],
+        room_state_ids: StateMap[str],
+    ) -> Dict[str, Any]:
         # Check if one of the notifs is an invite event for the user.
         is_invite = False
         for n in notifs:
@@ -334,7 +362,7 @@ class Mailer:
             "notifs": [],
             "invite": is_invite,
             "link": self.make_room_link(room_id),
-        }
+        }  # type: Dict[str, Any]
 
         if not is_invite:
             for n in notifs:
@@ -365,7 +393,13 @@ class Mailer:
 
         return room_vars
 
-    async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
+    async def get_notif_vars(
+        self,
+        notif: Dict[str, Any],
+        user_id: str,
+        notif_event: EventBase,
+        room_state_ids: StateMap[str],
+    ) -> Dict[str, Any]:
         results = await self.store.get_events_around(
             notif["room_id"],
             notif["event_id"],
@@ -391,7 +425,9 @@ class Mailer:
 
         return ret
 
-    async def get_message_vars(self, notif, event, room_state_ids):
+    async def get_message_vars(
+        self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
+    ) -> Optional[Dict[str, Any]]:
         if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
             return None
 
@@ -432,7 +468,9 @@ class Mailer:
 
         return ret
 
-    def add_text_message_vars(self, messagevars, event):
+    def add_text_message_vars(
+        self, messagevars: Dict[str, Any], event: EventBase
+    ) -> None:
         msgformat = event.content.get("format")
 
         messagevars["format"] = msgformat
@@ -445,15 +483,18 @@ class Mailer:
         elif body:
             messagevars["body_text_html"] = safe_text(body)
 
-        return messagevars
-
-    def add_image_message_vars(self, messagevars, event):
+    def add_image_message_vars(
+        self, messagevars: Dict[str, Any], event: EventBase
+    ) -> None:
         messagevars["image_url"] = event.content["url"]
 
-        return messagevars
-
     async def make_summary_text(
-        self, notifs_by_room, room_state_ids, notif_events, user_id, reason
+        self,
+        notifs_by_room: Dict[str, List[Dict[str, Any]]],
+        room_state_ids: Dict[str, StateMap[str]],
+        notif_events: Dict[str, EventBase],
+        user_id: str,
+        reason: Dict[str, Any],
     ):
         if len(notifs_by_room) == 1:
             # Only one room has new stuff
@@ -580,7 +621,7 @@ class Mailer:
                     "app": self.app_name,
                 }
 
-    def make_room_link(self, room_id):
+    def make_room_link(self, room_id: str) -> str:
         if self.hs.config.email_riot_base_url:
             base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
         elif self.app_name == "Vector":
@@ -590,7 +631,7 @@ class Mailer:
             base_url = "https://matrix.to/#"
         return "%s/%s" % (base_url, room_id)
 
-    def make_notif_link(self, notif):
+    def make_notif_link(self, notif: Dict[str, str]) -> str:
         if self.hs.config.email_riot_base_url:
             return "%s/#/room/%s/%s" % (
                 self.hs.config.email_riot_base_url,
@@ -606,7 +647,9 @@ class Mailer:
         else:
             return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
 
-    def make_unsubscribe_link(self, user_id, app_id, email_address):
+    def make_unsubscribe_link(
+        self, user_id: str, app_id: str, email_address: str
+    ) -> str:
         params = {
             "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
             "app_id": app_id,
@@ -620,7 +663,7 @@ class Mailer:
         )
 
 
-def safe_markup(raw_html):
+def safe_markup(raw_html: str) -> jinja2.Markup:
     return jinja2.Markup(
         bleach.linkify(
             bleach.clean(
@@ -635,7 +678,7 @@ def safe_markup(raw_html):
     )
 
 
-def safe_text(raw_text):
+def safe_text(raw_text: str) -> jinja2.Markup:
     """
     Process text: treat it as HTML but escape any tags (ie. just escape the
     HTML) then linkify it.
@@ -655,7 +698,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
     return ret
 
 
-def string_ordinal_total(s):
+def string_ordinal_total(s: str) -> int:
     tot = 0
     for c in s:
         tot += ord(c)
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 6e7c880dc0..df34103224 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,6 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Dict
+
+from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage import Storage
 from synapse.storage.databases.main import DataStore
@@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
     return badge
 
 
-async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
+async def get_context_for_event(
+    storage: Storage, ev: EventBase, user_id: str
+) -> Dict[str, str]:
     ctx = {}
 
     room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 2a52e226e3..8f1072b094 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -14,25 +14,31 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
 
+from synapse.push import Pusher
 from synapse.push.emailpusher import EmailPusher
+from synapse.push.httppusher import HttpPusher
 from synapse.push.mailer import Mailer
 
-from .httppusher import HttpPusher
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class PusherFactory:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.config = hs.config
 
-        self.pusher_types = {"http": HttpPusher}
+        self.pusher_types = {
+            "http": HttpPusher
+        }  # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
 
         logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
         if hs.config.email_enable_notifs:
-            self.mailers = {}  # app_name -> Mailer
+            self.mailers = {}  # type: Dict[str, Mailer]
 
             self._notif_template_html = hs.config.email_notif_template_html
             self._notif_template_text = hs.config.email_notif_template_text
@@ -41,7 +47,7 @@ class PusherFactory:
 
             logger.info("defined email pusher type")
 
-    def create_pusher(self, pusherdict):
+    def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
         kind = pusherdict["kind"]
         f = self.pusher_types.get(kind, None)
         if not f:
@@ -49,7 +55,9 @@ class PusherFactory:
         logger.debug("creating %s pusher for %r", kind, pusherdict)
         return f(self.hs, pusherdict)
 
-    def _create_email_pusher(self, _hs, pusherdict):
+    def _create_email_pusher(
+        self, _hs: "HomeServer", pusherdict: Dict[str, Any]
+    ) -> EmailPusher:
         app_name = self._app_name_from_pusherdict(pusherdict)
         mailer = self.mailers.get(app_name)
         if not mailer:
@@ -62,7 +70,7 @@ class PusherFactory:
             self.mailers[app_name] = mailer
         return EmailPusher(self.hs, pusherdict, mailer)
 
-    def _app_name_from_pusherdict(self, pusherdict):
+    def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
         data = pusherdict["data"]
 
         if isinstance(data, dict):
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f325964983..9fcc0b8a64 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Union
+from typing import TYPE_CHECKING, Any, Dict, Optional
 
 from prometheus_client import Gauge
 
@@ -23,9 +23,7 @@ from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
-from synapse.push import PusherConfigException
-from synapse.push.emailpusher import EmailPusher
-from synapse.push.httppusher import HttpPusher
+from synapse.push import Pusher, PusherConfigException
 from synapse.push.pusher import PusherFactory
 from synapse.types import RoomStreamToken
 from synapse.util.async_helpers import concurrently_execute
@@ -77,7 +75,7 @@ class PusherPool:
         self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
 
         # map from user id to app_id:pushkey to pusher
-        self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
+        self.pushers = {}  # type: Dict[str, Dict[str, Pusher]]
 
     def start(self):
         """Starts the pushers off in a background process.
@@ -99,11 +97,11 @@ class PusherPool:
         lang,
         data,
         profile_tag="",
-    ):
+    ) -> Optional[Pusher]:
         """Creates a new pusher and adds it to the pool
 
         Returns:
-            EmailPusher|HttpPusher
+            The newly created pusher.
         """
 
         time_now_msec = self.clock.time_msec()
@@ -267,17 +265,19 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
-    async def start_pusher_by_id(self, app_id, pushkey, user_id):
+    async def start_pusher_by_id(
+        self, app_id: str, pushkey: str, user_id: str
+    ) -> Optional[Pusher]:
         """Look up the details for the given pusher, and start it
 
         Returns:
-            EmailPusher|HttpPusher|None: The pusher started, if any
+            The pusher started, if any
         """
         if not self._should_start_pushers:
-            return
+            return None
 
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
-            return
+            return None
 
         resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
 
@@ -303,19 +303,19 @@ class PusherPool:
 
         logger.info("Started pushers")
 
-    async def _start_pusher(self, pusherdict):
+    async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
         """Start the given pusher
 
         Args:
-            pusherdict (dict): dict with the values pulled from the db table
+            pusherdict: dict with the values pulled from the db table
 
         Returns:
-            EmailPusher|HttpPusher
+            The newly created pusher or None.
         """
         if not self._pusher_shard_config.should_handle(
             self._instance_name, pusherdict["user_name"]
         ):
-            return
+            return None
 
         try:
             p = self.pusher_factory.create_pusher(pusherdict)
@@ -328,15 +328,15 @@ class PusherPool:
                 pusherdict.get("pushkey"),
                 e,
             )
-            return
+            return None
         except Exception:
             logger.exception(
                 "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
             )
-            return
+            return None
 
         if not p:
-            return
+            return None
 
         appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
 
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2b3972cb14..1492ac922c 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         assert self.METHOD in ("PUT", "POST", "GET")
 
+        self._replication_secret = None
+        if hs.config.worker.worker_replication_secret:
+            self._replication_secret = hs.config.worker.worker_replication_secret
+
+    def _check_auth(self, request) -> None:
+        # Get the authorization header.
+        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+        if len(auth_headers) > 1:
+            raise RuntimeError("Too many Authorization headers.")
+        parts = auth_headers[0].split(b" ")
+        if parts[0] == b"Bearer" and len(parts) == 2:
+            received_secret = parts[1].decode("ascii")
+            if self._replication_secret == received_secret:
+                # Success!
+                return
+
+        raise RuntimeError("Invalid Authorization header.")
+
     @abc.abstractmethod
     async def _serialize_payload(**kwargs):
         """Static method that is called when creating a request.
@@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
 
+        replication_secret = None
+        if hs.config.worker.worker_replication_secret:
+            replication_secret = hs.config.worker.worker_replication_secret.encode(
+                "ascii"
+            )
+
         @trace(opname="outgoing_replication_request")
         @outgoing_gauge.track_inprogress()
         async def send_request(instance_name="master", **kwargs):
@@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 # the master, and so whether we should clean up or not.
                 while True:
                     headers = {}  # type: Dict[bytes, List[bytes]]
+                    # Add an authorization header, if configured.
+                    if replication_secret:
+                        headers[b"Authorization"] = [b"Bearer " + replication_secret]
                     inject_active_span_byte_dict(headers, None, check_destination=False)
                     try:
                         result = await request_func(uri, data, headers=headers)
@@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         """
 
         url_args = list(self.PATH_ARGS)
-        handler = self._handle_request
         method = self.METHOD
 
         if self.CACHE:
-            handler = self._cached_handler  # type: ignore
             url_args.append("txn_id")
 
         args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
         pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
 
         http_server.register_paths(
-            method, [pattern], handler, self.__class__.__name__,
+            method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
         )
 
-    def _cached_handler(self, request, txn_id, **kwargs):
+    def _check_auth_and_handle(self, request, **kwargs):
         """Called on new incoming requests when caching is enabled. Checks
         if there is a cached response for the request and returns that,
         otherwise calls `_handle_request` and caches its response.
@@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         # We just use the txn_id here, but we probably also want to use the
         # other PATH_ARGS as well.
 
-        assert self.CACHE
+        # Check the authorization headers before handling the request.
+        if self._replication_secret:
+            self._check_auth(request)
+
+        if self.CACHE:
+            txn_id = kwargs.pop("txn_id")
+
+            return self.response_cache.wrap(
+                txn_id, self._handle_request, request, **kwargs
+            )
 
-        return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
+        return self._handle_request(request, **kwargs)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..88cba369f5 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -320,9 +320,9 @@ class UserRestServletV2(RestServlet):
                             data={},
                         )
 
-            if "avatar_url" in body and type(body["avatar_url"]) == str:
+            if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
-                    user_id, requester, body["avatar_url"], True
+                    target_user, requester, body["avatar_url"], True
                 )
 
             ret = await self.admin_handler.get_user(target_user)
@@ -420,6 +420,9 @@ class UserRegisterServlet(RestServlet):
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
             raise SynapseError(400, "Invalid user type")
 
+        if "mac" not in body:
+            raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
         got_mac = body["mac"]
 
         want_mac_builder = hmac.new(
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a89ae6ddf9..9041e7ed76 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -451,7 +451,7 @@ class RegisterRestServlet(RestServlet):
 
         # == Normal User Registration == (everyone else)
         if not self._registration_enabled:
-            raise SynapseError(403, "Registration has been disabled")
+            raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
 
         # For regular registration, convert the provided username to lowercase
         # before attempting to register it. This should mean that people who try
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 67aa993f19..47c2b44bff 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
     request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
     request.setHeader(b"Content-Length", b"%d" % (file_size,))
 
+    # Tell web crawlers to not index, archive, or follow links in media. This
+    # should help to prevent things in the media repo from showing up in web
+    # search results.
+    request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
+
 
 # separators as defined in RFC2616. SP and HT are handled separately.
 # see _can_encode_filename_as_token.
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74ebd8..83beb02b05 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -66,7 +66,7 @@ class MediaRepository:
     def __init__(self, hs):
         self.hs = hs
         self.auth = hs.get_auth()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index dce6c4d168..1082389d9b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -676,7 +676,11 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def decode_and_calc_og(body, media_uri, request_encoding=None):
+def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
+    # If there's no body, nothing useful is going to be found.
+    if not body:
+        return {}
+
     from lxml import etree
 
     try:
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389e1..42febc9afc 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -44,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
         # already been uploaded to a tmp file at this point
-        content_length = request.getHeader(b"Content-Length").decode("ascii")
+        content_length = request.getHeader("Content-Length")
         if content_length is None:
             raise SynapseError(msg="Request must specify a Content-Length", code=400)
         if int(content_length) > self.max_upload_size:
diff --git a/synapse/server.py b/synapse/server.py
index b017e3489f..9af759626e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -350,17 +350,46 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_simple_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client with no special configuration.
+        """
         return SimpleHttpClient(self)
 
     @cache_in_self
     def get_proxied_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client that uses configured HTTP(S) proxies.
+        """
+        return SimpleHttpClient(
+            self,
+            http_proxy=os.getenvb(b"http_proxy"),
+            https_proxy=os.getenvb(b"HTTPS_PROXY"),
+        )
+
+    @cache_in_self
+    def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
+        based on the IP range blacklist.
+        """
         return SimpleHttpClient(
             self,
+            ip_blacklist=self.config.ip_range_blacklist,
             http_proxy=os.getenvb(b"http_proxy"),
             https_proxy=os.getenvb(b"HTTPS_PROXY"),
         )
 
     @cache_in_self
+    def get_federation_http_client(self) -> MatrixFederationHttpClient:
+        """
+        An HTTP client for federation.
+        """
+        tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
+            self.config
+        )
+        return MatrixFederationHttpClient(self, tls_client_options_factory)
+
+    @cache_in_self
     def get_room_creation_handler(self) -> RoomCreationHandler:
         return RoomCreationHandler(self)
 
@@ -515,13 +544,6 @@ class HomeServer(metaclass=abc.ABCMeta):
         return PusherPool(self)
 
     @cache_in_self
-    def get_http_client(self) -> MatrixFederationHttpClient:
-        tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
-            self.config
-        )
-        return MatrixFederationHttpClient(self, tls_client_options_factory)
-
-    @cache_in_self
     def get_media_repository_resource(self) -> MediaRepositoryResource:
         # build the media repo resource. This indirects through the HomeServer
         # to ensure that we only have a single instance of
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1fa3b280b4..84f59c7d85 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -783,7 +783,7 @@ class StateResolutionStore:
         )
 
     def get_auth_chain_difference(
-        self, state_sets: List[Set[str]]
+        self, room_id: str, state_sets: List[Set[str]]
     ) -> Awaitable[Set[str]]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
@@ -796,4 +796,4 @@ class StateResolutionStore:
             An awaitable that resolves to a set of event IDs.
         """
 
-        return self.store.get_auth_chain_difference(state_sets)
+        return self.store.get_auth_chain_difference(room_id, state_sets)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..f85124bf81 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
 
     # Also fetch all auth events that appear in only some of the state sets'
     # auth chains.
-    auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
+    auth_diff = await _get_auth_chain_difference(
+        room_id, state_sets, event_map, state_res_store
+    )
 
     full_conflicted_set = set(
         itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
 
 
 async def _get_auth_chain_difference(
+    room_id: str,
     state_sets: Sequence[StateMap[str]],
     event_map: Dict[str, EventBase],
     state_res_store: "synapse.state.StateResolutionStore",
@@ -252,9 +255,90 @@ async def _get_auth_chain_difference(
         Set of event IDs
     """
 
+    # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+    # all events passed to it (and their auth chains) have been persisted
+    # previously. This is not the case for any events in the `event_map`, and so
+    # we need to manually handle those events.
+    #
+    # We do this by:
+    #   1. calculating the auth chain difference for the state sets based on the
+    #      events in `event_map` alone
+    #   2. replacing any events in the state_sets that are also in `event_map`
+    #      with their auth events (recursively), and then calling
+    #      `store.get_auth_chain_difference` as normal
+    #   3. adding the results of 1 and 2 together.
+
+    # Map from event ID in `event_map` to their auth event IDs, and their auth
+    # event IDs if they appear in the `event_map`. This is the intersection of
+    # the event's auth chain with the events in the `event_map` *plus* their
+    # auth event IDs.
+    events_to_auth_chain = {}  # type: Dict[str, Set[str]]
+    for event in event_map.values():
+        chain = {event.event_id}
+        events_to_auth_chain[event.event_id] = chain
+
+        to_search = [event]
+        while to_search:
+            for auth_id in to_search.pop().auth_event_ids():
+                chain.add(auth_id)
+                auth_event = event_map.get(auth_id)
+                if auth_event:
+                    to_search.append(auth_event)
+
+    # We now a) calculate the auth chain difference for the unpersisted events
+    # and b) work out the state sets to pass to the store.
+    #
+    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # much simpler calculation.
+    if event_map:
+        # The list of state sets to pass to the store, where each state set is a set
+        # of the event ids making up the state. This is similar to `state_sets`,
+        # except that (a) we only have event ids, not the complete
+        # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+        # unpersisted events and replaced them with the persisted events in
+        # their auth chain.
+        state_sets_ids = []  # type: List[Set[str]]
+
+        # For each state set, the unpersisted event IDs reachable (by their auth
+        # chain) from the events in that set.
+        unpersisted_set_ids = []  # type: List[Set[str]]
+
+        for state_set in state_sets:
+            set_ids = set()  # type: Set[str]
+            state_sets_ids.append(set_ids)
+
+            unpersisted_ids = set()  # type: Set[str]
+            unpersisted_set_ids.append(unpersisted_ids)
+
+            for event_id in state_set.values():
+                event_chain = events_to_auth_chain.get(event_id)
+                if event_chain is not None:
+                    # We have an event in `event_map`. We add all the auth
+                    # events that it references (that aren't also in `event_map`).
+                    set_ids.update(e for e in event_chain if e not in event_map)
+
+                    # We also add the full chain of unpersisted event IDs
+                    # referenced by this state set, so that we can work out the
+                    # auth chain difference of the unpersisted events.
+                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                else:
+                    set_ids.add(event_id)
+
+        # The auth chain difference of the unpersisted events of the state sets
+        # is calculated by taking the difference between the union and
+        # intersections.
+        union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+        intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+        difference_from_event_map = union - intersection  # type: Collection[str]
+    else:
+        difference_from_event_map = ()
+        state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
     difference = await state_res_store.get_auth_chain_difference(
-        [set(state_set.values()) for state_set in state_sets]
+        room_id, state_sets_ids
     )
+    difference.update(difference_from_event_map)
 
     return difference
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 2e07c37340..ebffd89251 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
-    async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
+    async def get_auth_chain_difference(
+        self, room_id: str, state_sets: List[Set[str]]
+    ) -> Set[str]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fedb8a6c26..ff96c34c2e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_user_by_external_id",
         )
 
+    async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
+        """Look up external ids for the given user
+
+        Args:
+            mxid: the MXID to be looked up
+
+        Returns:
+            Tuples of (auth_provider, external_id)
+        """
+        res = await self.db_pool.simple_select_list(
+            table="user_external_ids",
+            keyvalues={"user_id": mxid},
+            retcols=("auth_provider", "external_id"),
+            desc="get_external_ids_by_user",
+        )
+        return [(r["auth_provider"], r["external_id"]) for r in res]
+
     async def count_all_users(self):
         """Counts all users registered on the homeserver."""
 
@@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "user_external_ids_user_id_idx",
+            index_name="user_external_ids_user_id_idx",
+            table="user_external_ids",
+            columns=["user_id"],
+            unique=False,
+        )
+
     async def _background_update_set_deactivated_flag(self, progress, batch_size):
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
         for each of them.
diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
new file mode 100644
index 0000000000..8f5e65aa71
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5825, 'user_external_ids_user_id_idx', '{}');
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 94b59afb38..1ee61851e4 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -15,28 +15,56 @@
 
 import importlib
 import importlib.util
+import itertools
+from typing import Any, Iterable, Tuple, Type
+
+import jsonschema
 
 from synapse.config._base import ConfigError
+from synapse.config._util import json_error_to_config_error
 
 
-def load_module(provider):
+def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
     """ Loads a synapse module with its config
-    Take a dict with keys 'module' (the module name) and 'config'
-    (the config dict).
+
+    Args:
+        provider: a dict with keys 'module' (the module name) and 'config'
+           (the config dict).
+        config_path: the path within the config file. This will be used as a basis
+           for any error message.
 
     Returns
         Tuple of (provider class, parsed config object)
     """
+
+    modulename = provider.get("module")
+    if not isinstance(modulename, str):
+        raise ConfigError(
+            "expected a string", path=itertools.chain(config_path, ("module",))
+        )
+
     # We need to import the module, and then pick the class out of
     # that, so we split based on the last dot.
-    module, clz = provider["module"].rsplit(".", 1)
+    module, clz = modulename.rsplit(".", 1)
     module = importlib.import_module(module)
     provider_class = getattr(module, clz)
 
+    module_config = provider.get("config")
     try:
-        provider_config = provider_class.parse_config(provider.get("config"))
+        provider_config = provider_class.parse_config(module_config)
+    except jsonschema.ValidationError as e:
+        raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
+    except ConfigError as e:
+        raise _wrap_config_error(
+            "Failed to parse config for module %r" % (modulename,),
+            prefix=itertools.chain(config_path, ("config",)),
+            e=e,
+        )
     except Exception as e:
-        raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
+        raise ConfigError(
+            "Failed to parse config for module %r" % (modulename,),
+            path=itertools.chain(config_path, ("config",)),
+        ) from e
 
     return provider_class, provider_config
 
@@ -56,3 +84,27 @@ def load_python_module(location: str):
     mod = importlib.util.module_from_spec(spec)
     spec.loader.exec_module(mod)  # type: ignore
     return mod
+
+
+def _wrap_config_error(
+    msg: str, prefix: Iterable[str], e: ConfigError
+) -> "ConfigError":
+    """Wrap a relative ConfigError with a new path
+
+    This is useful when we have a ConfigError with a relative path due to a problem
+    parsing part of the config, and we now need to set it in context.
+    """
+    path = prefix
+    if e.path:
+        path = itertools.chain(prefix, e.path)
+
+    e1 = ConfigError(msg, path)
+
+    # ideally we would set the 'cause' of the new exception to the original exception;
+    # however now that we have merged the path into our own, the stringification of
+    # e will be incorrect, so instead we create a new exception with just the "msg"
+    # part.
+
+    e1.__cause__ = Exception(e.msg)
+    e1.__cause__.__cause__ = e.__cause__
+    return e1
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index c98ae75974..279c94a03d 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -16,8 +16,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from mock import Mock
-
 import jsonschema
 
 from twisted.internet import defer
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
 from synapse.events import make_event_from_dict
 
 from tests import unittest
-from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
+from tests.utils import setup_test_homeserver
 
 user_localpart = "test_user"
 
@@ -42,19 +40,9 @@ def MockEvent(**kwargs):
 
 
 class FilteringTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation_resource = MockHttpResource()
-
-        self.mock_http_client = Mock(spec=[])
-        self.mock_http_client.put_json = DeferredMockCallable()
-
-        hs = yield setup_test_homeserver(
-            self.addCleanup, http_client=self.mock_http_client, keyring=Mock(),
-        )
-
+        hs = setup_test_homeserver(self.addCleanup)
         self.filtering = hs.get_filtering()
-
         self.datastore = hs.get_datastore()
 
     def test_errors_on_invalid_filters(self):
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 40abe9d72d..43fef5d64a 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -23,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=GenericWorkerServer
+            federation_http_client=None, homeserver_to_use=GenericWorkerServer
         )
 
         return hs
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index ea3be95cf1..b260ab734d 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase
 class FederationReaderOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=GenericWorkerServer
+            federation_http_client=None, homeserver_to_use=GenericWorkerServer
         )
         return hs
 
@@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
 class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=SynapseHomeServer
+            federation_http_client=None, homeserver_to_use=SynapseHomeServer
         )
         return hs
 
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 697916a019..d146f2254f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock()
-        hs = self.setup_test_homeserver(http_client=self.http_client)
+        hs = self.setup_test_homeserver(federation_http_client=self.http_client)
         return hs
 
     def test_get_keys_from_server(self):
@@ -395,7 +395,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
             }
         ]
 
-        return self.setup_test_homeserver(http_client=self.http_client, config=config)
+        return self.setup_test_homeserver(
+            federation_http_client=self.http_client, config=config
+        )
 
     def build_perspectives_response(
         self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 875aaec2c6..5dfeccfeb6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -27,7 +27,7 @@ user2 = "@theresa:bbb"
 
 class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.store = hs.get_datastore()
         return hs
@@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ee6ef5e6fa..770d225ed5 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,8 +42,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.mock_registry.register_query_handler = register_query_handler
 
         hs = self.setup_test_homeserver(
-            http_client=None,
-            resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_registry=self.mock_registry,
         )
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf866dacf3..d0452e1490 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(http_client=None)
+        hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastore()
         return hs
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..1d99a45436 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
 
 from mock import Mock, patch
 
-import attr
 import pymacaroons
 
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
 from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
 from synapse.handlers.sso import MappingException
 from synapse.types import UserID
 
+from tests.test_utils import FakeResponse
 from tests.unittest import HomeserverTestCase, override_config
 
-
-@attr.s
-class FakeResponse:
-    code = attr.ib()
-    body = attr.ib()
-    phrase = attr.ib()
-
-    def deliverBody(self, protocol):
-        protocol.dataReceived(self.body)
-        protocol.connectionLost(Failure(ResponseDone()))
-
-
 # These are a few constants that are used as config parameters in the tests.
 ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 8ed67640f8..0794b32c9c 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            "server", http_client=None, federation_sender=Mock()
+            "server", federation_http_client=None, federation_sender=Mock()
         )
         return hs
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..919547556b 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
 
         hs = yield setup_test_homeserver(
             self.addCleanup,
-            http_client=None,
-            resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..f21de958f1 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
 
 
 import json
+from typing import Dict
 
 from mock import ANY, Mock, call
 
 from twisted.internet import defer
+from twisted.web.resource import Resource
 
 from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
 from synapse.types import UserID, create_requester
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
-from tests.utils import register_federation_servlets
 
 # Some local users to test with
 U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
-    servlets = [register_federation_servlets]
-
     def make_homeserver(self, reactor, clock):
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
@@ -70,13 +70,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(
             notifier=Mock(),
-            http_client=mock_federation_client,
+            federation_http_client=mock_federation_client,
             keyring=mock_keyring,
             replication_streams={},
         )
 
         return hs
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TransportLayerServer(self.hs)
+        return d
+
     def prepare(self, reactor, clock, hs):
         mock_notifier = hs.get_notifier()
         self.on_new_event = mock_notifier.on_new_event
@@ -192,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        put_json = self.hs.get_http_client().put_json
+        put_json = self.hs.get_federation_http_client().put_json
         put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
@@ -270,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
-        put_json = self.hs.get_http_client().put_json
+        put_json = self.hs.get_federation_http_client().put_json
         put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..626acdcaa3 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,6 +17,7 @@ import logging
 from mock import Mock
 
 import treq
+from netaddr import IPSet
 from service_identity import VerificationError
 from zope.interface import implementer
 
@@ -103,6 +104,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             reactor=self.reactor,
             tls_client_options_factory=self.tls_factory,
             user_agent="test-agent",  # Note that this is unused since _well_known_resolver is provided.
+            ip_blacklist=IPSet(),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=self.well_known_resolver,
         )
@@ -736,6 +738,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             reactor=self.reactor,
             tls_client_options_factory=tls_factory,
             user_agent=b"test-agent",  # This is unused since _well_known_resolver is passed below.
+            ip_blacklist=IPSet(),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=WellKnownResolver(
                 self.reactor,
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index f118430309..8b4af74c51 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred
 
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
+from synapse.push import PusherConfigException
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import receipts
 
@@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase):
     user_id = True
     hijack_auth = False
 
+    def default_config(self):
+        config = super().default_config()
+        config["start_pushers"] = True
+        return config
+
     def make_homeserver(self, reactor, clock):
         self.push_attempts = []
 
@@ -46,13 +52,49 @@ class HTTPPusherTests(HomeserverTestCase):
 
         m.post_json_get_json = post_json_get_json
 
-        config = self.default_config()
-        config["start_pushers"] = True
-
-        hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
+        hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
 
         return hs
 
+    def test_invalid_configuration(self):
+        """Invalid push configurations should be rejected."""
+        # Register the user who gets notified
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Register the pusher
+        user_tuple = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple.token_id
+
+        def test_data(data):
+            self.get_failure(
+                self.hs.get_pusherpool().add_pusher(
+                    user_id=user_id,
+                    access_token=token_id,
+                    kind="http",
+                    app_id="m.http",
+                    app_display_name="HTTP Push Notifications",
+                    device_display_name="pushy push",
+                    pushkey="a@example.com",
+                    lang=None,
+                    data=data,
+                ),
+                PusherConfigException,
+            )
+
+        # Data must be provided with a URL.
+        test_data(None)
+        test_data({})
+        test_data({"url": 1})
+        # A bare domain name isn't accepted.
+        test_data({"url": "example.com"})
+        # A URL without a path isn't accepted.
+        test_data({"url": "http://example.com"})
+        # A url with an incorrect path isn't accepted.
+        test_data({"url": "http://example.com/foo"})
+
     def test_sends_http(self):
         """
         The HTTP pusher will send pushes for each message to a HTTP endpoint
@@ -82,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -117,7 +159,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # One push was attempted to be sent -- it'll be the first message
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(
             self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
         )
@@ -137,7 +181,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Now it'll try and send the second push message, which will be the second one
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(
             self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
         )
@@ -194,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -230,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Add yet another person — we want to make this room not a 1:1
@@ -268,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
 
     def test_sends_high_priority_for_one_to_one_only(self):
@@ -310,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -326,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority — this is a one-to-one room
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Yet another user joins
@@ -345,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -392,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -408,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Send another event, this time with no mention
@@ -417,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -465,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -485,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Send another event, this time as someone without the power of @room
@@ -496,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -570,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -589,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # Check that the unread count for the room is 0
         #
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..3379189785 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import attr
 
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
 )
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     if not hiredis:
         skip = "Requires hiredis"
 
-    servlets = [
-        streams.register_servlets,
-    ]
-
     def prepare(self, reactor, clock, hs):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
@@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.worker_hs = self.setup_test_homeserver(
-            http_client=None,
+            federation_http_client=None,
             homeserver_to_use=GenericWorkerServer,
             config=self._get_worker_hs_config(),
             reactor=self.reactor,
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._client_transport = None
         self._server_transport = None
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+        return d
+
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.generic_worker"
@@ -264,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
             extra_config: Any extra config to use for this instances.
             **kwargs: Options that get passed to `self.setup_test_homeserver`,
-                useful to e.g. pass some mocks for things like `http_client`
+                useful to e.g. pass some mocks for things like `federation_http_client`
 
         Returns:
             The new worker HomeServer instance.
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..fe9e4d5f9a
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,119 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Tuple
+
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+    """Test the authentication of HTTP calls between workers."""
+
+    servlets = [register.register_servlets]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        # This isn't a real configuration option but is used to provide the main
+        # homeserver and worker homeserver different options.
+        main_replication_secret = config.pop("main_replication_secret", None)
+        if main_replication_secret:
+            config["worker_replication_secret"] = main_replication_secret
+        return self.setup_test_homeserver(config=config)
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.client_reader"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+
+        return config
+
+    def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]:
+        """Run the actual test:
+
+        1. Create a worker homeserver.
+        2. Start registration by providing a user/password.
+        3. Complete registration by providing dummy auth (this hits the main synapse).
+        4. Return the final request.
+
+        """
+        worker_hs = self.make_worker_hs("synapse.app.client_reader")
+        site = self._hs_to_site[worker_hs]
+
+        request_1, channel_1 = make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )  # type: SynapseRequest, FakeChannel
+        self.assertEqual(request_1.code, 401)
+
+        # Grab the session
+        session = channel_1.json_body["session"]
+
+        # also complete the dummy auth
+        return make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": "m.login.dummy"}},
+        )
+
+    def test_no_auth(self):
+        """With no authentication the request should finish.
+        """
+        request, channel = self._test_register()
+        self.assertEqual(request.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+    @override_config({"main_replication_secret": "my-secret"})
+    def test_missing_auth(self):
+        """If the main process expects a secret that is not provided, an error results.
+        """
+        request, channel = self._test_register()
+        self.assertEqual(request.code, 500)
+
+    @override_config(
+        {
+            "main_replication_secret": "my-secret",
+            "worker_replication_secret": "wrong-secret",
+        }
+    )
+    def test_unauthorized(self):
+        """If the main process receives the wrong secret, an error results.
+        """
+        request, channel = self._test_register()
+        self.assertEqual(request.code, 500)
+
+    @override_config({"worker_replication_secret": "my-secret"})
+    def test_authorized(self):
+        """The request should finish when the worker provides the authentication header.
+        """
+        request, channel = self._test_register()
+        self.assertEqual(request.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..fdaad3d8ad 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,20 @@
 # limitations under the License.
 import logging
 
-from synapse.api.constants import LoginType
 from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha import register
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
 from tests.server import FakeChannel, make_request
 
 logger = logging.getLogger(__name__)
 
 
 class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
-    """Base class for tests of the replication streams"""
+    """Test using one or more client readers for registration."""
 
     servlets = [register.register_servlets]
 
-    def prepare(self, reactor, clock, hs):
-        self.recaptcha_checker = DummyRecaptchaChecker(hs)
-        auth_handler = hs.get_auth_handler()
-        auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.client_reader"
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 779745ae9d..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {"send_federation": True},
-            http_client=mock_client,
+            federation_http_client=mock_client,
         )
 
         user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client1,
+            federation_http_client=mock_client1,
         )
 
         mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client2,
+            federation_http_client=mock_client2,
         )
 
         user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client1,
+            federation_http_client=mock_client1,
         )
 
         mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client2,
+            federation_http_client=mock_client2,
         )
 
         user = self.register_user("user3", "pass")
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 67c27a089f..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "https://push.example.com/push"},
+                data={"url": "https://push.example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         self.make_worker_hs(
             "synapse.app.pusher",
             {"start_pushers": True},
-            proxied_http_client=http_client_mock,
+            proxied_blacklisted_http_client=http_client_mock,
         )
 
         event_id = self._create_pusher_and_send_msg("user")
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock.post_json_get_json.assert_called_once()
         self.assertEqual(
             http_client_mock.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "pusher1",
                 "pusher_instances": ["pusher1", "pusher2"],
             },
-            proxied_http_client=http_client_mock1,
+            proxied_blacklisted_http_client=http_client_mock1,
         )
 
         http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "pusher2",
                 "pusher_instances": ["pusher1", "pusher2"],
             },
-            proxied_http_client=http_client_mock2,
+            proxied_blacklisted_http_client=http_client_mock2,
         )
 
         # We choose a user name that we know should go to pusher1.
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock2.post_json_get_json.assert_not_called()
         self.assertEqual(
             http_client_mock1.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock2.post_json_get_json.assert_called_once()
         self.assertEqual(
             http_client_mock2.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 4f76f8f768..67d8878395 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -210,7 +210,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         }
         config["media_storage_providers"] = [provider_config]
 
-        hs = self.setup_test_homeserver(config=config, http_client=client)
+        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
 
         return hs
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 54d46f4bd3..ba1438cdc7 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -561,7 +561,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
                 "admin": True,
                 "displayname": "Bob's name",
                 "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
-                "avatar_url": None,
+                "avatar_url": "mxc://fibble/wibble",
             }
         )
 
@@ -578,6 +578,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
         self.assertEqual(True, channel.json_body["admin"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
         request, channel = self.make_request(
@@ -592,6 +593,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(True, channel.json_body["admin"])
         self.assertEqual(False, channel.json_body["is_guest"])
         self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     def test_create_user(self):
         """
@@ -606,6 +608,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
                 "admin": False,
                 "displayname": "Bob's name",
                 "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+                "avatar_url": "mxc://fibble/wibble",
             }
         )
 
@@ -622,6 +625,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
         self.assertEqual(False, channel.json_body["admin"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
         request, channel = self.make_request(
@@ -636,6 +640,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(False, channel.json_body["admin"])
         self.assertEqual(False, channel.json_body["is_guest"])
         self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     @override_config(
         {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1256,7 +1261,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "https://example.com/_matrix/push/v1/notify"},
             )
         )
 
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 5d5c24d01c..11cd8efe21 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -38,7 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(
             "red",
-            http_client=None,
+            federation_http_client=None,
             federation_client=Mock(),
             presence_handler=presence_handler,
         )
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 383a9eafac..2a3b483eaf 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             self.addCleanup,
             "test",
-            http_client=None,
+            federation_http_client=None,
             resource_for_client=self.mock_resource,
             federation=Mock(),
             federation_client=Mock(),
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 49f1073c88..e67de41c18 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -45,7 +45,7 @@ class RoomBase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         self.hs = self.setup_test_homeserver(
-            "red", http_client=None, federation_client=Mock(),
+            "red", federation_http_client=None, federation_client=Mock(),
         )
 
         self.hs.get_federation_handler = Mock()
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index bbd30f594b..ae0207366b 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -39,7 +39,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         hs = self.setup_test_homeserver(
-            "red", http_client=None, federation_client=Mock(),
+            "red", federation_http_client=None, federation_client=Mock(),
         )
 
         self.event_source = hs.get_event_sources().sources["typing"]
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..5a18af8d34 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,17 +17,23 @@
 # limitations under the License.
 
 import json
+import re
 import time
+import urllib.parse
 from typing import Any, Dict, Optional
 
+from mock import patch
+
 import attr
 
 from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
+from synapse.types import JsonDict
 
 from tests.server import FakeSite, make_request
+from tests.test_utils import FakeResponse
 
 
 @attr.s
@@ -344,3 +350,111 @@ class RestHelper:
         )
 
         return channel.json_body
+
+    def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+        """Log in (as a new user) via OIDC
+
+        Returns the result of the final token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+        client_redirect_url = "https://x"
+
+        # first hit the redirect url (which will issue a cookie and state)
+        _, channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            "/login/sso/redirect?redirectUrl=" + client_redirect_url,
+        )
+        # that will redirect to the OIDC IdP, but we skip that and go straight
+        # back to synapse's OIDC callback resource. However, we do need the "state"
+        # param that synapse passes to the IdP via query params, and the cookie that
+        # synapse passes to the client.
+        assert channel.code == 302
+        oauth_uri = channel.headers.getRawHeaders("Location")[0]
+        params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
+        redirect_uri = "%s?%s" % (
+            urllib.parse.urlparse(params["redirect_uri"][0]).path,
+            urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+        )
+        cookies = {}
+        for h in channel.headers.getRawHeaders("Set-Cookie"):
+            parts = h.split(";")
+            k, v = parts[0].split("=", maxsplit=1)
+            cookies[k] = v
+
+        # before we hit the callback uri, stub out some methods in the http client so
+        # that we don't have to handle full HTTPS requests.
+
+        # (expected url, json response) pairs, in the order we expect them.
+        expected_requests = [
+            # first we get a hit to the token endpoint, which we tell to return
+            # a dummy OIDC access token
+            ("https://issuer.test/token", {"access_token": "TEST"}),
+            # and then one to the user_info endpoint, which returns our remote user id.
+            ("https://issuer.test/userinfo", {"sub": remote_user_id}),
+        ]
+
+        async def mock_req(method: str, uri: str, data=None, headers=None):
+            (expected_uri, resp_obj) = expected_requests.pop(0)
+            assert uri == expected_uri
+            resp = FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+            )
+            return resp
+
+        with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+            # now hit the callback URI with the right params and a made-up code
+            _, channel = make_request(
+                self.hs.get_reactor(),
+                self.site,
+                "GET",
+                redirect_uri,
+                custom_headers=[
+                    ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+                ],
+            )
+
+        # expect a confirmation page
+        assert channel.code == 200
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.result["body"].decode("utf-8"),
+        )
+        assert m
+        login_token = m.group(1)
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token and device id.
+        _, channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        assert channel.code == 200
+        return channel.json_body
+
+
+# an 'oidc_config' suitable for login_with_oidc.
+TEST_OIDC_CONFIG = {
+    "enabled": True,
+    "discover": False,
+    "issuer": "https://issuer.test",
+    "client_id": "test-client-id",
+    "client_secret": "test-client-secret",
+    "scopes": ["profile"],
+    "authorization_endpoint": "https://z",
+    "token_endpoint": "https://issuer.test/token",
+    "userinfo_endpoint": "https://issuer.test/userinfo",
+    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..ac67a9de29 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from typing import List, Union
 
 from twisted.internet.defer import succeed
@@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.http.site import SynapseRequest
 from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.oidc import OIDCResource
+from synapse.types import JsonDict, UserID
 
 from tests import unittest
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 
 
@@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
         register.register_servlets,
     ]
 
+    def default_config(self):
+        config = super().default_config()
+
+        # we enable OIDC as a way of testing SSO flows
+        oidc_config = {}
+        oidc_config.update(TEST_OIDC_CONFIG)
+        oidc_config["allow_existing_users"] = True
+
+        config["oidc_config"] = oidc_config
+        config["public_baseurl"] = "https://synapse.test"
+        return config
+
+    def create_resource_dict(self):
+        resource_dict = super().create_resource_dict()
+        # mount the OIDC resource at /_synapse/oidc
+        resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+        return resource_dict
+
     def prepare(self, reactor, clock, hs):
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
         self.user_tok = self.login("test", self.user_pass)
 
-    def get_device_ids(self) -> List[str]:
+    def get_device_ids(self, access_token: str) -> List[str]:
         # Get the list of devices so one can be deleted.
-        request, channel = self.make_request(
-            "GET", "devices", access_token=self.user_tok,
-        )  # type: SynapseRequest, FakeChannel
-
-        # Get the ID of the device.
-        self.assertEqual(request.code, 200)
+        _, channel = self.make_request("GET", "devices", access_token=access_token,)
+        self.assertEqual(channel.code, 200)
         return [d["device_id"] for d in channel.json_body["devices"]]
 
     def delete_device(
-        self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+        self,
+        access_token: str,
+        device: str,
+        expected_response: int,
+        body: Union[bytes, JsonDict] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
         request, channel = self.make_request(
-            "DELETE", "devices/" + device, body, access_token=self.user_tok
+            "DELETE", "devices/" + device, body, access_token=access_token,
         )  # type: SynapseRequest, FakeChannel
 
         # Ensure the response is sane.
@@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """
         Test user interactive authentication outside of registration.
         """
-        device_id = self.get_device_ids()[0]
+        device_id = self.get_device_ids(self.user_tok)[0]
 
         # Attempt to delete this device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(device_id, 401)
+        channel = self.delete_device(self.user_tok, device_id, 401)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Make another request providing the UI auth flow.
         self.delete_device(
+            self.user_tok,
             device_id,
             200,
             {
@@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         UIA - check that still works.
         """
 
-        device_id = self.get_device_ids()[0]
-        channel = self.delete_device(device_id, 401)
+        device_id = self.get_device_ids(self.user_tok)[0]
+        channel = self.delete_device(self.user_tok, device_id, 401)
         session = channel.json_body["session"]
 
         # Make another request providing the UI auth flow.
         self.delete_device(
+            self.user_tok,
             device_id,
             200,
             {
@@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Create a second login.
         self.login("test", self.user_pass)
 
-        device_ids = self.get_device_ids()
+        device_ids = self.get_device_ids(self.user_tok)
         self.assertEqual(len(device_ids), 2)
 
         # Attempt to delete the first device.
@@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Create a second login.
         self.login("test", self.user_pass)
 
-        device_ids = self.get_device_ids()
+        device_ids = self.get_device_ids(self.user_tok)
         self.assertEqual(len(device_ids), 2)
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(device_ids[0], 401)
+        channel = self.delete_device(self.user_tok, device_ids[0], 401)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Make another request providing the UI auth flow, but try to delete the
         # second device. This results in an error.
         self.delete_device(
+            self.user_tok,
             device_ids[1],
             403,
             {
@@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
                 },
             },
         )
+
+    def test_does_not_offer_password_for_sso_user(self):
+        login_resp = self.helper.login_via_oidc("username")
+        user_tok = login_resp["access_token"]
+        device_id = login_resp["device_id"]
+
+        # now call the device deletion API: we should get the option to auth with SSO
+        # and not password.
+        channel = self.delete_device(user_tok, device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+    def test_does_not_offer_sso_for_password_user(self):
+        # now call the device deletion API: we should get the option to auth with SSO
+        # and not password.
+        device_ids = self.get_device_ids(self.user_tok)
+        channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+        flows = channel.json_body["flows"]
+        self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+    def test_offers_both_flows_for_upgraded_user(self):
+        """A user that had a password and then logged in with SSO should get both flows
+        """
+        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        device_ids = self.get_device_ids(self.user_tok)
+        channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+        flows = channel.json_body["flows"]
+        # we have no particular expectations of ordering here
+        self.assertIn({"stages": ["m.login.password"]}, flows)
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+        self.assertEqual(len(flows), 2)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 8f0c2430e8..bcb21d0ced 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -121,6 +121,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+        self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
 
     def test_POST_guest_registration(self):
         self.hs.config.macaroon_secret_key = "test"
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index fbcf8d5b86..5e90d656f7 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -39,7 +39,7 @@ from tests.utils import default_config
 class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock()
-        return self.setup_test_homeserver(http_client=self.http_client)
+        return self.setup_test_homeserver(federation_http_client=self.http_client)
 
     def create_test_resource(self):
         return create_resource_tree(
@@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             }
         ]
         self.hs2 = self.setup_test_homeserver(
-            http_client=self.http_client2, config=config
+            federation_http_client=self.http_client2, config=config
         )
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2a3b2a8f27..6f0677d335 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -214,7 +214,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         }
         config["media_storage_providers"] = [provider_config]
 
-        hs = self.setup_test_homeserver(config=config, http_client=client)
+        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
 
         return hs
 
@@ -362,3 +362,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                     "error": "Not found [b'example.com', b'12345']",
                 },
             )
+
+    def test_x_robots_tag_header(self):
+        """
+        Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+        to not index, archive, or follow links in media.
+        """
+        channel = self._req(b"inline; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"X-Robots-Tag"),
+            [b"noindex, nofollow, noarchive, noimageindex"],
+        )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index ccdc8c2ecf..529b6bcded 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,41 +18,15 @@ import re
 
 from mock import patch
 
-import attr
-
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
 
 from tests import unittest
 from tests.server import FakeTransport
 
 
-@attr.s
-class FakeResponse:
-    version = attr.ib()
-    code = attr.ib()
-    phrase = attr.ib()
-    headers = attr.ib()
-    body = attr.ib()
-    absoluteURI = attr.ib()
-
-    @property
-    def request(self):
-        @attr.s
-        class FakeTransport:
-            absoluteURI = self.absoluteURI
-
-        return FakeTransport()
-
-    def deliverBody(self, protocol):
-        protocol.dataReceived(self.body)
-        protocol.connectionLost(Failure(ResponseDone()))
-
-
 class URLPreviewTests(unittest.HomeserverTestCase):
 
     hijack_auth = True
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..4faf32e335 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -216,8 +216,9 @@ def make_request(
         and not path.startswith(b"/_matrix")
         and not path.startswith(b"/_synapse")
     ):
+        if path.startswith(b"/"):
+            path = path[1:]
         path = b"/_matrix/client/r0/" + path
-        path = path.replace(b"//", b"/")
 
     if not path.startswith(b"/"):
         path = b"/" + path
@@ -258,6 +259,7 @@ def make_request(
         for k, v in custom_headers:
             req.requestHeaders.addRawHeader(k, v)
 
+    req.parseCookies()
     req.requestReceived(method, path, b"1.1")
 
     if await_result:
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..09f4f32a02 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
 from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+    _get_auth_chain_difference,
+    lexicographical_topological_sort,
+    resolve_events_with_store,
+)
 from synapse.types import EventID
 
 from tests import unittest
@@ -587,6 +591,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
         self.assert_dict(self.expected_combined_state, state)
 
 
+class AuthChainDifferenceTestCase(unittest.TestCase):
+    """We test that `_get_auth_chain_difference` correctly handles unpersisted
+    events.
+    """
+
+    def test_simple(self):
+        # Test getting the auth difference for a simple chain with a single
+        # unpersisted event:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #           C -|-> B -> A
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c}
+
+        state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {c.event_id})
+
+    def test_multiple_unpersisted_chain(self):
+        # Test getting the auth difference for a simple chain with multiple
+        # unpersisted events:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #      D -> C -|-> B -> A
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        d = FakeEvent(
+            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c, d.event_id: d}
+
+        state_sets = [
+            {"a": a.event_id, "b": b.event_id},
+            {"c": c.event_id, "d": d.event_id},
+        ]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {d.event_id, c.event_id})
+
+    def test_unpersisted_events_different_sets(self):
+        # Test getting the auth difference for with multiple unpersisted events
+        # in different branches:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #     D --> C -|-> B -> A
+        #     E ----^ -|---^
+        #              |
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        d = FakeEvent(
+            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id], [])
+
+        e = FakeEvent(
+            id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id, b.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+        state_sets = [
+            {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+            {"c": c.event_id, "d": d.event_id},
+        ]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {d.event_id, e.event_id})
+
+
 def pairwise(iterable):
     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
     a, b = itertools.tee(iterable)
@@ -647,7 +779,7 @@ class TestStateResolutionStore:
 
         return list(result)
 
-    def get_auth_chain_difference(self, auth_sets):
+    def get_auth_chain_difference(self, room_id, auth_sets):
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 35dafbb904..3d7760d5d9 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -26,7 +26,7 @@ room_key = {
 
 class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.store = hs.get_datastore()
         return hs
 
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..482506d731 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -202,34 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "d", "e"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
-        difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
         self.assertSetEqual(difference, set())
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index cc1f3c53c5..a06ad2c03e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
     servlets = [room.register_servlets]
 
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d4f9e809db..a6303bf0ee 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from mock import Mock
-
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
 
 
 class RedactionTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
-        config = self.default_config()
+    def default_config(self):
+        config = super().default_config()
         config["redaction_retention_period"] = "30d"
-        return self.setup_test_homeserver(
-            resource_for_federation=Mock(), http_client=None, config=config
-        )
+        return config
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index ff972daeaa..d2aed66f6d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,8 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
-
 from synapse.api.constants import Membership
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client.v1 import login, room
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(
-            resource_for_federation=Mock(), http_client=None
-        )
-        return hs
-
     def prepare(self, reactor, clock, hs: TestHomeServer):
 
         # We can't test the RoomMemberStore on its own without the other event
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 1ce4ea3a01..fa45f8b3b7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -37,7 +37,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.hs_clock = Clock(self.reactor)
         self.homeserver = setup_test_homeserver(
             self.addCleanup,
-            http_client=self.http_client,
+            federation_http_client=self.http_client,
             clock=self.hs_clock,
             reactor=self.reactor,
         )
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 7f67ee9e1f..a883d707df 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -56,7 +56,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -69,7 +69,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø lies in Northern Norway. The municipality has a population of"
             " (2015) 72,066, but with an annual influx of students it has over 75,000"
@@ -96,7 +96,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -122,7 +122,7 @@ class PreviewTestCase(unittest.TestCase):
         ]
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -149,7 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment(self):
         html = """
@@ -164,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment2(self):
         html = """
@@ -182,7 +182,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(
+        self.assertEqual(
             og,
             {
                 "og:title": "Foo",
@@ -203,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_missing_title(self):
         html = """
@@ -216,7 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
     def test_h1_as_title(self):
         html = """
@@ -230,7 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
     def test_missing_title_and_broken_h1(self):
         html = """
@@ -244,4 +244,9 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+
+    def test_empty(self):
+        html = ""
+        og = decode_and_calc_og(html, "http://example.com/test.html")
+        self.assertEqual(og, {})
diff --git a/tests/test_server.py b/tests/test_server.py
index c387a85f2e..6b2d2f0401 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase):
         self.reactor = ThreadedMemoryReactorClock()
         self.hs_clock = Clock(self.reactor)
         self.homeserver = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
+            self.addCleanup,
+            federation_http_client=None,
+            clock=self.hs_clock,
+            reactor=self.reactor,
         )
 
     def test_handler_for_request(self):
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..6873d45eb6 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,11 @@ import warnings
 from asyncio import Future
 from typing import Any, Awaitable, Callable, TypeVar
 
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
 TV = TypeVar("TV")
 
 
@@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
     sys.unraisablehook = unraisablehook  # type: ignore
 
     return cleanup
+
+
+@attr.s
+class FakeResponse:
+    """A fake twisted.web.IResponse object
+
+    there is a similar class at treq.test.test_response, but it lacks a `phrase`
+    attribute, and didn't support deliverBody until recently.
+    """
+
+    # HTTP response code
+    code = attr.ib(type=int)
+
+    # HTTP response phrase (eg b'OK' for a 200)
+    phrase = attr.ib(type=bytes)
+
+    # body of the response
+    body = attr.ib(type=bytes)
+
+    def deliverBody(self, protocol):
+        protocol.dataReceived(self.body)
+        protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..102b0a1f34 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
 import inspect
 import logging
 import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
 
 from mock import Mock, patch
 
@@ -46,6 +46,7 @@ from synapse.logging.context import (
 )
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
         """
         Create a the root resource for the test server.
 
-        The default implementation creates a JsonResource and calls each function in
-        `servlets` to register servletes against it
+        The default calls `self.create_resource_dict` and builds the resultant dict
+        into a tree.
         """
-        resource = JsonResource(self.hs)
+        root_resource = Resource()
+        create_resource_tree(self.create_resource_dict(), root_resource)
+        return root_resource
 
-        for servlet in self.servlets:
-            servlet(self.hs, resource)
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        """Create a resource tree for the test server
 
-        return resource
+        A resource tree is a mapping from path to twisted.web.resource.
+
+        The default implementation creates a JsonResource and calls each function in
+        `servlets` to register servlets against it.
+        """
+        servlet_resource = JsonResource(self.hs)
+        for servlet in self.servlets:
+            servlet(self.hs, servlet_resource)
+        return {
+            "/_matrix/client": servlet_resource,
+            "/_synapse/admin": servlet_resource,
+        }
 
     def default_config(self):
         """
@@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
     A federating homeserver that authenticates incoming requests as `other.example.com`.
     """
 
-    def prepare(self, reactor, clock, homeserver):
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+        return d
+
+
+class TestTransportLayerServer(JsonResource):
+    """A test implementation of TransportLayerServer
+
+    authenticates incoming requests as `other.example.com`.
+    """
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
         class Authenticator:
             def authenticate_request(self, request, content):
                 return succeed("other.example.com")
 
+        authenticator = Authenticator()
+
         ratelimiter = FederationRateLimiter(
-            clock,
+            hs.get_clock(),
             FederationRateLimitConfig(
                 window_size=1,
                 sleep_limit=1,
@@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
                 concurrent_requests=1000,
             ),
         )
-        federation_server.register_servlets(
-            homeserver, self.resource, Authenticator(), ratelimiter
-        )
 
-        return super().prepare(reactor, clock, homeserver)
+        federation_server.register_servlets(hs, self, authenticator, ratelimiter)
 
 
 def override_config(extra_config):
diff --git a/tests/utils.py b/tests/utils.py
index c8d3ffbaba..977eeaf6ee 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,13 +20,12 @@ import os
 import time
 import uuid
 import warnings
-from inspect import getcallargs
 from typing import Type
 from urllib import parse as urlparse
 
 from mock import Mock, patch
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
@@ -34,7 +33,6 @@ from synapse.api.room_versions import RoomVersions
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.federation.transport import server as federation_server
 from synapse.http.server import HttpServer
 from synapse.logging.context import current_context, set_current_context
 from synapse.server import HomeServer
@@ -42,7 +40,6 @@ from synapse.storage import DataStore
 from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import prepare_database
-from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
 #
@@ -342,32 +339,9 @@ def setup_test_homeserver(
 
     hs.get_auth_handler().validate_hash = validate_hash
 
-    fed = kwargs.get("resource_for_federation", None)
-    if fed:
-        register_federation_servlets(hs, fed)
-
     return hs
 
 
-def register_federation_servlets(hs, resource):
-    federation_server.register_servlets(
-        hs,
-        resource=resource,
-        authenticator=federation_server.Authenticator(hs),
-        ratelimiter=FederationRateLimiter(
-            hs.get_clock(), config=hs.config.rc_federation
-        ),
-    )
-
-
-def get_mock_call_args(pattern_func, mock_func):
-    """ Return the arguments the mock function was called with interpreted
-    by the pattern functions argument list.
-    """
-    invoked_args, invoked_kargs = mock_func.call_args
-    return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
-
-
 def mock_getRawHeaders(headers=None):
     headers = headers if headers is not None else {}
 
@@ -553,86 +527,6 @@ class MockClock:
         return d
 
 
-def _format_call(args, kwargs):
-    return ", ".join(
-        ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
-    )
-
-
-class DeferredMockCallable:
-    """A callable instance that stores a set of pending call expectations and
-    return values for them. It allows a unit test to assert that the given set
-    of function calls are eventually made, by awaiting on them to be called.
-    """
-
-    def __init__(self):
-        self.expectations = []
-        self.calls = []
-
-    def __call__(self, *args, **kwargs):
-        self.calls.append((args, kwargs))
-
-        if not self.expectations:
-            raise ValueError(
-                "%r has no pending calls to handle call(%s)"
-                % (self, _format_call(args, kwargs))
-            )
-
-        for (call, result, d) in self.expectations:
-            if args == call[1] and kwargs == call[2]:
-                d.callback(None)
-                return result
-
-        failure = AssertionError(
-            "Was not expecting call(%s)" % (_format_call(args, kwargs))
-        )
-
-        for _, _, d in self.expectations:
-            try:
-                d.errback(failure)
-            except Exception:
-                pass
-
-        raise failure
-
-    def expect_call_and_return(self, call, result):
-        self.expectations.append((call, result, defer.Deferred()))
-
-    @defer.inlineCallbacks
-    def await_calls(self, timeout=1000):
-        deferred = defer.DeferredList(
-            [d for _, _, d in self.expectations], fireOnOneErrback=True
-        )
-
-        timer = reactor.callLater(
-            timeout / 1000,
-            deferred.errback,
-            AssertionError(
-                "%d pending calls left: %s"
-                % (
-                    len([e for e in self.expectations if not e[2].called]),
-                    [e for e in self.expectations if not e[2].called],
-                )
-            ),
-        )
-
-        yield deferred
-
-        timer.cancel()
-
-        self.calls = []
-
-    def assert_had_no_calls(self):
-        if self.calls:
-            calls = self.calls
-            self.calls = []
-
-            raise AssertionError(
-                "Expected not to received any calls, got:\n"
-                + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
-            )
-
-
 async def create_room(hs, room_id: str, creator_id: str):
     """Creates and persist a creation event for the given room
     """