summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--changelog.d/8504.bugfix1
-rw-r--r--changelog.d/8544.feature1
-rw-r--r--changelog.d/8545.bugfix1
-rw-r--r--changelog.d/8561.misc1
-rw-r--r--changelog.d/8562.misc1
-rw-r--r--changelog.d/8563.misc1
-rw-r--r--changelog.d/8564.feature1
-rw-r--r--changelog.d/8566.misc1
-rw-r--r--changelog.d/8567.bugfix1
-rw-r--r--changelog.d/8568.misc1
-rw-r--r--changelog.d/8569.misc1
-rw-r--r--changelog.d/8571.misc1
-rw-r--r--changelog.d/8572.misc1
-rw-r--r--changelog.d/8577.misc1
-rw-r--r--changelog.d/8578.misc1
-rw-r--r--changelog.d/8583.misc1
-rw-r--r--changelog.d/8585.bugfix1
-rw-r--r--changelog.d/8587.misc1
-rw-r--r--changelog.d/8589.removal1
-rw-r--r--changelog.d/8590.misc1
-rw-r--r--changelog.d/8591.misc1
-rw-r--r--changelog.d/8592.misc1
-rw-r--r--changelog.d/8593.misc1
-rw-r--r--changelog.d/8594.misc1
-rw-r--r--changelog.d/8599.feature1
-rw-r--r--changelog.d/8600.misc1
-rw-r--r--changelog.d/8606.feature1
-rw-r--r--changelog.d/8609.misc1
-rw-r--r--mypy.ini4
-rwxr-xr-xscripts/synapse_port_db5
-rwxr-xr-xsetup.py2
-rw-r--r--synapse/api/auth.py6
-rw-r--r--synapse/appservice/scheduler.py18
-rw-r--r--synapse/events/builder.py2
-rw-r--r--synapse/handlers/_base.py20
-rw-r--r--synapse/handlers/account_validity.py12
-rw-r--r--synapse/handlers/auth.py8
-rw-r--r--synapse/handlers/initial_sync.py8
-rw-r--r--synapse/handlers/message.py7
-rw-r--r--synapse/handlers/profile.py84
-rw-r--r--synapse/logging/_remote.py225
-rw-r--r--synapse/logging/_terse_json.py199
-rw-r--r--synapse/metrics/background_process_metrics.py11
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/push/mailer.py16
-rw-r--r--synapse/push/push_rule_evaluator.py18
-rw-r--r--synapse/replication/slave/storage/client_ips.py10
-rw-r--r--synapse/res/templates/notif.html56
-rw-r--r--synapse/res/templates/notif.txt24
-rw-r--r--synapse/res/templates/notif_mail.html26
-rw-r--r--synapse/res/templates/notif_mail.txt6
-rw-r--r--synapse/res/templates/room.html26
-rw-r--r--synapse/res/templates/room.txt12
-rw-r--r--synapse/rest/client/v1/login.py2
-rw-r--r--synapse/storage/_base.py12
-rw-r--r--synapse/storage/database.py2
-rw-r--r--synapse/storage/databases/main/client_ips.py8
-rw-r--r--synapse/storage/databases/main/devices.py8
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/events_worker.py24
-rw-r--r--synapse/storage/databases/main/profile.py6
-rw-r--r--synapse/storage/databases/main/pusher.py2
-rw-r--r--synapse/storage/databases/main/receipts.py11
-rw-r--r--synapse/storage/databases/main/roommember.py18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql (renamed from synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql)5
-rw-r--r--synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql1
-rw-r--r--synapse/storage/types.py4
-rw-r--r--synapse/util/caches/__init__.py13
-rw-r--r--synapse/util/caches/deferred_cache.py146
-rw-r--r--synapse/util/caches/descriptors.py44
-rw-r--r--synapse/util/caches/dictionary_cache.py29
-rw-r--r--synapse/util/caches/lrucache.py135
-rw-r--r--synmark/__init__.py9
-rw-r--r--tests/appservice/test_scheduler.py41
-rw-r--r--tests/logging/test_terse_json.py2
-rw-r--r--tests/push/test_email.py15
-rw-r--r--tests/rest/client/v2_alpha/test_account.py1
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py1
-rw-r--r--tests/storage/test__base.py228
-rw-r--r--tests/util/caches/test_deferred_cache.py140
-rw-r--r--tests/util/caches/test_descriptors.py281
-rw-r--r--tests/util/test_lrucache.py12
-rw-r--r--tox.ini5
84 files changed, 1264 insertions, 781 deletions
diff --git a/.gitignore b/.gitignore
index af36c00cfa..9bb5bdd647 100644
--- a/.gitignore
+++ b/.gitignore
@@ -21,6 +21,7 @@ _trial_temp*/
 /.python-version
 /*.signing.key
 /env/
+/.venv*/
 /homeserver*.yaml
 /logs
 /media_store/
diff --git a/changelog.d/8504.bugfix b/changelog.d/8504.bugfix
new file mode 100644
index 0000000000..2bd0dbb8b4
--- /dev/null
+++ b/changelog.d/8504.bugfix
@@ -0,0 +1 @@
+Expose the `uk.half-shot.msc2778.login.application_service` to clients from the login API. This feature was added in v1.21.0, but was not exposed as a potential login flow.
diff --git a/changelog.d/8544.feature b/changelog.d/8544.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8544.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8545.bugfix b/changelog.d/8545.bugfix
new file mode 100644
index 0000000000..64ba307df0
--- /dev/null
+++ b/changelog.d/8545.bugfix
@@ -0,0 +1 @@
+Fix a long standing bug where email notifications for encrypted messages were blank.
diff --git a/changelog.d/8561.misc b/changelog.d/8561.misc
new file mode 100644
index 0000000000..a40dedfa8e
--- /dev/null
+++ b/changelog.d/8561.misc
@@ -0,0 +1 @@
+Move metric registration code down into `LruCache`.
diff --git a/changelog.d/8562.misc b/changelog.d/8562.misc
new file mode 100644
index 0000000000..ebdbddb500
--- /dev/null
+++ b/changelog.d/8562.misc
@@ -0,0 +1 @@
+Add type annotations for `LruCache`.
diff --git a/changelog.d/8563.misc b/changelog.d/8563.misc
new file mode 100644
index 0000000000..eeba8e5fee
--- /dev/null
+++ b/changelog.d/8563.misc
@@ -0,0 +1 @@
+Replace `DeferredCache` with the lighter-weight `LruCache` where possible.
diff --git a/changelog.d/8564.feature b/changelog.d/8564.feature
new file mode 100644
index 0000000000..45342e66ad
--- /dev/null
+++ b/changelog.d/8564.feature
@@ -0,0 +1 @@
+Support modifying event content in `ThirdPartyRules` modules.
diff --git a/changelog.d/8566.misc b/changelog.d/8566.misc
new file mode 100644
index 0000000000..453cf48ffa
--- /dev/null
+++ b/changelog.d/8566.misc
@@ -0,0 +1 @@
+Add virtualenv-generated folders to `.gitignore`.
\ No newline at end of file
diff --git a/changelog.d/8567.bugfix b/changelog.d/8567.bugfix
new file mode 100644
index 0000000000..4d835df6fd
--- /dev/null
+++ b/changelog.d/8567.bugfix
@@ -0,0 +1 @@
+Fix increase in the number of `There was no active span...` errors logged when using OpenTracing.
diff --git a/changelog.d/8568.misc b/changelog.d/8568.misc
new file mode 100644
index 0000000000..0ed7db92d3
--- /dev/null
+++ b/changelog.d/8568.misc
@@ -0,0 +1 @@
+Add `get_immediate` method to `DeferredCache`.
diff --git a/changelog.d/8569.misc b/changelog.d/8569.misc
new file mode 100644
index 0000000000..3b6e0625e5
--- /dev/null
+++ b/changelog.d/8569.misc
@@ -0,0 +1 @@
+Fix mypy not properly checking across the codebase, additionally, fix a typing assertion error in `handlers/auth.py`.
\ No newline at end of file
diff --git a/changelog.d/8571.misc b/changelog.d/8571.misc
new file mode 100644
index 0000000000..f6a65057e0
--- /dev/null
+++ b/changelog.d/8571.misc
@@ -0,0 +1 @@
+Fix `synmark` benchmark runner.
diff --git a/changelog.d/8572.misc b/changelog.d/8572.misc
new file mode 100644
index 0000000000..ea2a6d340d
--- /dev/null
+++ b/changelog.d/8572.misc
@@ -0,0 +1 @@
+Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s.
diff --git a/changelog.d/8577.misc b/changelog.d/8577.misc
new file mode 100644
index 0000000000..75fe563a02
--- /dev/null
+++ b/changelog.d/8577.misc
@@ -0,0 +1 @@
+Adjust a protocol-type definition to fit `sqlite3` assertions.
\ No newline at end of file
diff --git a/changelog.d/8578.misc b/changelog.d/8578.misc
new file mode 100644
index 0000000000..e93462255b
--- /dev/null
+++ b/changelog.d/8578.misc
@@ -0,0 +1 @@
+Support macOS on the `synmark` benchmark runner.
diff --git a/changelog.d/8583.misc b/changelog.d/8583.misc
new file mode 100644
index 0000000000..d24973f09a
--- /dev/null
+++ b/changelog.d/8583.misc
@@ -0,0 +1 @@
+Update `mypy` static type checker to 0.790.
\ No newline at end of file
diff --git a/changelog.d/8585.bugfix b/changelog.d/8585.bugfix
new file mode 100644
index 0000000000..e97e6ac1d8
--- /dev/null
+++ b/changelog.d/8585.bugfix
@@ -0,0 +1 @@
+Fix a bug that prevented errors encountered during execution of the `synapse_port_db` from being correctly printed.
\ No newline at end of file
diff --git a/changelog.d/8587.misc b/changelog.d/8587.misc
new file mode 100644
index 0000000000..9e56551a34
--- /dev/null
+++ b/changelog.d/8587.misc
@@ -0,0 +1 @@
+Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting.
diff --git a/changelog.d/8589.removal b/changelog.d/8589.removal
new file mode 100644
index 0000000000..b80f29d6bb
--- /dev/null
+++ b/changelog.d/8589.removal
@@ -0,0 +1 @@
+Drop unused `device_max_stream_id` table.
diff --git a/changelog.d/8590.misc b/changelog.d/8590.misc
new file mode 100644
index 0000000000..4abcccb326
--- /dev/null
+++ b/changelog.d/8590.misc
@@ -0,0 +1 @@
+Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices.
diff --git a/changelog.d/8591.misc b/changelog.d/8591.misc
new file mode 100644
index 0000000000..8f16bc3e7e
--- /dev/null
+++ b/changelog.d/8591.misc
@@ -0,0 +1 @@
+ Move metric registration code down into `LruCache`.
diff --git a/changelog.d/8592.misc b/changelog.d/8592.misc
new file mode 100644
index 0000000000..099e8fb7bb
--- /dev/null
+++ b/changelog.d/8592.misc
@@ -0,0 +1 @@
+Remove extraneous unittest logging decorators from unit tests.
\ No newline at end of file
diff --git a/changelog.d/8593.misc b/changelog.d/8593.misc
new file mode 100644
index 0000000000..d266ba19a4
--- /dev/null
+++ b/changelog.d/8593.misc
@@ -0,0 +1 @@
+Minor optimisations in caching code.
diff --git a/changelog.d/8594.misc b/changelog.d/8594.misc
new file mode 100644
index 0000000000..d266ba19a4
--- /dev/null
+++ b/changelog.d/8594.misc
@@ -0,0 +1 @@
+Minor optimisations in caching code.
diff --git a/changelog.d/8599.feature b/changelog.d/8599.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8599.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8600.misc b/changelog.d/8600.misc
new file mode 100644
index 0000000000..a5a922e641
--- /dev/null
+++ b/changelog.d/8600.misc
@@ -0,0 +1 @@
+Update `mypy` static type checker to 0.790.
diff --git a/changelog.d/8606.feature b/changelog.d/8606.feature
new file mode 100644
index 0000000000..fad723c108
--- /dev/null
+++ b/changelog.d/8606.feature
@@ -0,0 +1 @@
+Limit appservice transactions to 100 persistent and 100 ephemeral events.
diff --git a/changelog.d/8609.misc b/changelog.d/8609.misc
new file mode 100644
index 0000000000..5e3f3c1993
--- /dev/null
+++ b/changelog.d/8609.misc
@@ -0,0 +1 @@
+Add type hints to profile and base handler.
diff --git a/mypy.ini b/mypy.ini
index b5db54ee3b..5e9f7b1259 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -15,8 +15,9 @@ files =
   synapse/events/builder.py,
   synapse/events/spamcheck.py,
   synapse/federation,
-  synapse/handlers/appservice.py,
+  synapse/handlers/_base.py,
   synapse/handlers/account_data.py,
+  synapse/handlers/appservice.py,
   synapse/handlers/auth.py,
   synapse/handlers/cas_handler.py,
   synapse/handlers/deactivate_account.py,
@@ -32,6 +33,7 @@ files =
   synapse/handlers/pagination.py,
   synapse/handlers/password_policy.py,
   synapse/handlers/presence.py,
+  synapse/handlers/profile.py,
   synapse/handlers/read_marker.py,
   synapse/handlers/room.py,
   synapse/handlers/room_member.py,
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 2d0b59ab53..6c7664ad4a 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -22,6 +22,7 @@ import logging
 import sys
 import time
 import traceback
+from typing import Optional
 
 import yaml
 
@@ -152,7 +153,7 @@ IGNORED_TABLES = {
 
 # Error returned by the run function. Used at the top-level part of the script to
 # handle errors and return codes.
-end_error = None
+end_error = None  # type: Optional[str]
 # The exec_info for the error, if any. If error is defined but not exec_info the script
 # will show only the error message without the stacktrace, if exec_info is defined but
 # not the error then the script will show nothing outside of what's printed in the run
@@ -635,7 +636,7 @@ class Porter(object):
             self.progress.done()
         except Exception as e:
             global end_error_exec_info
-            end_error = e
+            end_error = str(e)
             end_error_exec_info = sys.exc_info()
             logger.exception("")
         finally:
diff --git a/setup.py b/setup.py
index 08843fe2a3..2f4a3170d2 100755
--- a/setup.py
+++ b/setup.py
@@ -102,6 +102,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
     "flake8",
 ]
 
+CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
+
 # Dependencies which are exclusively required by unit test code. This is
 # NOT a list of all modules that are necessary to run the unit tests.
 # Tests assume that all optional dependencies are installed.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1071a0576e..bff87fabde 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -34,7 +34,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.logging import opentracing as opentracing
 from synapse.types import StateMap, UserID
-from synapse.util.caches import register_cache
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
 
@@ -70,8 +69,9 @@ class Auth:
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
 
-        self.token_cache = LruCache(10000)
-        register_cache("cache", "token_cache", self.token_cache)
+        self.token_cache = LruCache(
+            10000, "token_cache"
+        )  # type: LruCache[str, Tuple[str, bool]]
 
         self._auth_blocking = AuthBlocking(self.hs)
 
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index ad3c408519..58291afc22 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -60,6 +60,13 @@ from synapse.types import JsonDict
 logger = logging.getLogger(__name__)
 
 
+# Maximum number of events to provide in an AS transaction.
+MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
+
+# Maximum number of ephemeral events to provide in an AS transaction.
+MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
+
+
 class ApplicationServiceScheduler:
     """ Public facing API for this module. Does the required DI to tie the
     components together. This also serves as the "event_pool", which in this
@@ -136,10 +143,17 @@ class _ServiceQueuer:
         self.requests_in_flight.add(service.id)
         try:
             while True:
-                events = self.queued_events.pop(service.id, [])
-                ephemeral = self.queued_ephemeral.pop(service.id, [])
+                all_events = self.queued_events.get(service.id, [])
+                events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
+                del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
+
+                all_events_ephemeral = self.queued_ephemeral.get(service.id, [])
+                ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
+                del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
+
                 if not events and not ephemeral:
                     return
+
                 try:
                     await self.txn_ctrl.send(service, events, ephemeral)
                 except Exception:
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index df4f950fec..07df258e6e 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -98,7 +98,7 @@ class EventBuilder:
         return self._state_key is not None
 
     async def build(
-        self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
+        self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
     ) -> EventBase:
         """Transform into a fully signed and hashed event
 
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 0206320e96..bd8e71ae56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
 
 import synapse.state
 import synapse.storage
@@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.types import UserID
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,11 +34,7 @@ class BaseHandler:
     Common base class for the event handlers.
     """
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()  # type: synapse.storage.DataStore
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
@@ -56,7 +56,7 @@ class BaseHandler:
                 clock=self.clock,
                 rate_hz=self.hs.config.rc_admin_redaction.per_second,
                 burst_count=self.hs.config.rc_admin_redaction.burst_count,
-            )
+            )  # type: Optional[Ratelimiter]
         else:
             self.admin_redaction_ratelimiter = None
 
@@ -127,15 +127,15 @@ class BaseHandler:
             if guest_access != "can_join":
                 if context:
                     current_state_ids = await context.get_current_state_ids()
-                    current_state = await self.store.get_events(
+                    current_state_dict = await self.store.get_events(
                         list(current_state_ids.values())
                     )
+                    current_state = list(current_state_dict.values())
                 else:
-                    current_state = await self.state_handler.get_current_state(
+                    current_state_map = await self.state_handler.get_current_state(
                         event.room_id
                     )
-
-                current_state = list(current_state.values())
+                    current_state = list(current_state_map.values())
 
                 logger.info("maybe_kick_guest_users %r", current_state)
                 await self.kick_guest_users(current_state)
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index f33044e97a..fd4f762f33 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -22,7 +22,7 @@ from typing import List
 
 from synapse.api.errors import StoreError
 from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
 
@@ -63,16 +63,10 @@ class AccountValidityHandler:
             self._raw_from = email.utils.parseaddr(self._from_string)[1]
 
             # Check the renewal emails to send and send them every 30min.
-            def send_emails():
-                # run as a background process to make sure that the database transactions
-                # have a logcontext to report to
-                return run_as_background_process(
-                    "send_renewals", self._send_renewal_emails
-                )
-
             if hs.config.run_background_tasks:
-                self.clock.looping_call(send_emails, 30 * 60 * 1000)
+                self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
 
+    @wrap_as_background_process("send_renewals")
     async def _send_renewal_emails(self):
         """Gets the list of users whose account is expiring in the amount of time
         configured in the ``renew_at`` parameter from the ``account_validity``
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1d1ddc2245..8619fbb982 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1122,20 +1122,22 @@ class AuthHandler(BaseHandler):
             Whether self.hash(password) == stored_hash.
         """
 
-        def _do_validate_hash():
+        def _do_validate_hash(checked_hash: bytes):
             # Normalise the Unicode in the password
             pw = unicodedata.normalize("NFKC", password)
 
             return bcrypt.checkpw(
                 pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
-                stored_hash,
+                checked_hash,
             )
 
         if stored_hash:
             if not isinstance(stored_hash, bytes):
                 stored_hash = stored_hash.encode("ascii")
 
-            return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
+            return await defer_to_thread(
+                self.hs.get_reactor(), _do_validate_hash, stored_hash
+            )
         else:
             return False
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 98075f48d2..cb11754bf8 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
                 user_id, room_id, pagin_config, membership, is_peeking
             )
         elif membership == Membership.LEAVE:
+            # The member_event_id will always be available if membership is set
+            # to leave.
+            assert member_event_id
+
             result = await self._room_initial_sync_parted(
                 user_id, room_id, pagin_config, membership, member_event_id, is_peeking
             )
@@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
         user_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
-        membership: Membership,
+        membership: str,
         member_event_id: str,
         is_peeking: bool,
     ) -> JsonDict:
@@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
         user_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
-        membership: Membership,
+        membership: str,
         is_peeking: bool,
     ) -> JsonDict:
         current_state = await self.state.get_current_state(room_id=room_id)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e37bca3d12..e4f2115617 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1364,7 +1364,12 @@ class EventCreationHandler:
         for k, v in original_event.internal_metadata.get_dict().items():
             setattr(builder.internal_metadata, k, v)
 
-        event = await builder.build(prev_event_ids=original_event.prev_event_ids())
+        # the event type hasn't changed, so there's no point in re-calculating the
+        # auth events.
+        event = await builder.build(
+            prev_event_ids=original_event.prev_event_ids(),
+            auth_event_ids=original_event.auth_event_ids(),
+        )
 
         # we rebuild the event context, to be on the safe side. If nothing else,
         # delta_ids might need an update.
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index b784938755..92700b589c 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -12,9 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import random
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import (
     AuthError,
@@ -24,11 +24,20 @@ from synapse.api.errors import (
     StoreError,
     SynapseError,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester, get_domain_from_id
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.types import (
+    JsonDict,
+    Requester,
+    UserID,
+    create_requester,
+    get_domain_from_id,
+)
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 MAX_DISPLAYNAME_LEN = 256
@@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
     PROFILE_UPDATE_MS = 60 * 1000
     PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.federation = hs.get_federation_client()
@@ -57,10 +66,10 @@ class ProfileHandler(BaseHandler):
 
         if hs.config.run_background_tasks:
             self.clock.looping_call(
-                self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+                self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
             )
 
-    async def get_profile(self, user_id):
+    async def get_profile(self, user_id: str) -> JsonDict:
         target_user = UserID.from_string(user_id)
 
         if self.hs.is_mine(target_user):
@@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-    async def get_profile_from_cache(self, user_id):
+    async def get_profile_from_cache(self, user_id: str) -> JsonDict:
         """Get the profile information from our local cache. If the user is
         ours then the profile information will always be corect. Otherwise,
         it may be out of date/missing.
@@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
             profile = await self.store.get_from_remote_profile_cache(user_id)
             return profile or {}
 
-    async def get_displayname(self, target_user):
+    async def get_displayname(self, target_user: UserID) -> str:
         if self.hs.is_mine(target_user):
             try:
                 displayname = await self.store.get_profile_displayname(
@@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
             return result["displayname"]
 
     async def set_displayname(
-        self, target_user, requester, new_displayname, by_admin=False
-    ):
+        self,
+        target_user: UserID,
+        requester: Requester,
+        new_displayname: str,
+        by_admin: bool = False,
+    ) -> None:
         """Set the displayname of a user
 
         Args:
-            target_user (UserID): the user whose displayname is to be changed.
-            requester (Requester): The user attempting to make this change.
-            new_displayname (str): The displayname to give this user.
-            by_admin (bool): Whether this change was made by an administrator.
+            target_user: the user whose displayname is to be changed.
+            requester: The user attempting to make this change.
+            new_displayname: The displayname to give this user.
+            by_admin: Whether this change was made by an administrator.
         """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
                 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
             )
 
+        displayname_to_set = new_displayname  # type: Optional[str]
         if new_displayname == "":
-            new_displayname = None
+            displayname_to_set = None
 
         # If the admin changes the display name of a user, the requesting user cannot send
         # the join event to update the displayname in the rooms.
@@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
         if by_admin:
             requester = create_requester(target_user)
 
-        await self.store.set_profile_displayname(target_user.localpart, new_displayname)
+        await self.store.set_profile_displayname(
+            target_user.localpart, displayname_to_set
+        )
 
         if self.hs.config.user_directory_search_all_users:
             profile = await self.store.get_profileinfo(target_user.localpart)
@@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    async def get_avatar_url(self, target_user):
+    async def get_avatar_url(self, target_user: UserID) -> str:
         if self.hs.is_mine(target_user):
             try:
                 avatar_url = await self.store.get_profile_avatar_url(
@@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
             return result["avatar_url"]
 
     async def set_avatar_url(
-        self, target_user, requester, new_avatar_url, by_admin=False
+        self,
+        target_user: UserID,
+        requester: Requester,
+        new_avatar_url: str,
+        by_admin: bool = False,
     ):
         """Set a new avatar URL for a user.
 
         Args:
-            target_user (UserID): the user whose avatar URL is to be changed.
-            requester (Requester): The user attempting to make this change.
-            new_avatar_url (str): The avatar URL to give this user.
-            by_admin (bool): Whether this change was made by an administrator.
+            target_user: the user whose avatar URL is to be changed.
+            requester: The user attempting to make this change.
+            new_avatar_url: The avatar URL to give this user.
+            by_admin: Whether this change was made by an administrator.
         """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    async def on_profile_query(self, args):
+    async def on_profile_query(self, args: JsonDict) -> JsonDict:
         user = UserID.from_string(args["user_id"])
         if not self.hs.is_mine(user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
 
         return response
 
-    async def _update_join_states(self, requester, target_user):
+    async def _update_join_states(
+        self, requester: Requester, target_user: UserID
+    ) -> None:
         if not self.hs.is_mine(target_user):
             return
 
@@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
                     "Failed to update join event for room %s - %s", room_id, str(e)
                 )
 
-    async def check_profile_query_allowed(self, target_user, requester=None):
+    async def check_profile_query_allowed(
+        self, target_user: UserID, requester: Optional[UserID] = None
+    ) -> None:
         """Checks whether a profile query is allowed. If the
         'require_auth_for_profile_requests' config flag is set to True and a
         'requester' is provided, the query is only allowed if the two users
         share a room.
 
         Args:
-            target_user (UserID): The owner of the queried profile.
-            requester (None|UserID): The user querying for the profile.
+            target_user: The owner of the queried profile.
+            requester: The user querying for the profile.
 
         Raises:
             SynapseError(403): The two users share no room, or ne user couldn't
@@ -370,11 +394,7 @@ class ProfileHandler(BaseHandler):
                 raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
             raise
 
-    def _start_update_remote_profile_cache(self):
-        return run_as_background_process(
-            "Update remote profile", self._update_remote_profile_cache
-        )
-
+    @wrap_as_background_process("Update remote profile")
     async def _update_remote_profile_cache(self):
         """Called periodically to check profiles of remote users we haven't
         checked in a while.
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
new file mode 100644
index 0000000000..0caf325916
--- /dev/null
+++ b/synapse/logging/_remote.py
@@ -0,0 +1,225 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import traceback
+from collections import deque
+from ipaddress import IPv4Address, IPv6Address, ip_address
+from math import floor
+from typing import Callable, Optional
+
+import attr
+from zope.interface import implementer
+
+from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
+from twisted.internet.endpoints import (
+    HostnameEndpoint,
+    TCP4ClientEndpoint,
+    TCP6ClientEndpoint,
+)
+from twisted.internet.interfaces import IPushProducer, ITransport
+from twisted.internet.protocol import Factory, Protocol
+from twisted.logger import ILogObserver, Logger, LogLevel
+
+
+@attr.s
+@implementer(IPushProducer)
+class LogProducer:
+    """
+    An IPushProducer that writes logs from its buffer to its transport when it
+    is resumed.
+
+    Args:
+        buffer: Log buffer to read logs from.
+        transport: Transport to write to.
+        format_event: A callable to format the log entry to a string.
+    """
+
+    transport = attr.ib(type=ITransport)
+    format_event = attr.ib(type=Callable[[dict], str])
+    _buffer = attr.ib(type=deque)
+    _paused = attr.ib(default=False, type=bool, init=False)
+
+    def pauseProducing(self):
+        self._paused = True
+
+    def stopProducing(self):
+        self._paused = True
+        self._buffer = deque()
+
+    def resumeProducing(self):
+        self._paused = False
+
+        while self._paused is False and (self._buffer and self.transport.connected):
+            try:
+                # Request the next event and format it.
+                event = self._buffer.popleft()
+                msg = self.format_event(event)
+
+                # Send it as a new line over the transport.
+                self.transport.write(msg.encode("utf8"))
+            except Exception:
+                # Something has gone wrong writing to the transport -- log it
+                # and break out of the while.
+                traceback.print_exc(file=sys.__stderr__)
+                break
+
+
+@attr.s
+@implementer(ILogObserver)
+class TCPLogObserver:
+    """
+    An IObserver that writes JSON logs to a TCP target.
+
+    Args:
+        hs (HomeServer): The homeserver that is being logged for.
+        host: The host of the logging target.
+        port: The logging target's port.
+        format_event: A callable to format the log entry to a string.
+        maximum_buffer: The maximum buffer size.
+    """
+
+    hs = attr.ib()
+    host = attr.ib(type=str)
+    port = attr.ib(type=int)
+    format_event = attr.ib(type=Callable[[dict], str])
+    maximum_buffer = attr.ib(type=int)
+    _buffer = attr.ib(default=attr.Factory(deque), type=deque)
+    _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
+    _logger = attr.ib(default=attr.Factory(Logger))
+    _producer = attr.ib(default=None, type=Optional[LogProducer])
+
+    def start(self) -> None:
+
+        # Connect without DNS lookups if it's a direct IP.
+        try:
+            ip = ip_address(self.host)
+            if isinstance(ip, IPv4Address):
+                endpoint = TCP4ClientEndpoint(
+                    self.hs.get_reactor(), self.host, self.port
+                )
+            elif isinstance(ip, IPv6Address):
+                endpoint = TCP6ClientEndpoint(
+                    self.hs.get_reactor(), self.host, self.port
+                )
+            else:
+                raise ValueError("Unknown IP address provided: %s" % (self.host,))
+        except ValueError:
+            endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
+
+        factory = Factory.forProtocol(Protocol)
+        self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
+        self._service.startService()
+        self._connect()
+
+    def stop(self):
+        self._service.stopService()
+
+    def _connect(self) -> None:
+        """
+        Triggers an attempt to connect then write to the remote if not already writing.
+        """
+        if self._connection_waiter:
+            return
+
+        self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+        @self._connection_waiter.addErrback
+        def fail(r):
+            r.printTraceback(file=sys.__stderr__)
+            self._connection_waiter = None
+            self._connect()
+
+        @self._connection_waiter.addCallback
+        def writer(r):
+            # We have a connection. If we already have a producer, and its
+            # transport is the same, just trigger a resumeProducing.
+            if self._producer and r.transport is self._producer.transport:
+                self._producer.resumeProducing()
+                self._connection_waiter = None
+                return
+
+            # If the producer is still producing, stop it.
+            if self._producer:
+                self._producer.stopProducing()
+
+            # Make a new producer and start it.
+            self._producer = LogProducer(
+                buffer=self._buffer,
+                transport=r.transport,
+                format_event=self.format_event,
+            )
+            r.transport.registerProducer(self._producer, True)
+            self._producer.resumeProducing()
+            self._connection_waiter = None
+
+    def _handle_pressure(self) -> None:
+        """
+        Handle backpressure by shedding events.
+
+        The buffer will, in this order, until the buffer is below the maximum:
+            - Shed DEBUG events
+            - Shed INFO events
+            - Shed the middle 50% of the events.
+        """
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Strip out DEBUGs
+        self._buffer = deque(
+            filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
+        )
+
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Strip out INFOs
+        self._buffer = deque(
+            filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
+        )
+
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Cut the middle entries out
+        buffer_split = floor(self.maximum_buffer / 2)
+
+        old_buffer = self._buffer
+        self._buffer = deque()
+
+        for i in range(buffer_split):
+            self._buffer.append(old_buffer.popleft())
+
+        end_buffer = []
+        for i in range(buffer_split):
+            end_buffer.append(old_buffer.pop())
+
+        self._buffer.extend(reversed(end_buffer))
+
+    def __call__(self, event: dict) -> None:
+        self._buffer.append(event)
+
+        # Handle backpressure, if it exists.
+        try:
+            self._handle_pressure()
+        except Exception:
+            # If handling backpressure fails,clear the buffer and log the
+            # exception.
+            self._buffer.clear()
+            self._logger.failure("Failed clearing backpressure")
+
+        # Try and write immediately.
+        self._connect()
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 1b8916cfa2..9b46956ca9 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -18,26 +18,11 @@ Log formatters that output terse JSON.
 """
 
 import json
-import sys
-import traceback
-from collections import deque
-from ipaddress import IPv4Address, IPv6Address, ip_address
-from math import floor
-from typing import IO, Optional
+from typing import IO
 
-import attr
-from zope.interface import implementer
+from twisted.logger import FileLogObserver
 
-from twisted.application.internet import ClientService
-from twisted.internet.defer import Deferred
-from twisted.internet.endpoints import (
-    HostnameEndpoint,
-    TCP4ClientEndpoint,
-    TCP6ClientEndpoint,
-)
-from twisted.internet.interfaces import IPushProducer, ITransport
-from twisted.internet.protocol import Factory, Protocol
-from twisted.logger import FileLogObserver, ILogObserver, Logger
+from synapse.logging._remote import TCPLogObserver
 
 _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
 
@@ -150,180 +135,22 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
     return FileLogObserver(outFile, formatEvent)
 
 
-@attr.s
-@implementer(IPushProducer)
-class LogProducer:
+def TerseJSONToTCPLogObserver(
+    hs, host: str, port: int, metadata: dict, maximum_buffer: int
+) -> FileLogObserver:
     """
-    An IPushProducer that writes logs from its buffer to its transport when it
-    is resumed.
-
-    Args:
-        buffer: Log buffer to read logs from.
-        transport: Transport to write to.
-    """
-
-    transport = attr.ib(type=ITransport)
-    _buffer = attr.ib(type=deque)
-    _paused = attr.ib(default=False, type=bool, init=False)
-
-    def pauseProducing(self):
-        self._paused = True
-
-    def stopProducing(self):
-        self._paused = True
-        self._buffer = deque()
-
-    def resumeProducing(self):
-        self._paused = False
-
-        while self._paused is False and (self._buffer and self.transport.connected):
-            try:
-                event = self._buffer.popleft()
-                self.transport.write(_encoder.encode(event).encode("utf8"))
-                self.transport.write(b"\n")
-            except Exception:
-                # Something has gone wrong writing to the transport -- log it
-                # and break out of the while.
-                traceback.print_exc(file=sys.__stderr__)
-                break
-
-
-@attr.s
-@implementer(ILogObserver)
-class TerseJSONToTCPLogObserver:
-    """
-    An IObserver that writes JSON logs to a TCP target.
+    A log observer that formats events to a flattened JSON representation.
 
     Args:
         hs (HomeServer): The homeserver that is being logged for.
         host: The host of the logging target.
         port: The logging target's port.
-        metadata: Metadata to be added to each log entry.
+        metadata: Metadata to be added to each log object.
+        maximum_buffer: The maximum buffer size.
     """
 
-    hs = attr.ib()
-    host = attr.ib(type=str)
-    port = attr.ib(type=int)
-    metadata = attr.ib(type=dict)
-    maximum_buffer = attr.ib(type=int)
-    _buffer = attr.ib(default=attr.Factory(deque), type=deque)
-    _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
-    _logger = attr.ib(default=attr.Factory(Logger))
-    _producer = attr.ib(default=None, type=Optional[LogProducer])
-
-    def start(self) -> None:
-
-        # Connect without DNS lookups if it's a direct IP.
-        try:
-            ip = ip_address(self.host)
-            if isinstance(ip, IPv4Address):
-                endpoint = TCP4ClientEndpoint(
-                    self.hs.get_reactor(), self.host, self.port
-                )
-            elif isinstance(ip, IPv6Address):
-                endpoint = TCP6ClientEndpoint(
-                    self.hs.get_reactor(), self.host, self.port
-                )
-        except ValueError:
-            endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
-
-        factory = Factory.forProtocol(Protocol)
-        self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
-        self._service.startService()
-        self._connect()
-
-    def stop(self):
-        self._service.stopService()
-
-    def _connect(self) -> None:
-        """
-        Triggers an attempt to connect then write to the remote if not already writing.
-        """
-        if self._connection_waiter:
-            return
-
-        self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
-        @self._connection_waiter.addErrback
-        def fail(r):
-            r.printTraceback(file=sys.__stderr__)
-            self._connection_waiter = None
-            self._connect()
-
-        @self._connection_waiter.addCallback
-        def writer(r):
-            # We have a connection. If we already have a producer, and its
-            # transport is the same, just trigger a resumeProducing.
-            if self._producer and r.transport is self._producer.transport:
-                self._producer.resumeProducing()
-                self._connection_waiter = None
-                return
-
-            # If the producer is still producing, stop it.
-            if self._producer:
-                self._producer.stopProducing()
-
-            # Make a new producer and start it.
-            self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
-            r.transport.registerProducer(self._producer, True)
-            self._producer.resumeProducing()
-            self._connection_waiter = None
-
-    def _handle_pressure(self) -> None:
-        """
-        Handle backpressure by shedding events.
-
-        The buffer will, in this order, until the buffer is below the maximum:
-            - Shed DEBUG events
-            - Shed INFO events
-            - Shed the middle 50% of the events.
-        """
-        if len(self._buffer) <= self.maximum_buffer:
-            return
-
-        # Strip out DEBUGs
-        self._buffer = deque(
-            filter(lambda event: event["level"] != "DEBUG", self._buffer)
-        )
-
-        if len(self._buffer) <= self.maximum_buffer:
-            return
-
-        # Strip out INFOs
-        self._buffer = deque(
-            filter(lambda event: event["level"] != "INFO", self._buffer)
-        )
-
-        if len(self._buffer) <= self.maximum_buffer:
-            return
-
-        # Cut the middle entries out
-        buffer_split = floor(self.maximum_buffer / 2)
-
-        old_buffer = self._buffer
-        self._buffer = deque()
-
-        for i in range(buffer_split):
-            self._buffer.append(old_buffer.popleft())
-
-        end_buffer = []
-        for i in range(buffer_split):
-            end_buffer.append(old_buffer.pop())
-
-        self._buffer.extend(reversed(end_buffer))
-
-    def __call__(self, event: dict) -> None:
-        flattened = flatten_event(event, self.metadata, include_time=True)
-        self._buffer.append(flattened)
-
-        # Handle backpressure, if it exists.
-        try:
-            self._handle_pressure()
-        except Exception:
-            # If handling backpressure fails,clear the buffer and log the
-            # exception.
-            self._buffer.clear()
-            self._logger.failure("Failed clearing backpressure")
+    def formatEvent(_event: dict) -> str:
+        flattened = flatten_event(_event, metadata, include_time=True)
+        return _encoder.encode(flattened) + "\n"
 
-        # Try and write immediately.
-        self._connect()
+    return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 5b73463504..ea5f1c7b62 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -24,6 +24,7 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
 from twisted.internet import defer
 
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.opentracing import start_active_span
 
 if TYPE_CHECKING:
     import resource
@@ -197,14 +198,14 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
 
         with BackgroundProcessLoggingContext(desc) as context:
             context.request = "%s-%i" % (desc, count)
-
             try:
-                result = func(*args, **kwargs)
+                with start_active_span(desc, tags={"request_id": context.request}):
+                    result = func(*args, **kwargs)
 
-                if inspect.isawaitable(result):
-                    result = await result
+                    if inspect.isawaitable(result):
+                        result = await result
 
-                return result
+                    return result
             except Exception:
                 logger.exception(
                     "Background process '%s' threw an exception", desc,
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c440f2545c..a701defcdd 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
     # dedupe when we add callbacks to lru cache nodes, otherwise the number
     # of callbacks would grow.
     def __call__(self):
-        rules = self.cache.get(self.room_id, None, update_metrics=False)
+        rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
         if rules:
             rules.invalidate_all()
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 455a1acb46..155791b754 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -387,8 +387,8 @@ class Mailer:
         return ret
 
     async def get_message_vars(self, notif, event, room_state_ids):
-        if event.type != EventTypes.Message:
-            return
+        if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
+            return None
 
         sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
         sender_state_event = await self.store.get_event(sender_state_event_id)
@@ -399,10 +399,8 @@ class Mailer:
         # sender_hash % the number of default images to choose from
         sender_hash = string_ordinal_total(event.sender)
 
-        msgtype = event.content.get("msgtype")
-
         ret = {
-            "msgtype": msgtype,
+            "event_type": event.type,
             "is_historical": event.event_id != notif["event_id"],
             "id": event.event_id,
             "ts": event.origin_server_ts,
@@ -411,6 +409,14 @@ class Mailer:
             "sender_hash": sender_hash,
         }
 
+        # Encrypted messages don't have any additional useful information.
+        if event.type == EventTypes.Encrypted:
+            return ret
+
+        msgtype = event.content.get("msgtype")
+
+        ret["msgtype"] = msgtype
+
         if msgtype == "m.text":
             self.add_text_message_vars(ret, event)
         elif msgtype == "m.image":
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 3a68ce636f..2ce9e444ab 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,11 +16,10 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Optional, Pattern, Union
+from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
 
 from synapse.events import EventBase
 from synapse.types import UserID
-from synapse.util.caches import register_cache
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -174,20 +173,21 @@ class PushRuleEvaluatorForEvent:
         # Similar to _glob_matches, but do not treat display_name as a glob.
         r = regex_cache.get((display_name, False, True), None)
         if not r:
-            r = re.escape(display_name)
-            r = _re_word_boundary(r)
-            r = re.compile(r, flags=re.IGNORECASE)
+            r1 = re.escape(display_name)
+            r1 = _re_word_boundary(r1)
+            r = re.compile(r1, flags=re.IGNORECASE)
             regex_cache[(display_name, False, True)] = r
 
-        return r.search(body)
+        return bool(r.search(body))
 
     def _get_value(self, dotted_key: str) -> Optional[str]:
         return self._value_cache.get(dotted_key, None)
 
 
 # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache = LruCache(50000)
-register_cache("cache", "regex_push_cache", regex_cache)
+regex_cache = LruCache(
+    50000, "regex_push_cache"
+)  # type: LruCache[Tuple[str, bool, bool], Pattern]
 
 
 def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
@@ -205,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
         if not r:
             r = _glob_to_re(glob, word_boundary)
             regex_cache[(glob, True, word_boundary)] = r
-        return r.search(value)
+        return bool(r.search(value))
     except re.error:
         logger.warning("Failed to parse glob to regex: %r", glob)
         return False
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 4b0ea0cc01..0f5b7adef7 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -15,7 +15,7 @@
 
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
 
 from ._base import BaseSlavedStore
 
@@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
-        self.client_ip_last_seen = DeferredCache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000
-        )  # type: DeferredCache[tuple, int]
+        self.client_ip_last_seen = LruCache(
+            cache_name="client_ip_last_seen", keylen=4, max_size=50000
+        )  # type: LruCache[tuple, int]
 
     async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
         now = int(self._clock.time_msec())
@@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore):
         if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
             return
 
-        self.client_ip_last_seen.prefill(key, now)
+        self.client_ip_last_seen.set(key, now)
 
         self.hs.get_tcp_replication().send_user_ip(
             user_id, access_token, ip, user_agent, device_id, now
diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html
index 1a6c70b562..6d76064d13 100644
--- a/synapse/res/templates/notif.html
+++ b/synapse/res/templates/notif.html
@@ -1,41 +1,47 @@
-{% for message in notif.messages %}
+{%- for message in notif.messages %}
     <tr class="{{ "historical_message" if message.is_historical else "message" }}">
         <td class="sender_avatar">
-            {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
-                {% if message.sender_avatar_url %}
+            {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+                {%- if message.sender_avatar_url %}
                     <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}"  />
-                {% else %}
-                    {% if message.sender_hash % 3 == 0 %}
+                {%- else %}
+                    {%- if message.sender_hash % 3 == 0 %}
                         <img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png"  />
-                    {% elif message.sender_hash % 3 == 1 %}
+                    {%- elif message.sender_hash % 3 == 1 %}
                         <img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png"  />
-                    {% else %}
+                    {%- else %}
                         <img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png"  />
-                    {% endif %}
-                {% endif %}
-            {% endif %}
+                    {%- endif %}
+                {%- endif %}
+            {%- endif %}
         </td>
         <td class="message_contents">
-            {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
-                <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
-            {% endif %}
+            {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+                <div class="sender_name">{%- if message.msgtype == "m.emote" %}*{%- endif %} {{ message.sender_name }}</div>
+            {%- endif %}
             <div class="message_body">
-                {% if message.msgtype == "m.text" %}
-                    {{ message.body_text_html }}
-                {% elif message.msgtype == "m.emote" %}
-                    {{ message.body_text_html }}
-                {% elif message.msgtype == "m.notice" %}
-                    {{ message.body_text_html }}
-                {% elif message.msgtype == "m.image" %}
-                    <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
-                {% elif message.msgtype == "m.file" %}
-                    <span class="filename">{{ message.body_text_plain }}</span>
-                {% endif %}
+                {%- if message.event_type == "m.room.encrypted" %}
+                    An encrypted message.
+                {%- elif message.event_type == "m.room.message" %}
+                    {%- if message.msgtype == "m.text" %}
+                        {{ message.body_text_html }}
+                    {%- elif message.msgtype == "m.emote" %}
+                        {{ message.body_text_html }}
+                    {%- elif message.msgtype == "m.notice" %}
+                        {{ message.body_text_html }}
+                    {%- elif message.msgtype == "m.image" %}
+                        <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
+                    {%- elif message.msgtype == "m.file" %}
+                        <span class="filename">{{ message.body_text_plain }}</span>
+                    {%- else %}
+                        A message with unrecognised content.
+                    {%- endif %}
+                {%- endif %}
             </div>
         </td>
         <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
     </tr>
-{% endfor %}
+{%- endfor %}
 <tr class="notif_link">
     <td></td>
     <td>
diff --git a/synapse/res/templates/notif.txt b/synapse/res/templates/notif.txt
index a37bee9833..1ee7da3c50 100644
--- a/synapse/res/templates/notif.txt
+++ b/synapse/res/templates/notif.txt
@@ -1,16 +1,22 @@
-{% for message in notif.messages %}
-{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
-{% if message.msgtype == "m.text" %}
+{%- for message in notif.messages %}
+{%- if message.event_type == "m.room.encrypted" %}
+An encrypted message.
+{%- elif message.event_type == "m.room.message" %}
+{%- if message.msgtype == "m.emote" %}* {%- endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
+{%- if message.msgtype == "m.text" %}
 {{ message.body_text_plain }}
-{% elif message.msgtype == "m.emote" %}
+{%- elif message.msgtype == "m.emote" %}
 {{ message.body_text_plain }}
-{% elif message.msgtype == "m.notice" %}
+{%- elif message.msgtype == "m.notice" %}
 {{ message.body_text_plain }}
-{% elif message.msgtype == "m.image" %}
+{%- elif message.msgtype == "m.image" %}
 {{ message.body_text_plain }}
-{% elif message.msgtype == "m.file" %}
+{%- elif message.msgtype == "m.file" %}
 {{ message.body_text_plain }}
-{% endif %}
-{% endfor %}
+{%- else %}
+A message with unrecognised content.
+{%- endif %}
+{%- endif %}
+{%- endfor %}
 
 View {{ room.title }} at {{ notif.link }}
diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index a2dfeb9e9f..27d4182790 100644
--- a/synapse/res/templates/notif_mail.html
+++ b/synapse/res/templates/notif_mail.html
@@ -2,8 +2,8 @@
 <html lang="en">
     <head>
         <style type="text/css">
-            {% include 'mail.css' without context %}
-            {% include "mail-%s.css" % app_name ignore missing without context %}
+            {%- include 'mail.css' without context %}
+            {%- include "mail-%s.css" % app_name ignore missing without context %}
         </style>
     </head>
     <body>
@@ -18,21 +18,21 @@
                                 <div class="summarytext">{{ summary_text }}</div>
                             </td>
                             <td class="logo">
-                                {% if app_name == "Riot" %}
+                                {%- if app_name == "Riot" %}
                                     <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
-                                {% elif app_name == "Vector" %}
+                                {%- elif app_name == "Vector" %}
                                     <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
-                                {% elif app_name == "Element" %}
+                                {%- elif app_name == "Element" %}
                                     <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
-                                {% else %}
+                                {%- else %}
                                     <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
-                                {% endif %}
+                                {%- endif %}
                             </td>
                         </tr>
                     </table>
-                    {% for room in rooms %}
-                        {% include 'room.html' with context %}
-                    {% endfor %}
+                    {%- for room in rooms %}
+                        {%- include 'room.html' with context %}
+                    {%- endfor %}
                     <div class="footer">
                         <a href="{{ unsubscribe_link }}">Unsubscribe</a>
                         <br/>
@@ -41,12 +41,12 @@
                             Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
                             an event was received at {{ reason.received_at|format_ts("%c") }}
                             which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
-                            {% if reason.last_sent_ts %}
+                            {%- if reason.last_sent_ts %}
                                 and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
                                 which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
-                            {% else %}
+                            {%- else %}
                                 and we don't have a last time we sent a mail for this room.
-                            {% endif %}
+                            {%- endif %}
                         </div>
                     </div>
                 </td>
diff --git a/synapse/res/templates/notif_mail.txt b/synapse/res/templates/notif_mail.txt
index 24843042a5..df3c253979 100644
--- a/synapse/res/templates/notif_mail.txt
+++ b/synapse/res/templates/notif_mail.txt
@@ -2,9 +2,9 @@ Hi {{ user_display_name }},
 
 {{ summary_text }}
 
-{% for room in rooms %}
-{% include 'room.txt' with context %}
-{% endfor %}
+{%- for room in rooms %}
+{%- include 'room.txt' with context %}
+{%- endfor %}
 
 You can disable these notifications at {{ unsubscribe_link }}
 
diff --git a/synapse/res/templates/room.html b/synapse/res/templates/room.html
index b8525fef88..4fc6f6ac9b 100644
--- a/synapse/res/templates/room.html
+++ b/synapse/res/templates/room.html
@@ -1,23 +1,23 @@
 <table class="room">
     <tr class="room_header">
         <td class="room_avatar">
-            {% if room.avatar_url %}
+            {%- if room.avatar_url %}
                 <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
-            {% else %}
-                {% if room.hash % 3 == 0 %}
+            {%- else %}
+                {%- if room.hash % 3 == 0 %}
                     <img alt="" src="https://riot.im/img/external/avatar-1.png"  />
-                {% elif room.hash % 3 == 1 %}
+                {%- elif room.hash % 3 == 1 %}
                     <img alt="" src="https://riot.im/img/external/avatar-2.png"  />
-                {% else %}
+                {%- else %}
                     <img alt="" src="https://riot.im/img/external/avatar-3.png"  />
-                {% endif %}
-            {% endif %}
+                {%- endif %}
+            {%- endif %}
         </td>
         <td class="room_name" colspan="2">
             {{ room.title }}
         </td>
     </tr>
-    {% if room.invite %}
+    {%- if room.invite %}
         <tr>
             <td></td>
             <td>
@@ -25,9 +25,9 @@
             </td>
             <td></td>
         </tr>
-    {% else %}
-        {% for notif in room.notifs %}
-            {% include 'notif.html' with context %}
-        {% endfor %}
-    {% endif %}
+    {%- else %}
+        {%- for notif in room.notifs %}
+            {%- include 'notif.html' with context %}
+        {%- endfor %}
+    {%- endif %}
 </table>
diff --git a/synapse/res/templates/room.txt b/synapse/res/templates/room.txt
index 84648c710e..df841e9e6f 100644
--- a/synapse/res/templates/room.txt
+++ b/synapse/res/templates/room.txt
@@ -1,9 +1,9 @@
 {{ room.title }}
 
-{% if room.invite %}
+{%- if room.invite %}
     You've been invited, join at {{ room.link }}
-{% else %}
-    {% for notif in room.notifs %}
-        {% include 'notif.txt' with context %}
-    {% endfor %}
-{% endif %}
+{%- else %}
+    {%- for notif in room.notifs %}
+        {%- include 'notif.txt' with context %}
+    {%- endfor %}
+{%- endif %}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7deb9300d..b82a4e978a 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -110,6 +110,8 @@ class LoginRestServlet(RestServlet):
             ({"type": t} for t in self.auth_handler.get_supported_login_types())
         )
 
+        flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
+
         return 200, {"flows": flows}
 
     def on_OPTIONS(self, request: SynapseRequest):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ab49d227de..2b196ded1b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta):
         """
 
         try:
-            if key is None:
-                getattr(self, cache_name).invalidate_all()
-            else:
-                getattr(self, cache_name).invalidate(tuple(key))
+            cache = getattr(self, cache_name)
         except AttributeError:
             # We probably haven't pulled in the cache in this worker,
             # which is fine.
-            pass
+            return
+
+        if key is None:
+            cache.invalidate_all()
+        else:
+            cache.invalidate(tuple(key))
 
 
 def db_to_json(db_content):
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 763722d6bc..0217e63108 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -160,7 +160,7 @@ class LoggingDatabaseConnection:
         self.conn.__enter__()
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
         return self.conn.__exit__(exc_type, exc_value, traceback)
 
     # Proxy through any unknown lookups to the DB conn class.
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index ec8260e906..2408432738 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
 
@@ -410,8 +410,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 class ClientIpStore(ClientIpWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
 
-        self.client_ip_last_seen = DeferredCache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000
+        self.client_ip_last_seen = LruCache(
+            cache_name="client_ip_last_seen", keylen=4, max_size=50000
         )
 
         super().__init__(database, db_conn, hs)
@@ -442,7 +442,7 @@ class ClientIpStore(ClientIpWorkerStore):
         if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
             return
 
-        self.client_ip_last_seen.prefill(key, now)
+        self.client_ip_last_seen.set(key, now)
 
         self._batch_row_update[key] = (user_agent, device_id, now)
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e662a20d24..dfb4f87b8f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -34,8 +34,8 @@ from synapse.storage.database import (
 )
 from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
-from synapse.util.caches.deferred_cache import DeferredCache
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
 
@@ -1005,8 +1005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
-        self.device_id_exists_cache = DeferredCache(
-            name="device_id_exists", keylen=2, max_entries=10000
+        self.device_id_exists_cache = LruCache(
+            cache_name="device_id_exists", keylen=2, max_size=10000
         )
 
     async def store_device(
@@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 )
                 if hidden:
                     raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
-            self.device_id_exists_cache.prefill(key, True)
+            self.device_id_exists_cache.set(key, True)
             return inserted
         except StoreError:
             raise
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ba3b1769b0..87808c1483 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1051,9 +1051,7 @@ class PersistEventsStore:
 
         def prefill():
             for cache_entry in to_prefill:
-                self.store._get_event_cache.prefill(
-                    (cache_entry[0].event_id,), cache_entry
-                )
+                self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
 
         txn.call_after(prefill)
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index ff150f0be7..6e7f16f39c 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -33,7 +33,10 @@ from synapse.api.room_versions import (
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
@@ -42,8 +45,8 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import Collection, get_domain_from_id
-from synapse.util.caches.deferred_cache import DeferredCache
 from synapse.util.caches.descriptors import cached
+from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -137,20 +140,16 @@ class EventsWorkerStore(SQLBaseStore):
                     db_conn, "events", "stream_ordering", step=-1
                 )
 
-        if not hs.config.worker.worker_app:
+        if hs.config.run_background_tasks:
             # We periodically clean out old transaction ID mappings
             self._clock.looping_call(
-                run_as_background_process,
-                5 * 60 * 1000,
-                "_cleanup_old_transaction_ids",
-                self._cleanup_old_transaction_ids,
+                self._cleanup_old_transaction_ids, 5 * 60 * 1000,
             )
 
-        self._get_event_cache = DeferredCache(
-            "*getEvent*",
+        self._get_event_cache = LruCache(
+            cache_name="*getEvent*",
             keylen=3,
-            max_entries=hs.config.caches.event_cache_size,
-            apply_cache_factor_from_config=False,
+            max_size=hs.config.caches.event_cache_size,
         )
 
         self._event_fetch_lock = threading.Condition()
@@ -749,7 +748,7 @@ class EventsWorkerStore(SQLBaseStore):
                 event=original_ev, redacted_event=redacted_event
             )
 
-            self._get_event_cache.prefill((event_id,), cache_entry)
+            self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
         return result_map
@@ -1375,6 +1374,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return mapping
 
+    @wrap_as_background_process("_cleanup_old_transaction_ids")
     async def _cleanup_old_transaction_ids(self):
         """Cleans out transaction id mappings older than 24hrs.
         """
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 1681caa1f0..a6d1eb908a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
@@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def set_profile_displayname(
-        self, user_localpart: str, new_displayname: str
+        self, user_localpart: str, new_displayname: Optional[str]
     ) -> None:
         await self.db_pool.simple_update_one(
             table="profiles",
@@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
 
     async def get_remote_profile_cache_entries_that_expire(
         self, last_checked: int
-    ) -> Dict[str, str]:
+    ) -> List[Dict[str, str]]:
         """Get all users who haven't been checked since `last_checked`
         """
 
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df8609b97b..7997242d90 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
                 lock=False,
             )
 
-            user_has_pusher = self.get_if_user_has_pusher.cache.get(
+            user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
                 (user_id,), None, update_metrics=False
             )
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5cdf16521c..ca7917c989 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
-from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         if receipt_type != "m.read":
             return
 
-        # Returns either an ObservableDeferred or the raw result
-        res = self.get_users_with_read_receipts_in_room.cache.get(
+        res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
             room_id, None, update_metrics=False
         )
 
-        # first handle the ObservableDeferred case
-        if isinstance(res, ObservableDeferred):
-            if res.has_called():
-                res = res.get_result()
-            else:
-                res = None
-
         if res and user_id in res:
             # We'd only be adding to the set, so no point invalidating if the
             # user is already there
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 20fcdaa529..01d9dbb36f 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -20,7 +20,10 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
@@ -67,16 +70,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         ):
             self._known_servers_count = 1
             self.hs.get_clock().looping_call(
-                run_as_background_process,
-                60 * 1000,
-                "_count_known_servers",
-                self._count_known_servers,
+                self._count_known_servers, 60 * 1000,
             )
             self.hs.get_clock().call_later(
-                1000,
-                run_as_background_process,
-                "_count_known_servers",
-                self._count_known_servers,
+                1000, self._count_known_servers,
             )
             LaterGauge(
                 "synapse_federation_known_servers",
@@ -85,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 lambda: self._known_servers_count,
             )
 
+    @wrap_as_background_process("_count_known_servers")
     async def _count_known_servers(self):
         """
         Count the servers that this server knows about.
@@ -531,7 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # If we do then we can reuse that result and simply update it with
             # any membership changes in `delta_ids`
             if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get(
+                prev_res = self._get_joined_users_from_context.cache.get_immediate(
                     (room_id, context.prev_group), None
                 )
                 if prev_res and isinstance(prev_res, dict):
diff --git a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
index 20f5a95a24..7b84a207fd 100644
--- a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
@@ -13,6 +13,5 @@
  * limitations under the License.
  */
 
-ALTER TABLE application_services_state
-    ADD COLUMN read_receipt_stream_id INT,
-    ADD COLUMN presence_stream_id INT;
\ No newline at end of file
+ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT;
+ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
new file mode 100644
index 0000000000..01ea6eddcf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
@@ -0,0 +1 @@
+DROP TABLE device_max_stream_id;
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 970bb1b9da..9cadcba18f 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Iterable, Iterator, List, Tuple
+from typing import Any, Iterable, Iterator, List, Optional, Tuple
 
 from typing_extensions import Protocol
 
@@ -65,5 +65,5 @@ class Connection(Protocol):
     def __enter__(self) -> "Connection":
         ...
 
-    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
         ...
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8fc05be278..89f0b38535 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,7 +16,7 @@
 
 import logging
 from sys import intern
-from typing import Callable, Dict, Optional
+from typing import Callable, Dict, Optional, Sized
 
 import attr
 from prometheus_client.core import Gauge
@@ -92,7 +92,7 @@ class CacheMetric:
 def register_cache(
     cache_type: str,
     cache_name: str,
-    cache,
+    cache: Sized,
     collect_callback: Optional[Callable] = None,
     resizable: bool = True,
     resize_callback: Optional[Callable] = None,
@@ -100,12 +100,15 @@ def register_cache(
     """Register a cache object for metric collection and resizing.
 
     Args:
-        cache_type
+        cache_type: a string indicating the "type" of the cache. This is used
+            only for deduplication so isn't too important provided it's constant.
         cache_name: name of the cache
-        cache: cache itself
+        cache: cache itself, which must implement __len__(), and may optionally implement
+             a max_size property
         collect_callback: If given, a function which is called during metric
             collection to update additional metrics.
-        resizable: Whether this cache supports being resized.
+        resizable: Whether this cache supports being resized, in which case either
+            resize_callback must be provided, or the cache must support set_max_size().
         resize_callback: A function which can be called to resize the cache.
 
     Returns:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index f728cd2cf2..601305487c 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -17,14 +17,23 @@
 
 import enum
 import threading
-from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
+from typing import (
+    Callable,
+    Generic,
+    Iterable,
+    MutableMapping,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 from prometheus_client import Gauge
 
 from twisted.internet import defer
+from twisted.python import failure
 
 from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches import register_cache
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 
@@ -34,7 +43,7 @@ cache_pending_metric = Gauge(
     ["name"],
 )
 
-
+T = TypeVar("T")
 KT = TypeVar("KT")
 VT = TypeVar("VT")
 
@@ -49,15 +58,12 @@ class DeferredCache(Generic[KT, VT]):
     """Wraps an LruCache, adding support for Deferred results.
 
     It expects that each entry added with set() will be a Deferred; likewise get()
-    may return an ObservableDeferred.
+    will return a Deferred.
     """
 
     __slots__ = (
         "cache",
-        "name",
-        "keylen",
         "thread",
-        "metrics",
         "_pending_deferred_cache",
     )
 
@@ -89,37 +95,27 @@ class DeferredCache(Generic[KT, VT]):
             cache_type()
         )  # type: MutableMapping[KT, CacheEntry]
 
+        def metrics_cb():
+            cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
+
         # cache is used for completed results and maps to the result itself, rather than
         # a Deferred.
         self.cache = LruCache(
             max_size=max_entries,
             keylen=keylen,
+            cache_name=name,
             cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
-            evicted_callback=self._on_evicted,
+            metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
-        )
+        )  # type: LruCache[KT, VT]
 
-        self.name = name
-        self.keylen = keylen
         self.thread = None  # type: Optional[threading.Thread]
-        self.metrics = register_cache(
-            "cache",
-            name,
-            self.cache,
-            collect_callback=self._metrics_collection_callback,
-        )
 
     @property
     def max_entries(self):
         return self.cache.max_size
 
-    def _on_evicted(self, evicted_count):
-        self.metrics.inc_evictions(evicted_count)
-
-    def _metrics_collection_callback(self):
-        cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
-
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
@@ -133,62 +129,113 @@ class DeferredCache(Generic[KT, VT]):
     def get(
         self,
         key: KT,
-        default=_Sentinel.sentinel,
         callback: Optional[Callable[[], None]] = None,
         update_metrics: bool = True,
-    ):
+    ) -> defer.Deferred:
         """Looks the key up in the caches.
 
+        For symmetry with set(), this method does *not* follow the synapse logcontext
+        rules: the logcontext will not be cleared on return, and the Deferred will run
+        its callbacks in the sentinel context. In other words: wrap the result with
+        make_deferred_yieldable() before `await`ing it.
+
         Args:
-            key(tuple)
-            default: What is returned if key is not in the caches. If not
-                specified then function throws KeyError instead
-            callback(fn): Gets called when the entry in the cache is invalidated
+            key:
+            callback: Gets called when the entry in the cache is invalidated
             update_metrics (bool): whether to update the cache hit rate metrics
 
         Returns:
-            Either an ObservableDeferred or the result itself
+            A Deferred which completes with the result. Note that this may later fail
+            if there is an ongoing set() operation which later completes with a failure.
+
+        Raises:
+            KeyError if the key is not found in the cache
         """
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
         if val is not _Sentinel.sentinel:
             val.callbacks.update(callbacks)
             if update_metrics:
-                self.metrics.inc_hits()
-            return val.deferred
-
-        val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
-        if val is not _Sentinel.sentinel:
-            self.metrics.inc_hits()
-            return val
+                m = self.cache.metrics
+                assert m  # we always have a name, so should always have metrics
+                m.inc_hits()
+            return val.deferred.observe()
 
-        if update_metrics:
-            self.metrics.inc_misses()
-
-        if default is _Sentinel.sentinel:
+        val2 = self.cache.get(
+            key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
+        )
+        if val2 is _Sentinel.sentinel:
             raise KeyError()
         else:
-            return default
+            return defer.succeed(val2)
+
+    def get_immediate(
+        self, key: KT, default: T, update_metrics: bool = True
+    ) -> Union[VT, T]:
+        """If we have a *completed* cached value, return it."""
+        return self.cache.get(key, default, update_metrics=update_metrics)
 
     def set(
         self,
         key: KT,
         value: defer.Deferred,
         callback: Optional[Callable[[], None]] = None,
-    ) -> ObservableDeferred:
+    ) -> defer.Deferred:
+        """Adds a new entry to the cache (or updates an existing one).
+
+        The given `value` *must* be a Deferred.
+
+        First any existing entry for the same key is invalidated. Then a new entry
+        is added to the cache for the given key.
+
+        Until the `value` completes, calls to `get()` for the key will also result in an
+        incomplete Deferred, which will ultimately complete with the same result as
+        `value`.
+
+        If `value` completes successfully, subsequent calls to `get()` will then return
+        a completed deferred with the same result. If it *fails*, the cache is
+        invalidated and subequent calls to `get()` will raise a KeyError.
+
+        If another call to `set()` happens before `value` completes, then (a) any
+        invalidation callbacks registered in the interim will be called, (b) any
+        `get()`s in the interim will continue to complete with the result from the
+        *original* `value`, (c) any future calls to `get()` will complete with the
+        result from the *new* `value`.
+
+        It is expected that `value` does *not* follow the synapse logcontext rules - ie,
+        if it is incomplete, it runs its callbacks in the sentinel context.
+
+        Args:
+            key: Key to be set
+            value: a deferred which will complete with a result to add to the cache
+            callback: An optional callback to be called when the entry is invalidated
+        """
         if not isinstance(value, defer.Deferred):
             raise TypeError("not a Deferred")
 
         callbacks = [callback] if callback else []
         self.check_thread()
-        observable = ObservableDeferred(value, consumeErrors=True)
-        observer = observable.observe()
-        entry = CacheEntry(deferred=observable, callbacks=callbacks)
 
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
             existing_entry.invalidate()
 
+        # XXX: why don't we invalidate the entry in `self.cache` yet?
+
+        # we can save a whole load of effort if the deferred is ready.
+        if value.called:
+            result = value.result
+            if not isinstance(result, failure.Failure):
+                self.cache.set(key, result, callbacks)
+            return value
+
+        # otherwise, we'll add an entry to the _pending_deferred_cache for now,
+        # and add callbacks to add it to the cache properly later.
+
+        observable = ObservableDeferred(value, consumeErrors=True)
+        observer = observable.observe()
+        entry = CacheEntry(deferred=observable, callbacks=callbacks)
+
         self._pending_deferred_cache[key] = entry
 
         def compare_and_pop():
@@ -232,7 +279,9 @@ class DeferredCache(Generic[KT, VT]):
         # _pending_deferred_cache to the real cache.
         #
         observer.addCallbacks(cb, eb)
-        return observable
+
+        # we return a new Deferred which will be called before any subsequent observers.
+        return observable.observe()
 
     def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
         callbacks = [callback] if callback else []
@@ -257,11 +306,12 @@ class DeferredCache(Generic[KT, VT]):
         self.check_thread()
         if not isinstance(key, tuple):
             raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+        key = cast(KT, key)
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
+        entry_dict = self._pending_deferred_cache.pop(key, None)
         if entry_dict is not None:
             for entry in iterate_tree_cache_entry(entry_dict):
                 entry.invalidate()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1f43886804..5d7fffee66 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -23,7 +23,6 @@ from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.deferred_cache import DeferredCache
 
 logger = logging.getLogger(__name__)
@@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             keylen=self.num_args,
             tree=self.tree,
             iterable=self.iterable,
-        )  # type: DeferredCache[Tuple, Any]
+        )  # type: DeferredCache[CacheKey, Any]
 
         def get_cache_key_gen(args, kwargs):
             """Given some args/kwargs return a generator that resolves into
@@ -202,32 +201,20 @@ class CacheDescriptor(_CacheDescriptorBase):
 
             cache_key = get_cache_key(args, kwargs)
 
-            # Add our own `cache_context` to argument list if the wrapped function
-            # has asked for one
-            if self.add_cache_context:
-                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
-
             try:
-                cached_result_d = cache.get(cache_key, callback=invalidate_callback)
-
-                if isinstance(cached_result_d, ObservableDeferred):
-                    observer = cached_result_d.observe()
-                else:
-                    observer = defer.succeed(cached_result_d)
-
+                ret = cache.get(cache_key, callback=invalidate_callback)
             except KeyError:
-                ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+                # Add our own `cache_context` to argument list if the wrapped function
+                # has asked for one
+                if self.add_cache_context:
+                    kwargs["cache_context"] = _CacheContext.get_instance(
+                        cache, cache_key
+                    )
 
-                def onErr(f):
-                    cache.invalidate(cache_key)
-                    return f
-
-                ret.addErrback(onErr)
-
-                result_d = cache.set(cache_key, ret, callback=invalidate_callback)
-                observer = result_d.observe()
+                ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+                ret = cache.set(cache_key, ret, callback=invalidate_callback)
 
-            return make_deferred_yieldable(observer)
+            return make_deferred_yieldable(ret)
 
         wrapped = cast(_CachedFunction, _wrapped)
 
@@ -286,7 +273,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj, objtype=None):
         cached_method = getattr(obj, self.cached_method_name)
-        cache = cached_method.cache
+        cache = cached_method.cache  # type: DeferredCache[CacheKey, Any]
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
@@ -326,14 +313,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
             for arg in list_args:
                 try:
                     res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not isinstance(res, ObservableDeferred):
-                        results[arg] = res
-                    elif not res.has_succeeded():
-                        res = res.observe()
+                    if not res.called:
                         res.addCallback(update_results_dict, arg)
                         cached_defers.append(res)
                     else:
-                        results[arg] = res.get_result()
+                        results[arg] = res.result
                 except KeyError:
                     missing.add(arg)
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 8592b93689..588d2d49f2 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -12,15 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 import logging
 import threading
 from collections import namedtuple
+from typing import Any
 
 from synapse.util.caches.lrucache import LruCache
 
-from . import register_cache
-
 logger = logging.getLogger(__name__)
 
 
@@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
         return len(self.value)
 
 
+class _Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
+
+
 class DictionaryCache:
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
     """
 
     def __init__(self, name, max_entries=1000):
-        self.cache = LruCache(max_size=max_entries, size_callback=len)
+        self.cache = LruCache(
+            max_size=max_entries, cache_name=name, size_callback=len
+        )  # type: LruCache[Any, DictionaryEntry]
 
         self.name = name
         self.sequence = 0
         self.thread = None
-        # caches_by_name[name] = self.cache
-
-        class Sentinel:
-            __slots__ = []
-
-        self.sentinel = Sentinel()
-        self.metrics = register_cache("dictionary", name, self.cache)
 
     def check_thread(self):
         expected_thread = self.thread
@@ -80,10 +80,8 @@ class DictionaryCache:
         Returns:
             DictionaryEntry
         """
-        entry = self.cache.get(key, self.sentinel)
-        if entry is not self.sentinel:
-            self.metrics.inc_hits()
-
+        entry = self.cache.get(key, _Sentinel.sentinel)
+        if entry is not _Sentinel.sentinel:
             if dict_keys is None:
                 return DictionaryEntry(
                     entry.full, entry.known_absent, dict(entry.value)
@@ -95,7 +93,6 @@ class DictionaryCache:
                     {k: entry.value[k] for k in dict_keys if k in entry.value},
                 )
 
-        self.metrics.inc_misses()
         return DictionaryEntry(False, set(), {})
 
     def invalidate(self, key):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 33eae2b7c4..60bb6ff642 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,11 +15,35 @@
 
 import threading
 from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
+
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
+from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.treecache import TreeCache
 
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
 
 def enumerate_leaves(node, depth):
     if depth == 0:
@@ -41,29 +65,31 @@ class _Node:
         self.callbacks = callbacks
 
 
-class LruCache:
+class LruCache(Generic[KT, VT]):
     """
-    Least-recently-used cache.
+    Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
+
     Supports del_multi only if cache_type=TreeCache
     If cache_type=TreeCache, all keys must be tuples.
-
-    Can also set callbacks on objects when getting/setting which are fired
-    when that key gets invalidated/evicted.
     """
 
     def __init__(
         self,
         max_size: int,
+        cache_name: Optional[str] = None,
         keylen: int = 1,
         cache_type: Type[Union[dict, TreeCache]] = dict,
         size_callback: Optional[Callable] = None,
-        evicted_callback: Optional[Callable] = None,
+        metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
     ):
         """
         Args:
             max_size: The maximum amount of entries the cache can hold
 
+            cache_name: The name of this cache, for the prometheus metrics. If unset,
+                no metrics will be reported on this cache.
+
             keylen: The length of the tuple used as the cache key. Ignored unless
                 cache_type is `TreeCache`.
 
@@ -73,9 +99,13 @@ class LruCache:
 
             size_callback (func(V) -> int | None):
 
-            evicted_callback (func(int)|None):
-                if not None, called on eviction with the size of the evicted
-                entry
+            metrics_collection_callback:
+                metrics collection callback. This is called early in the metrics
+                collection process, before any of the metrics registered with the
+                prometheus Registry are collected, so can be used to update any dynamic
+                metrics.
+
+                Ignored if cache_name is None.
 
             apply_cache_factor_from_config (bool): If true, `max_size` will be
                 multiplied by a cache factor derived from the homeserver config
@@ -94,6 +124,23 @@ class LruCache:
         else:
             self.max_size = int(max_size)
 
+        # register_cache might call our "set_cache_factor" callback; there's nothing to
+        # do yet when we get resized.
+        self._on_resize = None  # type: Optional[Callable[[],None]]
+
+        if cache_name is not None:
+            metrics = register_cache(
+                "lru_cache",
+                cache_name,
+                self,
+                collect_callback=metrics_collection_callback,
+            )  # type: Optional[CacheMetric]
+        else:
+            metrics = None
+
+        # this is exposed for access from outside this class
+        self.metrics = metrics
+
         list_root = _Node(None, None, None, None)
         list_root.next_node = list_root
         list_root.prev_node = list_root
@@ -105,16 +152,16 @@ class LruCache:
                 todelete = list_root.prev_node
                 evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
-                if evicted_callback:
-                    evicted_callback(evicted_len)
+                if metrics:
+                    metrics.inc_evictions(evicted_len)
 
-        def synchronized(f):
+        def synchronized(f: FT) -> FT:
             @wraps(f)
             def inner(*args, **kwargs):
                 with lock:
                     return f(*args, **kwargs)
 
-            return inner
+            return cast(FT, inner)
 
         cached_cache_len = [0]
         if size_callback is not None:
@@ -168,18 +215,45 @@ class LruCache:
             node.callbacks.clear()
             return deleted_len
 
+        @overload
+        def cache_get(
+            key: KT,
+            default: Literal[None] = None,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_get(
+            key: KT,
+            default: T,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_get(key, default=None, callbacks=[]):
+        def cache_get(
+            key: KT,
+            default: Optional[T] = None,
+            callbacks: Iterable[Callable[[], None]] = [],
+            update_metrics: bool = True,
+        ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
                 node.callbacks.update(callbacks)
+                if update_metrics and metrics:
+                    metrics.inc_hits()
                 return node.value
             else:
+                if update_metrics and metrics:
+                    metrics.inc_misses()
                 return default
 
         @synchronized
-        def cache_set(key, value, callbacks=[]):
+        def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
             node = cache.get(key, None)
             if node is not None:
                 # We sometimes store large objects, e.g. dicts, which cause
@@ -208,7 +282,7 @@ class LruCache:
             evict()
 
         @synchronized
-        def cache_set_default(key, value):
+        def cache_set_default(key: KT, value: VT) -> VT:
             node = cache.get(key, None)
             if node is not None:
                 return node.value
@@ -217,8 +291,16 @@ class LruCache:
                 evict()
                 return value
 
+        @overload
+        def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_pop(key: KT, default: T) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_pop(key, default=None):
+        def cache_pop(key: KT, default: Optional[T] = None):
             node = cache.get(key, None)
             if node:
                 delete_node(node)
@@ -228,18 +310,18 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_del_multi(key):
+        def cache_del_multi(key: KT) -> None:
             """
             This will only work if constructed with cache_type=TreeCache
             """
             popped = cache.pop(key)
             if popped is None:
                 return
-            for leaf in enumerate_leaves(popped, keylen - len(key)):
+            for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
                 delete_node(leaf)
 
         @synchronized
-        def cache_clear():
+        def cache_clear() -> None:
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
@@ -250,15 +332,21 @@ class LruCache:
                 cached_cache_len[0] = 0
 
         @synchronized
-        def cache_contains(key):
+        def cache_contains(key: KT) -> bool:
             return key in cache
 
         self.sentinel = object()
+
+        # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
+
         self.get = cache_get
         self.set = cache_set
         self.setdefault = cache_set_default
         self.pop = cache_pop
+        # `invalidate` is exposed for consistency with DeferredCache, so that it can be
+        # invalidated by the cache invalidation replication stream.
+        self.invalidate = cache_pop
         if cache_type is TreeCache:
             self.del_multi = cache_del_multi
         self.len = synchronized(cache_len)
@@ -302,6 +390,7 @@ class LruCache:
         new_size = int(self._original_max_size * factor)
         if new_size != self.max_size:
             self.max_size = new_size
-            self._on_resize()
+            if self._on_resize:
+                self._on_resize()
             return True
         return False
diff --git a/synmark/__init__.py b/synmark/__init__.py
index 53698bd5ab..09bc7e7927 100644
--- a/synmark/__init__.py
+++ b/synmark/__init__.py
@@ -15,7 +15,10 @@
 
 import sys
 
-from twisted.internet import epollreactor
+try:
+    from twisted.internet.epollreactor import EPollReactor as Reactor
+except ImportError:
+    from twisted.internet.pollreactor import PollReactor as Reactor
 from twisted.internet.main import installReactor
 
 from synapse.config.homeserver import HomeServerConfig
@@ -41,7 +44,7 @@ async def make_homeserver(reactor, config=None):
     config_obj = HomeServerConfig()
     config_obj.parse_config_dict(config, "", "")
 
-    hs = await setup_test_homeserver(
+    hs = setup_test_homeserver(
         cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
     )
     stor = hs.get_datastore()
@@ -63,7 +66,7 @@ def make_reactor():
     Instantiate and install a Twisted reactor suitable for testing (i.e. not the
     default global one).
     """
-    reactor = epollreactor.EPollReactor()
+    reactor = Reactor()
 
     if "twisted.internet.reactor" in sys.modules:
         del sys.modules["twisted.internet.reactor"]
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 2acb8b7603..97f8cad0dd 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -260,6 +260,31 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
         self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
         self.assertEquals(3, self.txn_ctrl.send.call_count)
 
+    def test_send_large_txns(self):
+        srv_1_defer = defer.Deferred()
+        srv_2_defer = defer.Deferred()
+        send_return_list = [srv_1_defer, srv_2_defer]
+
+        def do_send(x, y, z):
+            return make_deferred_yieldable(send_return_list.pop(0))
+
+        self.txn_ctrl.send = Mock(side_effect=do_send)
+
+        service = Mock(id=4, name="service")
+        event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
+        for event in event_list:
+            self.queuer.enqueue_event(service, event)
+
+        # Expect the first event to be sent immediately.
+        self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
+        srv_1_defer.callback(service)
+        # Then send the next 100 events
+        self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
+        srv_2_defer.callback(service)
+        # Then the final 99 events
+        self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
+        self.assertEquals(3, self.txn_ctrl.send.call_count)
+
     def test_send_single_ephemeral_no_queue(self):
         # Expect the event to be sent immediately.
         service = Mock(id=4, name="service")
@@ -296,3 +321,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
         # Expect the queued events to be sent
         self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
         self.assertEquals(2, self.txn_ctrl.send.call_count)
+
+    def test_send_large_txns_ephemeral(self):
+        d = defer.Deferred()
+        self.txn_ctrl.send = Mock(
+            side_effect=lambda x, y, z: make_deferred_yieldable(d)
+        )
+        # Expect the event to be sent immediately.
+        service = Mock(id=4, name="service")
+        first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
+        second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
+        event_list = first_chunk + second_chunk
+        self.queuer.enqueue_ephemeral(service, event_list)
+        self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
+        d.callback(service)
+        self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
+        self.assertEquals(2, self.txn_ctrl.send.call_count)
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 4cf81f7128..fd128b88e0 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -78,7 +78,7 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
             "server_name",
             "name",
         ]
-        self.assertEqual(set(log.keys()), set(expected_log_keys))
+        self.assertCountEqual(log.keys(), expected_log_keys)
 
         # It contains the data we expect.
         self.assertEqual(log["name"], "wally")
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 3224568640..55545d9341 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -158,8 +158,21 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about those messages
         self._check_for_mail()
 
+    def test_encrypted_message(self):
+        room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+        self.helper.invite(
+            room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+        )
+        self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+        # The other user sends some messages
+        self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token)
+
+        # We should get emailed about that message
+        self._check_for_mail()
+
     def _check_for_mail(self):
-        "Check that the user receives an email notification"
+        """Check that the user receives an email notification"""
 
         # Get the stream ordering before it gets sent
         pushers = self.get_success(
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index ae2cd67f35..66ac4dbe85 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -352,7 +352,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         self.render(request)
         self.assertEqual(request.code, 401)
 
-    @unittest.INFO
     def test_pending_invites(self):
         """Tests that deactivating a user rejects every pending invite for them."""
         store = self.hs.get_datastore()
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 293ccfba2b..86184f0d2e 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -104,7 +104,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(len(attempts), 1)
         self.assertEqual(attempts[0][0]["response"], "a")
 
-    @unittest.INFO
     def test_fallback_captcha(self):
         """Ensure that fallback auth via a captcha works."""
         # Returns a 401 as per the spec
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8e69b1e9cc..1ac4ebc61d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -15,237 +15,9 @@
 # limitations under the License.
 
 
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached
-
 from tests import unittest
 
 
-class CacheDecoratorTestCase(unittest.HomeserverTestCase):
-    @defer.inlineCallbacks
-    def test_passthrough(self):
-        class A:
-            @cached()
-            def func(self, key):
-                return key
-
-        a = A()
-
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals((yield a.func("bar")), "bar")
-
-    @defer.inlineCallbacks
-    def test_hit(self):
-        callcount = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 1)
-
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals(callcount[0], 1)
-
-    @defer.inlineCallbacks
-    def test_invalidate(self):
-        callcount = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 1)
-
-        a.func.invalidate(("foo",))
-
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-
-    def test_invalidate_missing(self):
-        class A:
-            @cached()
-            def func(self, key):
-                return key
-
-        A().func.invalidate(("what",))
-
-    @defer.inlineCallbacks
-    def test_max_entries(self):
-        callcount = [0]
-
-        class A:
-            @cached(max_entries=10)
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-
-        for k in range(0, 12):
-            yield a.func(k)
-
-        self.assertEquals(callcount[0], 12)
-
-        # There must have been at least 2 evictions, meaning if we calculate
-        # all 12 values again, we must get called at least 2 more times
-        for k in range(0, 12):
-            yield a.func(k)
-
-        self.assertTrue(
-            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
-        )
-
-    def test_prefill(self):
-        callcount = [0]
-
-        d = defer.succeed(123)
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return d
-
-        a = A()
-
-        a.func.prefill(("foo",), ObservableDeferred(d))
-
-        self.assertEquals(a.func("foo").result, d.result)
-        self.assertEquals(callcount[0], 0)
-
-    @defer.inlineCallbacks
-    def test_invalidate_context(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
-
-        a.func.invalidate(("foo",))
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 1)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-    @defer.inlineCallbacks
-    def test_eviction_context(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached(max_entries=2)
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        yield a.func2("foo")
-        yield a.func2("foo2")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func("foo3")
-
-        self.assertEquals(callcount[0], 3)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 4)
-        self.assertEquals(callcount2[0], 3)
-
-    @defer.inlineCallbacks
-    def test_double_get(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
-
-        a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
-
-        yield a.func2("foo")
-        a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 2)
-
-        a.func.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 3)
-
-
 class UpsertManyTests(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.storage = hs.get_datastore()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 9717be56b6..dadfabd46d 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -13,15 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import unittest
 from functools import partial
 
 from twisted.internet import defer
 
 from synapse.util.caches.deferred_cache import DeferredCache
 
+from tests.unittest import TestCase
 
-class DeferredCacheTestCase(unittest.TestCase):
+
+class DeferredCacheTestCase(TestCase):
     def test_empty(self):
         cache = DeferredCache("test")
         failed = False
@@ -36,7 +37,118 @@ class DeferredCacheTestCase(unittest.TestCase):
         cache = DeferredCache("test")
         cache.prefill("foo", 123)
 
-        self.assertEquals(cache.get("foo"), 123)
+        self.assertEquals(self.successResultOf(cache.get("foo")), 123)
+
+    def test_hit_deferred(self):
+        cache = DeferredCache("test")
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d)
+
+        # get should return an incomplete deferred
+        get_d = cache.get("k1")
+        self.assertFalse(get_d.called)
+
+        # add a callback that will make sure that the set_d gets called before the get_d
+        def check1(r):
+            self.assertTrue(set_d.called)
+            return r
+
+        # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
+        #   maybe we should fix that?
+        # get_d.addCallback(check1)
+
+        # now fire off all the deferreds
+        origin_d.callback(99)
+        self.assertEqual(self.successResultOf(origin_d), 99)
+        self.assertEqual(self.successResultOf(set_d), 99)
+        self.assertEqual(self.successResultOf(get_d), 99)
+
+    def test_callbacks(self):
+        """Invalidation callbacks are called at the right time"""
+        cache = DeferredCache("test")
+        callbacks = set()
+
+        # start with an entry, with a callback
+        cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+        # now replace that entry with a pending result
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+        # ... and also make a get request
+        get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+        # we don't expect the invalidation callback for the original value to have
+        # been called yet, even though get() will now return a different result.
+        # I'm not sure if that is by design or not.
+        self.assertEqual(callbacks, set())
+
+        # now fire off all the deferreds
+        origin_d.callback(20)
+        self.assertEqual(self.successResultOf(set_d), 20)
+        self.assertEqual(self.successResultOf(get_d), 20)
+
+        # now the original invalidation callback should have been called, but none of
+        # the others
+        self.assertEqual(callbacks, {"prefill"})
+        callbacks.clear()
+
+        # another update should invalidate both the previous results
+        cache.prefill("k1", 30)
+        self.assertEqual(callbacks, {"set", "get"})
+
+    def test_set_fail(self):
+        cache = DeferredCache("test")
+        callbacks = set()
+
+        # start with an entry, with a callback
+        cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+        # now replace that entry with a pending result
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+        # ... and also make a get request
+        get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+        # none of the callbacks should have been called yet
+        self.assertEqual(callbacks, set())
+
+        # oh noes! fails!
+        e = Exception("oops")
+        origin_d.errback(e)
+        self.assertIs(self.failureResultOf(set_d, Exception).value, e)
+        self.assertIs(self.failureResultOf(get_d, Exception).value, e)
+
+        # the callbacks for the failed requests should have been called.
+        # I'm not sure if this is deliberate or not.
+        self.assertEqual(callbacks, {"get", "set"})
+        callbacks.clear()
+
+        # the old value should still be returned now?
+        get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
+        self.assertEqual(self.successResultOf(get_d2), 10)
+
+        # replacing the value now should run the callbacks for those requests
+        # which got the original result
+        cache.prefill("k1", 30)
+        self.assertEqual(callbacks, {"prefill", "get2"})
+
+    def test_get_immediate(self):
+        cache = DeferredCache("test")
+        d1 = defer.Deferred()
+        cache.set("key1", d1)
+
+        # get_immediate should return default
+        v = cache.get_immediate("key1", 1)
+        self.assertEqual(v, 1)
+
+        # now complete the set
+        d1.callback(2)
+
+        # get_immediate should return result
+        v = cache.get_immediate("key1", 1)
+        self.assertEqual(v, 2)
 
     def test_invalidate(self):
         cache = DeferredCache("test")
@@ -66,23 +178,24 @@ class DeferredCacheTestCase(unittest.TestCase):
         d2 = defer.Deferred()
         cache.set("key2", d2, partial(record_callback, 1))
 
-        # lookup should return observable deferreds
-        self.assertFalse(cache.get("key1").has_called())
-        self.assertFalse(cache.get("key2").has_called())
+        # lookup should return pending deferreds
+        self.assertFalse(cache.get("key1").called)
+        self.assertFalse(cache.get("key2").called)
 
         # let one of the lookups complete
         d2.callback("result2")
 
-        # for now at least, the cache will return real results rather than an
-        # observabledeferred
-        self.assertEqual(cache.get("key2"), "result2")
+        # now the cache will return a completed deferred
+        self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
 
         # now do the invalidation
         cache.invalidate_all()
 
-        # lookup should return none
-        self.assertIsNone(cache.get("key1", None))
-        self.assertIsNone(cache.get("key2", None))
+        # lookup should fail
+        with self.assertRaises(KeyError):
+            cache.get("key1")
+        with self.assertRaises(KeyError):
+            cache.get("key2")
 
         # both callbacks should have been callbacked
         self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
@@ -90,7 +203,8 @@ class DeferredCacheTestCase(unittest.TestCase):
 
         # letting the other lookup complete should do nothing
         d1.callback("result1")
-        self.assertIsNone(cache.get("key1", None))
+        with self.assertRaises(KeyError):
+            cache.get("key1", None)
 
     def test_eviction(self):
         cache = DeferredCache(
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 3d1f960869..2ad08f541b 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Set
 
 import mock
 
@@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
+    def test_cache_with_async_exception(self):
+        """The wrapped function returns a failure
+        """
+
+        class Cls:
+            result = None
+            call_count = 0
+
+            @cached()
+            def fn(self, arg1):
+                self.call_count += 1
+                return self.result
+
+        obj = Cls()
+        callbacks = set()  # type: Set[str]
+
+        # set off an asynchronous request
+        obj.result = origin_d = defer.Deferred()
+
+        d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
+        self.assertFalse(d1.called)
+
+        # a second request should also return a deferred, but should not call the
+        # function itself.
+        d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
+        self.assertFalse(d2.called)
+        self.assertEqual(obj.call_count, 1)
+
+        # no callbacks yet
+        self.assertEqual(callbacks, set())
+
+        # the original request fails
+        e = Exception("bzz")
+        origin_d.errback(e)
+
+        # ... which should cause the lookups to fail similarly
+        self.assertIs(self.failureResultOf(d1, Exception).value, e)
+        self.assertIs(self.failureResultOf(d2, Exception).value, e)
+
+        # ... and the callbacks to have been, uh, called.
+        self.assertEqual(callbacks, {"d1", "d2"})
+
+        # ... leaving the cache empty
+        self.assertEqual(len(obj.fn.cache.cache), 0)
+
+        # and a second call should work as normal
+        obj.result = defer.succeed(100)
+        d3 = obj.fn(1)
+        self.assertEqual(self.successResultOf(d3), 100)
+        self.assertEqual(obj.call_count, 2)
+
     def test_cache_logcontexts(self):
         """Check that logcontexts are set and restored correctly when
         using the cache."""
@@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase):
         self.failureResultOf(d, SynapseError)
 
 
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
+    """More tests for @cached
+
+    The following is a set of tests that got lost in a different file for a while.
+
+    There are probably duplicates of the tests in DescriptorTestCase. Ideally the
+    duplicates would be removed and the two sets of classes combined.
+    """
+
+    @defer.inlineCallbacks
+    def test_passthrough(self):
+        class A:
+            @cached()
+            def func(self, key):
+                return key
+
+        a = A()
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals((yield a.func("bar")), "bar")
+
+    @defer.inlineCallbacks
+    def test_hit(self):
+        callcount = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 1)
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals(callcount[0], 1)
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        callcount = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 1)
+
+        a.func.invalidate(("foo",))
+
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+
+    def test_invalidate_missing(self):
+        class A:
+            @cached()
+            def func(self, key):
+                return key
+
+        A().func.invalidate(("what",))
+
+    @defer.inlineCallbacks
+    def test_max_entries(self):
+        callcount = [0]
+
+        class A:
+            @cached(max_entries=10)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+
+        for k in range(0, 12):
+            yield a.func(k)
+
+        self.assertEquals(callcount[0], 12)
+
+        # There must have been at least 2 evictions, meaning if we calculate
+        # all 12 values again, we must get called at least 2 more times
+        for k in range(0, 12):
+            yield a.func(k)
+
+        self.assertTrue(
+            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
+        )
+
+    def test_prefill(self):
+        callcount = [0]
+
+        d = defer.succeed(123)
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return d
+
+        a = A()
+
+        a.func.prefill(("foo",), 456)
+
+        self.assertEquals(a.func("foo").result, 456)
+        self.assertEquals(callcount[0], 0)
+
+    @defer.inlineCallbacks
+    def test_invalidate_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func.invalidate(("foo",))
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 1)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+    @defer.inlineCallbacks
+    def test_eviction_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached(max_entries=2)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+        yield a.func2("foo2")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func("foo3")
+
+        self.assertEquals(callcount[0], 3)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 4)
+        self.assertEquals(callcount2[0], 3)
+
+    @defer.inlineCallbacks
+    def test_double_get(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+        yield a.func2("foo")
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 2)
+
+        a.func.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 3)
+
+
 class CachedListDescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_cache(self):
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 0adb2174af..a739a6aaaf 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,7 +19,8 @@ from mock import Mock
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache
 
-from .. import unittest
+from tests import unittest
+from tests.unittest import override_config
 
 
 class LruCacheTestCase(unittest.HomeserverTestCase):
@@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
         self.assertEquals(cache.pop("key"), None)
 
     def test_del_multi(self):
-        cache = LruCache(4, 2, cache_type=TreeCache)
+        cache = LruCache(4, keylen=2, cache_type=TreeCache)
         cache[("animal", "cat")] = "mew"
         cache[("animal", "dog")] = "woof"
         cache[("vehicles", "car")] = "vroom"
@@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
         cache.clear()
         self.assertEquals(len(cache), 0)
 
+    @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
+    def test_special_size(self):
+        cache = LruCache(10, "mycache")
+        self.assertEqual(cache.max_size, 100)
+
 
 class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
     def test_get(self):
@@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         m2 = Mock()
         m3 = Mock()
         m4 = Mock()
-        cache = LruCache(4, 2, cache_type=TreeCache)
+        cache = LruCache(4, keylen=2, cache_type=TreeCache)
 
         cache.set(("a", "1"), "value", callbacks=[m1])
         cache.set(("a", "2"), "value", callbacks=[m2])
diff --git a/tox.ini b/tox.ini
index 4d132eff4c..6dcc439a40 100644
--- a/tox.ini
+++ b/tox.ini
@@ -158,12 +158,9 @@ commands=
     coverage html
 
 [testenv:mypy]
-skip_install = True
 deps =
     {[base]deps}
-    mypy==0.782
-    mypy-zope
-extras = all
+extras = all,mypy
 commands = mypy
 
 # To find all folders that pass mypy you run: