summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-10-19 19:12:39 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-10-19 19:12:39 +0100
commit687d30b2eddd4feb530879c9499f248197ffba00 (patch)
treee3ee7460dcf934f1450359d29179079724a57f53
parentMerge commit '3c01724b3' into anoa/dinsic_release_1_21_x (diff)
parentRemove `ChainedIdGenerator`. (#8123) (diff)
downloadsynapse-687d30b2eddd4feb530879c9499f248197ffba00.tar.xz
Merge commit 'c9c544cda' into anoa/dinsic_release_1_21_x
* commit 'c9c544cda':
  Remove `ChainedIdGenerator`. (#8123)
  Switch the JSON byte producer from a pull to a push producer. (#8116)
  Updated docs: Added note about missing 308 redirect support. (#8120)
  Be stricter about JSON that is accepted by Synapse (#8106)
  Convert runWithConnection to async. (#8121)
  Remove the unused inlineCallbacks code-paths in the caching code (#8119)
  Separate `get_current_token` into two. (#8113)
  Convert events worker database to async/await. (#8071)
  Add a link to the matrix-synapse-rest-password-provider. (#8111)
-rw-r--r--changelog.d/8071.misc1
-rw-r--r--changelog.d/8106.bugfix1
-rw-r--r--changelog.d/8111.doc1
-rw-r--r--changelog.d/8113.misc1
-rw-r--r--changelog.d/8116.feature1
-rw-r--r--changelog.d/8119.misc1
-rw-r--r--changelog.d/8120.doc1
-rw-r--r--changelog.d/8121.misc1
-rw-r--r--changelog.d/8123.misc1
-rw-r--r--docs/federate.md12
-rw-r--r--docs/password_auth_providers.md1
-rw-r--r--synapse/api/errors.py6
-rw-r--r--synapse/event_auth.py2
-rw-r--r--synapse/federation/federation_server.py5
-rw-r--r--synapse/federation/sender/transaction_manager.py5
-rw-r--r--synapse/handlers/e2e_keys.py8
-rw-r--r--synapse/handlers/federation.py16
-rw-r--r--synapse/handlers/identity.py41
-rw-r--r--synapse/handlers/message.py11
-rw-r--r--synapse/handlers/oidc_handler.py6
-rw-r--r--synapse/handlers/profile.py8
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/ui_auth/checkers.py5
-rw-r--r--synapse/http/client.py11
-rw-r--r--synapse/http/federation/well_known_resolver.py5
-rw-r--r--synapse/http/server.py75
-rw-r--r--synapse/http/servlet.py5
-rw-r--r--synapse/logging/opentracing.py7
-rw-r--r--synapse/metrics/background_process_metrics.py2
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py8
-rw-r--r--synapse/replication/slave/storage/push_rule.py10
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/streams/_base.py4
-rw-r--r--synapse/rest/client/v1/push_rule.py2
-rw-r--r--synapse/rest/client/v1/room.py11
-rw-r--r--synapse/rest/client/v2_alpha/account.py18
-rw-r--r--synapse/rest/client/v2_alpha/sync.py5
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py8
-rw-r--r--synapse/spam_checker_api/__init__.py2
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/storage/_base.py7
-rw-r--r--synapse/storage/database.py27
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py30
-rw-r--r--synapse/storage/databases/main/events_worker.py148
-rw-r--r--synapse/storage/databases/main/push_rule.py36
-rw-r--r--synapse/storage/databases/main/stream.py1
-rw-r--r--synapse/storage/util/id_generators.py82
-rw-r--r--synapse/streams/events.py2
-rw-r--r--synapse/util/__init__.py14
-rw-r--r--synapse/util/caches/descriptors.py54
-rw-r--r--tests/handlers/test_profile.py31
-rw-r--r--tests/rest/client/v1/test_login.py16
-rw-r--r--tests/rest/client/v2_alpha/test_register.py4
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py6
-rw-r--r--tests/storage/test_appservice.py3
-rw-r--r--tests/storage/test_cleanup_extrems.py3
-rw-r--r--tests/storage/test_id_generators.py16
-rw-r--r--tests/storage/test_main.py8
-rw-r--r--tests/util/caches/test_descriptors.py12
60 files changed, 409 insertions, 419 deletions
diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8071.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8106.bugfix b/changelog.d/8106.bugfix
new file mode 100644
index 0000000000..c46c60448f
--- /dev/null
+++ b/changelog.d/8106.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where invalid JSON would be accepted by Synapse.
diff --git a/changelog.d/8111.doc b/changelog.d/8111.doc
new file mode 100644
index 0000000000..d3f7435452
--- /dev/null
+++ b/changelog.d/8111.doc
@@ -0,0 +1 @@
+Link to matrix-synapse-rest-password-provider in the password provider documentation.
diff --git a/changelog.d/8113.misc b/changelog.d/8113.misc
new file mode 100644
index 0000000000..00bec4f8ef
--- /dev/null
+++ b/changelog.d/8113.misc
@@ -0,0 +1 @@
+Separate `get_current_token` into two since there are two different use cases for it.
diff --git a/changelog.d/8116.feature b/changelog.d/8116.feature
new file mode 100644
index 0000000000..b1eaf1e78a
--- /dev/null
+++ b/changelog.d/8116.feature
@@ -0,0 +1 @@
+Iteratively encode JSON to avoid blocking the reactor.
diff --git a/changelog.d/8119.misc b/changelog.d/8119.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8119.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8120.doc b/changelog.d/8120.doc
new file mode 100644
index 0000000000..877ef79fd2
--- /dev/null
+++ b/changelog.d/8120.doc
@@ -0,0 +1 @@
+Updated documentation to note that Synapse does not follow `HTTP 308` redirects due to an upstream library not supporting them. Contributed by Ryan Cole.
\ No newline at end of file
diff --git a/changelog.d/8121.misc b/changelog.d/8121.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8121.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8123.misc b/changelog.d/8123.misc
new file mode 100644
index 0000000000..7245122896
--- /dev/null
+++ b/changelog.d/8123.misc
@@ -0,0 +1 @@
+Remove `ChainedIdGenerator`.
diff --git a/docs/federate.md b/docs/federate.md
index a0786b9cf7..b15cd724d1 100644
--- a/docs/federate.md
+++ b/docs/federate.md
@@ -47,6 +47,18 @@ you invite them to. This can be caused by an incorrectly-configured reverse
 proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly
 configure a reverse proxy.
 
+### Known issues
+
+**HTTP `308 Permanent Redirect` redirects are not followed**: Due to missing features
+in the HTTP library used by Synapse, 308 redirects are currently not followed by
+federating servers, which can cause `M_UNKNOWN` or `401 Unauthorized` errors. This
+may affect users who are redirecting apex-to-www (e.g. `example.com` -> `www.example.com`),
+and especially users of the Kubernetes *Nginx Ingress* module, which uses 308 redirect
+codes by default. For those Kubernetes users, [this Stackoverflow post](https://stackoverflow.com/a/52617528/5096871) 
+might be helpful. For other users, switching to a `301 Moved Permanently` code may be
+an option. 308 redirect codes will be supported properly in a future
+release of Synapse.
+
 ## Running a demo federation of Synapses
 
 If you want to get up and running quickly with a trio of homeservers in a
diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md
index fef1d47e85..7d98d9f255 100644
--- a/docs/password_auth_providers.md
+++ b/docs/password_auth_providers.md
@@ -14,6 +14,7 @@ password auth provider module implementations:
 
 * [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/)
 * [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth)
+* [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider)
 
 ## Required methods
 
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index a3b0c0a3e4..28a078a7b4 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -22,10 +22,10 @@ import typing
 from http import HTTPStatus
 from typing import Dict, List, Optional, Union
 
-from canonicaljson import json
-
 from twisted.web import http
 
+from synapse.util import json_decoder
+
 if typing.TYPE_CHECKING:
     from synapse.types import JsonDict
 
@@ -594,7 +594,7 @@ class HttpResponseException(CodeMessageException):
         # try to parse the body as json, to get better errcode/msg, but
         # default to M_UNKNOWN with the HTTP status as the error text
         try:
-            j = json.loads(self.response.decode("utf-8"))
+            j = json_decoder.decode(self.response.decode("utf-8"))
         except ValueError:
             j = {}
 
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c0981eee62..8c907ad596 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -47,7 +47,7 @@ def check(
     Args:
         room_version_obj: the version of the room
         event: the event being checked.
-        auth_events (dict: event-key -> event): the existing room state.
+        auth_events: the existing room state.
 
     Raises:
         AuthError if the checks fail
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 11c5d63298..630f571cd4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -28,7 +28,6 @@ from typing import (
     Union,
 )
 
-from canonicaljson import json
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
@@ -63,7 +62,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import glob_to_regex, unwrapFirstError
+from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -551,7 +550,7 @@ class FederationServer(FederationBase):
             for device_id, keys in device_keys.items():
                 for key_id, json_str in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json.loads(json_str)
+                        key_id: json_decoder.decode(json_str)
                     }
 
         logger.info(
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index c7f6cb3d73..9bd534a313 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -15,8 +15,6 @@
 import logging
 from typing import TYPE_CHECKING, List, Tuple
 
-from canonicaljson import json
-
 from synapse.api.errors import HttpResponseException
 from synapse.events import EventBase
 from synapse.federation.persistence import TransactionActions
@@ -28,6 +26,7 @@ from synapse.logging.opentracing import (
     tags,
     whitelisted_homeserver,
 )
+from synapse.util import json_decoder
 from synapse.util.metrics import measure_func
 
 if TYPE_CHECKING:
@@ -71,7 +70,7 @@ class TransactionManager(object):
         for edu in pending_edus:
             context = edu.get_context()
             if context:
-                span_contexts.append(extract_text_map(json.loads(context)))
+                span_contexts.append(extract_text_map(json_decoder.decode(context)))
             if keep_destination:
                 edu.strip_context()
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 84169c1022..d8def45e38 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -19,7 +19,7 @@ import logging
 from typing import Dict, List, Optional, Tuple
 
 import attr
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 from signedjson.key import VerifyKey, decode_verify_key_bytes
 from signedjson.sign import SignatureVerifyException, verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -35,7 +35,7 @@ from synapse.types import (
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
-from synapse.util import unwrapFirstError
+from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
@@ -404,7 +404,7 @@ class E2eKeysHandler(object):
             for device_id, keys in device_keys.items():
                 for key_id, json_bytes in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json.loads(json_bytes)
+                        key_id: json_decoder.decode(json_bytes)
                     }
 
         @trace
@@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
 
 
 def _one_time_keys_match(old_key_json, new_key):
-    old_key = json.loads(old_key_json)
+    old_key = json_decoder.decode(old_key_json)
 
     # if either is a string rather than an object, they must match exactly
     if not isinstance(old_key, dict) or not isinstance(new_key, dict):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3b5681eb06..29863c029b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1787,9 +1787,7 @@ class FederationHandler(BaseHandler):
         """Returns the state at the event. i.e. not including said event.
         """
 
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups(room_id, [event_id])
 
@@ -1815,9 +1813,7 @@ class FederationHandler(BaseHandler):
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event.
         """
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
 
@@ -2165,9 +2161,9 @@ class FederationHandler(BaseHandler):
         auth_types = auth_types_for_event(event)
         current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
 
-        current_auth_events = await self.store.get_events(current_state_ids)
+        auth_events_map = await self.store.get_events(current_state_ids)
         current_auth_events = {
-            (e.type, e.state_key): e for e in current_auth_events.values()
+            (e.type, e.state_key): e for e in auth_events_map.values()
         }
 
         try:
@@ -2183,9 +2179,7 @@ class FederationHandler(BaseHandler):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         # Just go through and process each event in `remote_auth_chain`. We
         # don't want to fall into the trap of `missing` being wrong.
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index ef930dba55..b5676b248b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -21,9 +21,6 @@ import logging
 import urllib.parse
 from typing import Awaitable, Callable, Dict, List, Optional, Tuple
 
-from canonicaljson import json
-
-from twisted.internet import defer
 from twisted.internet.error import TimeoutError
 
 from synapse.api.errors import (
@@ -37,6 +34,7 @@ from synapse.api.errors import (
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http.client import SimpleHttpClient
 from synapse.types import JsonDict, Requester
+from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
 from synapse.util.stringutils import assert_valid_client_secret, random_string
 
@@ -197,7 +195,7 @@ class IdentityHandler(BaseHandler):
         except TimeoutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except CodeMessageException as e:
-            data = json.loads(e.msg)  # XXX WAT?
+            data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
         logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
@@ -620,18 +618,19 @@ class IdentityHandler(BaseHandler):
     # the CS API. They should be consolidated with those in RoomMemberHandler
     # https://github.com/matrix-org/synapse-dinsic/issues/25
 
-    @defer.inlineCallbacks
-    def proxy_lookup_3pid(self, id_server, medium, address):
+    async def proxy_lookup_3pid(
+        self, id_server: str, medium: str, address: str
+    ) -> JsonDict:
         """Looks up a 3pid in the passed identity server.
 
         Args:
-            id_server (str): The server name (including port, if required)
+            id_server: The server name (including port, if required)
                 of the identity server to use.
-            medium (str): The type of the third party identifier (e.g. "email").
-            address (str): The third party identifier (e.g. "foo@example.com").
+            medium: The type of the third party identifier (e.g. "email").
+            address: The third party identifier (e.g. "foo@example.com").
 
         Returns:
-            Deferred[dict]: The result of the lookup. See
+            The result of the lookup. See
             https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
             for details
         """
@@ -643,16 +642,11 @@ class IdentityHandler(BaseHandler):
         id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
         try:
-            data = yield self.http_client.get_json(
+            data = await self.http_client.get_json(
                 "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
                 {"medium": medium, "address": address},
             )
 
-            if "mxid" in data:
-                if "signatures" not in data:
-                    raise AuthError(401, "No signatures on 3pid binding")
-                yield self._verify_any_signature(data, id_server)
-
         except HttpResponseException as e:
             logger.info("Proxied lookup failed: %r", e)
             raise e.to_synapse_error()
@@ -662,18 +656,19 @@ class IdentityHandler(BaseHandler):
 
         return data
 
-    @defer.inlineCallbacks
-    def proxy_bulk_lookup_3pid(self, id_server, threepids):
+    async def proxy_bulk_lookup_3pid(
+        self, id_server: str, threepids: List[List[str]]
+    ) -> JsonDict:
         """Looks up given 3pids in the passed identity server.
 
         Args:
-            id_server (str): The server name (including port, if required)
+            id_server: The server name (including port, if required)
                 of the identity server to use.
-            threepids ([[str, str]]): The third party identifiers to lookup, as
+            threepids: The third party identifiers to lookup, as
                 a list of 2-string sized lists ([medium, address]).
 
         Returns:
-            Deferred[dict]: The result of the lookup. See
+            The result of the lookup. See
             https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
             for details
         """
@@ -685,7 +680,7 @@ class IdentityHandler(BaseHandler):
         id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
 
         try:
-            data = yield self.http_client.post_json_get_json(
+            data = await self.http_client.post_json_get_json(
                 "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
                 {"threepids": threepids},
             )
@@ -697,7 +692,7 @@ class IdentityHandler(BaseHandler):
             logger.info("Failed to contact %s: %s", id_server, e)
             raise ProxiedRequestError(503, "Failed to contact identity server")
 
-        defer.returnValue(data)
+        return data
 
     async def lookup_3pid(
         self,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 465b2a19bc..d5b12403f9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -17,7 +17,7 @@
 import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 
 from twisted.internet.interfaces import IDelayedCall
 
@@ -55,6 +55,7 @@ from synapse.types import (
     UserID,
     create_requester,
 )
+from synapse.util import json_decoder
 from synapse.util.async_helpers import Linearizer
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.metrics import measure_func
@@ -867,7 +868,7 @@ class EventCreationHandler(object):
         # Ensure that we can round trip before trying to persist in db
         try:
             dump = frozendict_json_encoder.encode(event.content)
-            json.loads(dump)
+            json_decoder.decode(dump)
         except Exception:
             logger.exception("Failed to encode content: %r", event.content)
             raise
@@ -963,7 +964,7 @@ class EventCreationHandler(object):
                     allow_none=True,
                 )
 
-                is_admin_redaction = (
+                is_admin_redaction = bool(
                     original_event and event.sender != original_event.sender
                 )
 
@@ -1083,8 +1084,8 @@ class EventCreationHandler(object):
             auth_events_ids = self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
-            auth_events = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+            auth_events_map = await self.store.get_events(auth_events_ids)
+            auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
 
             room_version = await self.store.get_room_version_id(event.room_id)
             room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 87d28a7ae9..dd3703cbd2 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
 from urllib.parse import urlencode
@@ -39,6 +38,7 @@ from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -367,7 +367,7 @@ class OidcHandler:
             # and check for an error field. If not, we respond with a generic
             # error message.
             try:
-                resp = json.loads(resp_body.decode("utf-8"))
+                resp = json_decoder.decode(resp_body.decode("utf-8"))
                 error = resp["error"]
                 description = resp.get("error_description", error)
             except (ValueError, KeyError):
@@ -384,7 +384,7 @@ class OidcHandler:
 
         # Since it is a not a 5xx code, body should be a valid JSON. It will
         # raise if not.
-        resp = json.loads(resp_body.decode("utf-8"))
+        resp = json_decoder.decode(resp_body.decode("utf-8"))
 
         if "error" in resp:
             error = resp["error"]
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index b05aa89455..4f3198896e 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -133,8 +133,12 @@ class BaseProfileHandler(BaseHandler):
         body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
         signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
         try:
-            yield self.http_client.post_json_get_json(url, signed_body)
-            yield self.store.update_replication_batch_for_host(host, batchnum)
+            yield defer.ensureDeferred(
+                self.http_client.post_json_get_json(url, signed_body)
+            )
+            yield defer.ensureDeferred(
+                self.store.update_replication_batch_for_host(host, batchnum)
+            )
             logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
         except Exception:
             # This will get retried when the looping call next comes around
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 3cd821340b..ca9b644d0b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -747,7 +747,7 @@ class RoomMemberHandler(object):
 
         guest_access = await self.store.get_event(guest_access_id)
 
-        return (
+        return bool(
             guest_access
             and guest_access.content
             and "guest_access" in guest_access.content
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index a011e9fe29..9146dc1a3b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -16,13 +16,12 @@
 import logging
 from typing import Any
 
-from canonicaljson import json
-
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
         except PartialDownloadError as pde:
             # Twisted is silly
             data = pde.response
-            resp_body = json.loads(data.decode("utf-8"))
+            resp_body = json_decoder.decode(data.decode("utf-8"))
 
         if "success" in resp_body:
             # Note that we do NOT check the hostname here: we explicitly
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8aeb70cdec..dad01a8e56 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -19,7 +19,7 @@ import urllib
 from io import BytesIO
 
 import treq
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 from netaddr import IPAddress
 from prometheus_client import Counter
 from zope.interface import implementer, provider
@@ -47,6 +47,7 @@ from synapse.http import (
 from synapse.http.proxyagent import ProxyAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 
 logger = logging.getLogger(__name__)
@@ -391,7 +392,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -433,7 +434,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -463,7 +464,7 @@ class SimpleHttpClient(object):
             actual_headers.update(headers)
 
         body = await self.get_raw(uri, args, headers=headers)
-        return json.loads(body.decode("utf-8"))
+        return json_decoder.decode(body.decode("utf-8"))
 
     async def put_json(self, uri, json_body, args={}, headers=None):
         """ Puts some json to the given URI.
@@ -506,7 +507,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 89a3b041ce..f794315deb 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 import random
 import time
@@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
 
 from synapse.logging.context import make_deferred_yieldable
-from synapse.util import Clock
+from synapse.util import Clock, json_decoder
 from synapse.util.caches.ttlcache import TTLCache
 from synapse.util.metrics import Measure
 
@@ -181,7 +180,7 @@ class WellKnownResolver(object):
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code,))
 
-            parsed_body = json.loads(body.decode("utf-8"))
+            parsed_body = json_decoder.decode(body.decode("utf-8"))
             logger.info("Response from .well-known: %s", parsed_body)
 
             result = parsed_body["m.server"].encode("ascii")
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 37fdf14405..8d791bd2ca 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -500,7 +500,7 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
     pass
 
 
-@implementer(interfaces.IPullProducer)
+@implementer(interfaces.IPushProducer)
 class _ByteProducer:
     """
     Iteratively write bytes to the request.
@@ -515,52 +515,64 @@ class _ByteProducer:
     ):
         self._request = request
         self._iterator = iterator
+        self._paused = False
 
-    def start(self) -> None:
-        self._request.registerProducer(self, False)
+        # Register the producer and start producing data.
+        self._request.registerProducer(self, True)
+        self.resumeProducing()
 
     def _send_data(self, data: List[bytes]) -> None:
         """
-        Send a list of strings as a response to the request.
+        Send a list of bytes as a chunk of a response.
         """
         if not data:
             return
         self._request.write(b"".join(data))
 
+    def pauseProducing(self) -> None:
+        self._paused = True
+
     def resumeProducing(self) -> None:
         # We've stopped producing in the meantime (note that this might be
         # re-entrant after calling write).
         if not self._request:
             return
 
-        # Get the next chunk and write it to the request.
-        #
-        # The output of the JSON encoder is coalesced until min_chunk_size is
-        # reached. (This is because JSON encoders produce a very small output
-        # per iteration.)
-        #
-        # Note that buffer stores a list of bytes (instead of appending to
-        # bytes) to hopefully avoid many allocations.
-        buffer = []
-        buffered_bytes = 0
-        while buffered_bytes < self.min_chunk_size:
-            try:
-                data = next(self._iterator)
-                buffer.append(data)
-                buffered_bytes += len(data)
-            except StopIteration:
-                # The entire JSON object has been serialized, write any
-                # remaining data, finalize the producer and the request, and
-                # clean-up any references.
-                self._send_data(buffer)
-                self._request.unregisterProducer()
-                self._request.finish()
-                self.stopProducing()
-                return
-
-        self._send_data(buffer)
+        self._paused = False
+
+        # Write until there's backpressure telling us to stop.
+        while not self._paused:
+            # Get the next chunk and write it to the request.
+            #
+            # The output of the JSON encoder is buffered and coalesced until
+            # min_chunk_size is reached. This is because JSON encoders produce
+            # very small output per iteration and the Request object converts
+            # each call to write() to a separate chunk. Without this there would
+            # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
+            #
+            # Note that buffer stores a list of bytes (instead of appending to
+            # bytes) to hopefully avoid many allocations.
+            buffer = []
+            buffered_bytes = 0
+            while buffered_bytes < self.min_chunk_size:
+                try:
+                    data = next(self._iterator)
+                    buffer.append(data)
+                    buffered_bytes += len(data)
+                except StopIteration:
+                    # The entire JSON object has been serialized, write any
+                    # remaining data, finalize the producer and the request, and
+                    # clean-up any references.
+                    self._send_data(buffer)
+                    self._request.unregisterProducer()
+                    self._request.finish()
+                    self.stopProducing()
+                    return
+
+            self._send_data(buffer)
 
     def stopProducing(self) -> None:
+        # Clear a circular reference.
         self._request = None
 
 
@@ -620,8 +632,7 @@ def respond_with_json(
     if send_cors:
         set_cors_headers(request)
 
-    producer = _ByteProducer(request, encoder(json_object))
-    producer.start()
+    _ByteProducer(request, encoder(json_object))
     return NOT_DONE_YET
 
 
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a34e5ead88..53acba56cb 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -17,9 +17,8 @@
 
 import logging
 
-from canonicaljson import json
-
 from synapse.api.errors import Codes, SynapseError
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
         return None
 
     try:
-        content = json.loads(content_bytes.decode("utf-8"))
+        content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
         logger.warning("Unable to parse JSON: %s", e)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 21dbd9f415..abe532d350 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -177,6 +177,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.config import ConfigError
+from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.http.site import SynapseRequest
@@ -499,7 +500,9 @@ def start_active_span_from_edu(
     if opentracing is None:
         return _noop_context_manager()
 
-    carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+    carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
+        "opentracing", {}
+    )
     context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
     _references = [
         opentracing.child_of(span_context_from_string(x))
@@ -699,7 +702,7 @@ def span_context_from_string(carrier):
     Returns:
         The active span context decoded from a string.
     """
-    carrier = json.loads(carrier)
+    carrier = json_decoder.decode(carrier)
     return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
 
 
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index f766d16db6..4cd7932e5b 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -175,7 +175,7 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
     It returns a Deferred which completes when the function completes, but it doesn't
     follow the synapse logcontext rules, which makes it appropriate for passing to
     clock.looping_call and friends (or for firing-and-forgetting in the middle of a
-    normal synapse inlineCallbacks function).
+    normal synapse async function).
 
     Args:
         desc: a description for this background process type
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..d43eaf3a29 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
             int
         """
         return self._current
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 590187df46..90d90833f9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
@@ -21,16 +22,13 @@ from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def get_push_rules_stream_token(self):
-        return (
-            self._push_rules_stream_id_gen.get_current_token(),
-            self._stream_id_gen.get_current_token(),
-        )
-
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
+        # We assert this for the benefit of mypy
+        assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
         if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index d853e4447e..8cd47770c1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -21,9 +21,7 @@ import abc
 import logging
 from typing import Tuple, Type
 
-from canonicaljson import json
-
-from synapse.util import json_encoder as _json_encoder
+from synapse.util import json_decoder, json_encoder
 
 logger = logging.getLogger(__name__)
 
@@ -125,7 +123,7 @@ class RdataCommand(Command):
             stream_name,
             instance_name,
             None if token == "batch" else int(token),
-            json.loads(row_json),
+            json_decoder.decode(row_json),
         )
 
     def to_line(self):
@@ -134,7 +132,7 @@ class RdataCommand(Command):
                 self.stream_name,
                 self.instance_name,
                 str(self.token) if self.token is not None else "batch",
-                _json_encoder.encode(self.row),
+                json_encoder.encode(self.row),
             )
         )
 
@@ -359,7 +357,7 @@ class UserIpCommand(Command):
     def from_line(cls, line):
         user_id, jsn = line.split(" ", 1)
 
-        access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
+        access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
 
         return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
 
@@ -367,7 +365,7 @@ class UserIpCommand(Command):
         return (
             self.user_id
             + " "
-            + _json_encoder.encode(
+            + json_encoder.encode(
                 (
                     self.access_token,
                     self.ip,
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a42de3f7d..8c3caf30c9 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
         )
 
     def _current_token(self, instance_name: str) -> int:
-        push_rules_token, _ = self.store.get_push_rules_stream_token()
+        push_rules_token = self.store.get_max_push_rules_stream_id()
         return push_rules_token
 
 
@@ -405,7 +405,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
+            store.get_cache_stream_token_for_writer,
             store.get_all_updated_caches,
         )
 
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e2df638cc5..e781a3bcf4 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -159,7 +159,7 @@ class PushRuleRestServlet(RestServlet):
         return 200, {}
 
     def notify_user(self, user_id):
-        stream_id, _ = self.store.get_push_rules_stream_token()
+        stream_id = self.store.get_max_push_rules_stream_id()
         self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
 
     async def set_rule_attr(self, user_id, spec, val):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index ed82144475..bc914d920e 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,8 +21,6 @@ import re
 from typing import List, Optional
 from urllib import parse as urlparse
 
-from canonicaljson import json
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
@@ -46,6 +44,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.util import json_decoder
 
 MYPY = False
 if MYPY:
@@ -519,7 +518,9 @@ class RoomMessageListRestServlet(RestServlet):
         filter_str = parse_string(request, b"filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(json.loads(filter_json))  # type: Optional[Filter]
+            event_filter = Filter(
+                json_decoder.decode(filter_json)
+            )  # type: Optional[Filter]
             if (
                 event_filter
                 and event_filter.filter_json.get("event_format", "client")
@@ -631,7 +632,9 @@ class RoomEventContextServlet(RestServlet):
         filter_str = parse_string(request, b"filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(json.loads(filter_json))  # type: Optional[Filter]
+            event_filter = Filter(
+                json_decoder.decode(filter_json)
+            )  # type: Optional[Filter]
         else:
             event_filter = None
 
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 7363d8d989..6b945e1849 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1005,12 +1005,11 @@ class ThreepidLookupRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.identity_handler = hs.get_handlers().identity_handler
 
-    @defer.inlineCallbacks
-    def on_GET(self, request):
+    async def on_GET(self, request):
         """Proxy a /_matrix/identity/api/v1/lookup request to an identity
         server
         """
-        yield self.auth.get_user_by_req(request)
+        await self.auth.get_user_by_req(request)
 
         # Verify query parameters
         query_params = request.args
@@ -1023,9 +1022,9 @@ class ThreepidLookupRestServlet(RestServlet):
 
         # Proxy the request to the identity server. lookup_3pid handles checking
         # if the lookup is allowed so we don't need to do it here.
-        ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+        ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
 
-        defer.returnValue((200, ret))
+        return 200, ret
 
 
 class ThreepidBulkLookupRestServlet(RestServlet):
@@ -1036,12 +1035,11 @@ class ThreepidBulkLookupRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.identity_handler = hs.get_handlers().identity_handler
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
+    async def on_POST(self, request):
         """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
         server
         """
-        yield self.auth.get_user_by_req(request)
+        await self.auth.get_user_by_req(request)
 
         body = parse_json_object_from_request(request)
 
@@ -1049,11 +1047,11 @@ class ThreepidBulkLookupRestServlet(RestServlet):
 
         # Proxy the request to the identity server. lookup_3pid handles checking
         # if the lookup is allowed so we don't need to do it here.
-        ret = yield self.identity_handler.proxy_bulk_lookup_3pid(
+        ret = await self.identity_handler.proxy_bulk_lookup_3pid(
             body["id_server"], body["threepids"]
         )
 
-        defer.returnValue((200, ret))
+        return 200, ret
 
 
 def assert_valid_next_link(hs: "HomeServer", next_link: str):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..96488b131a 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -16,8 +16,6 @@
 import itertools
 import logging
 
-from canonicaljson import json
-
 from synapse.api.constants import PresenceState
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.handlers.sync import SyncConfig
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.types import StreamToken
+from synapse.util import json_decoder
 
 from ._base import client_patterns, set_timeline_upper_limit
 
@@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet):
             filter_collection = DEFAULT_FILTER_COLLECTION
         elif filter_id.startswith("{"):
             try:
-                filter_object = json.loads(filter_id)
+                filter_object = json_decoder.decode(filter_id)
                 set_timeline_upper_limit(
                     filter_object, self.hs.config.filter_timeline_limit
                 )
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index e266204f95..5db7f81c2d 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,19 +15,19 @@
 import logging
 from typing import Dict, Set
 
-from canonicaljson import json
 from signedjson.sign import sign_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
 
 class RemoteKey(DirectServeJsonResource):
-    """HTTP resource for retreiving the TLS certificate and NACL signature
+    """HTTP resource for retrieving the TLS certificate and NACL signature
     verification keys for a collection of servers. Checks that the reported
     X.509 TLS certificate matches the one used in the HTTPS connection. Checks
     that the NACL signature for the remote server is valid. Returns a dict of
@@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource):
                     # Cast to bytes since postgresql returns a memoryview.
                     json_results.add(bytes(result["key_json"]))
 
+        # If there is a cache miss, request the missing keys, then recurse (and
+        # ensure the result is sent).
         if cache_misses and query_remote_on_cache_miss:
             await self.fetcher.get_keys(cache_misses)
             await self.query_keys(request, query, query_remote_on_cache_miss=False)
         else:
             signed_keys = []
             for key_json in json_results:
-                key_json = json.loads(key_json.decode("utf-8"))
+                key_json = json_decoder.decode(key_json.decode("utf-8"))
                 for signing_key in self.config.key_server_signing_keys:
                     key_json = sign_json(key_json, self.config.server_name, signing_key)
 
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 9b78924d96..4d9b13ac04 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -51,5 +51,5 @@ class SpamCheckerApi(object):
         state_ids = yield self._store.get_filtered_current_state_ids(
             room_id=room_id, state_filter=StateFilter.from_types(types)
         )
-        state = yield self._store.get_events(state_ids.values())
+        state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
         return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index a1d3884667..dba8d91eef 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -641,7 +641,7 @@ class StateResolutionStore(object):
             allow_rejected (bool): If True return rejected events.
 
         Returns:
-            Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+            Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
         """
 
         return self.store.get_events(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6814bf5fcf..ab49d227de 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,12 +19,11 @@ import random
 from abc import ABCMeta
 from typing import Any, Optional
 
-from canonicaljson import json
-
 from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
 from synapse.types import Collection, get_domain_from_id
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -99,13 +98,13 @@ def db_to_json(db_content):
     if isinstance(db_content, memoryview):
         db_content = db_content.tobytes()
 
-    # Decode it to a Unicode string before feeding it to json.loads, since
+    # Decode it to a Unicode string before feeding it to the JSON decoder, since
     # Python 3.5 does not support deserializing bytes.
     if isinstance(db_content, (bytes, bytearray)):
         db_content = db_content.decode("utf8")
 
     try:
-        return json.loads(db_content)
+        return json_decoder.decode(db_content)
     except Exception:
         logging.warning("Tried to decode '%r' as JSON and failed", db_content)
         raise
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 8a9e06efcf..b9aef96b08 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -516,14 +516,16 @@ class DatabasePool(object):
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
-            result = yield self.runWithConnection(
-                self.new_transaction,
-                desc,
-                after_callbacks,
-                exception_callbacks,
-                func,
-                *args,
-                **kwargs
+            result = yield defer.ensureDeferred(
+                self.runWithConnection(
+                    self.new_transaction,
+                    desc,
+                    after_callbacks,
+                    exception_callbacks,
+                    func,
+                    *args,
+                    **kwargs
+                )
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
@@ -535,8 +537,7 @@ class DatabasePool(object):
 
         return result
 
-    @defer.inlineCallbacks
-    def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+    async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
         Arguments:
@@ -547,7 +548,7 @@ class DatabasePool(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
         if not parent_context:
@@ -570,12 +571,10 @@ class DatabasePool(object):
 
                 return func(conn, *args, **kwargs)
 
-        result = yield make_deferred_yieldable(
+        return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
         )
 
-        return result
-
     @staticmethod
     def cursor_to_dict(cursor):
         """Converts a SQL cursor into an list of dicts.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 },
             )
 
-    def get_cache_stream_token(self, instance_name):
+    def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token(instance_name)
+            return self._cache_id_gen.get_current_token_for_writer(instance_name)
         else:
             return 0
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 431bd76693..4826be630c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def get_auth_chain(self, event_ids, include_given=False):
+    async def get_auth_chain(self, event_ids, include_given=False):
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
@@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns:
             list of events
         """
-        return self.get_auth_chain_ids(
+        event_ids = await self.get_auth_chain_ids(
             event_ids, include_given=include_given
-        ).addCallback(self.get_events_as_list)
+        )
+        return await self.get_events_as_list(event_ids)
 
     def get_auth_chain_ids(
         self,
@@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
-    def get_backfill_events(self, room_id, event_list, limit):
+    async def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
 
@@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             event_list (list)
             limit (int)
         """
-        return (
-            self.db_pool.runInteraction(
-                "get_backfill_events",
-                self._get_backfill_events,
-                room_id,
-                event_list,
-                limit,
-            )
-            .addCallback(self.get_events_as_list)
-            .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+        event_ids = await self.db_pool.runInteraction(
+            "get_backfill_events",
+            self._get_backfill_events,
+            room_id,
+            event_list,
+            limit,
         )
+        events = await self.get_events_as_list(event_ids)
+        return sorted(events, key=lambda e: -e.depth)
 
     def _get_backfill_events(self, txn, room_id, event_list, limit):
         logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             latest_events,
             limit,
         )
-        events = await self.get_events_as_list(ids)
-        return events
+        return await self.get_events_as_list(ids)
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8c63a0dc4d..4a3333c0db 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
 
 from constantly import NamedConstant, Names
+from typing_extensions import Literal
 
 from twisted.internet import defer
 
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersions,
 )
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
             desc="get_received_ts",
         )
 
-    @defer.inlineCallbacks
-    def get_event(
+    # Inform mypy that if allow_none is False (the default) then get_event
+    # always returns an EventBase.
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[False] = False,
+        check_room_id: Optional[str] = None,
+    ) -> EventBase:
+        ...
+
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[True] = False,
+        check_room_id: Optional[str] = None,
+    ) -> Optional[EventBase]:
+        ...
+
+    async def get_event(
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
         allow_rejected: bool = False,
         allow_none: bool = False,
         check_room_id: Optional[str] = None,
-    ):
+    ) -> Optional[EventBase]:
         """Get an event from the database by event_id.
 
         Args:
@@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
-            Deferred[EventBase|None]
+            The event, or None if the event was not found.
         """
         if not isinstance(event_id, str):
             raise TypeError("Invalid event event_id %r" % (event_id,))
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [event_id],
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event
 
-    @defer.inlineCallbacks
-    def get_events(
+    async def get_events(
         self,
-        event_ids: List[str],
+        event_ids: Iterable[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> Dict[str, EventBase]:
         """Get events from the database
 
         Args:
@@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejeted events from the response.
 
         Returns:
-            Deferred : Dict from event_id to event.
+            A mapping from event_id to event.
         """
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             event_ids,
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
-    @defer.inlineCallbacks
-    def get_events_as_list(
+    async def get_events_as_list(
         self,
-        event_ids: List[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> List[EventBase]:
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
@@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejected events from the response.
 
         Returns:
-            Deferred[list[EventBase]]: List of events fetched from the database. The
-            events are in the same order as `event_ids` arg.
+            List of events fetched from the database. The events are in the same
+            order as `event_ids` arg.
 
             Note that the returned list may be smaller than the list of event
             IDs if not all events could be fetched.
@@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
             return []
 
         # there may be duplicates so we cast the list to a set
-        event_entry_map = yield self._get_events_from_cache_or_db(
+        event_entry_map = await self._get_events_from_cache_or_db(
             set(event_ids), allow_rejected=allow_rejected
         )
 
@@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
                     continue
 
                 redacted_event_id = entry.event.redacts
-                event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+                event_map = await self._get_events_from_cache_or_db([redacted_event_id])
                 original_event_entry = event_map.get(redacted_event_id)
                 if not original_event_entry:
                     # we don't have the redacted event (or it was rejected).
@@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             if get_prev_content:
                 if "replaces_state" in event.unsigned:
-                    prev = yield self.get_event(
+                    prev = await self.get_event(
                         event.unsigned["replaces_state"],
                         get_prev_content=False,
                         allow_none=True,
@@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
-    @defer.inlineCallbacks
-    def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result
         """
         event_entry_map = self._get_events_from_cache(
@@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
             # the events have been redacted, and if so pulling the redaction event out
             # of the database to check it.
             #
-            missing_events = yield self._get_events_from_db(
+            missing_events = await self._get_events_from_db(
                 missing_events_ids, allow_rejected=allow_rejected
             )
 
@@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
-    @defer.inlineCallbacks
-    def _get_events_from_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the database.
 
         Returned events will be added to the cache for future lookups.
@@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result. May return extra events which
                 weren't asked for.
         """
@@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
         events_to_fetch = event_ids
 
         while events_to_fetch:
-            row_map = yield self._enqueue_events(events_to_fetch)
+            row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids = set()
@@ -574,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
             if not allow_rejected and rejected_reason:
                 continue
 
-            d = db_to_json(row["json"])
-            internal_metadata = db_to_json(row["internal_metadata"])
+            # If the event or metadata cannot be parsed, log the error and act
+            # as if the event is unknown.
+            try:
+                d = db_to_json(row["json"])
+            except ValueError:
+                logger.error("Unable to parse json from event: %s", event_id)
+                continue
+            try:
+                internal_metadata = db_to_json(row["internal_metadata"])
+            except ValueError:
+                logger.error(
+                    "Unable to parse internal_metadata from event: %s", event_id
+                )
+                continue
 
             format_version = row["format_version"]
             if format_version is None:
@@ -650,8 +684,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events):
+    async def _enqueue_events(self, events):
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -660,7 +693,7 @@ class EventsWorkerStore(SQLBaseStore):
             events (Iterable[str]): events to be fetched.
 
         Returns:
-            Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+            Dict[str, Dict]: map from event id to row data from the database.
                 May contain events that weren't requested.
         """
 
@@ -683,7 +716,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
-            row_map = yield events_d
+            row_map = await events_d
         logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
 
         return row_map
@@ -842,33 +875,29 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    @defer.inlineCallbacks
-    def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield defer.ensureDeferred(
-            self.db_pool.simple_select_many_batch(
-                table="events",
-                retcols=("event_id",),
-                column="event_id",
-                iterable=list(event_ids),
-                keyvalues={"outlier": False},
-                desc="have_events_in_timeline",
-            )
+        rows = await self.db_pool.simple_select_many_batch(
+            table="events",
+            retcols=("event_id",),
+            column="event_id",
+            iterable=list(event_ids),
+            keyvalues={"outlier": False},
+            desc="have_events_in_timeline",
         )
 
         return {r["event_id"] for r in rows}
 
-    @defer.inlineCallbacks
-    def have_seen_events(self, event_ids):
+    async def have_seen_events(self, event_ids):
         """Given a list of event ids, check if we have already processed them.
 
         Args:
             event_ids (iterable[str]):
 
         Returns:
-            Deferred[set[str]]: The events we have already seen.
+            set[str]: The events we have already seen.
         """
         results = set()
 
@@ -884,7 +913,7 @@ class EventsWorkerStore(SQLBaseStore):
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )
         return results
@@ -914,8 +943,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    @defer.inlineCallbacks
-    def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id):
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -926,9 +954,9 @@ class EventsWorkerStore(SQLBaseStore):
             room_id (str)
 
         Returns:
-            Deferred[dict[str:int]] of complexity version to complexity.
+            dict[str:int] of complexity version to complexity.
         """
-        state_events = yield self.get_current_state_event_counts(room_id)
+        state_events = await self.get_current_state_event_counts(room_id)
 
         # Call this one "v1", so we can introduce new ones as we want to develop
         # it.
@@ -1165,9 +1193,9 @@ class EventsWorkerStore(SQLBaseStore):
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_event_ordering(self, event_id):
-        res = yield self.db_pool.simple_select_one(
+    @cached(max_entries=5000)
+    async def get_event_ordering(self, event_id):
+        res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index c2289a9557..a585e54812 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
         super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen = ChainedIdGenerator(
-                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+            self._push_rules_stream_id_gen = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[StreamIdGenerator, SlavedIdTracker]
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
@@ -338,8 +338,9 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             if before or after:
                 await self.db_pool.runInteraction(
                     "_add_push_rule_relative_txn",
@@ -559,8 +560,9 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "delete_push_rule",
                 delete_push_rule_txn,
@@ -569,8 +571,9 @@ class PushRuleStore(PushRulesWorkerStore):
             )
 
     async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
                 self._set_push_rule_enabled_txn,
@@ -643,8 +646,9 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "set_push_rule_actions",
                 set_push_rule_actions_txn,
@@ -673,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_push_rules_stream_token(self):
-        """Get the position of the push rules stream.
-        Returns a pair of a stream id for the push_rules stream and the
-        room stream ordering it corresponds to."""
-        return self._push_rules_stream_id_gen.get_current_token()
-
     def get_max_push_rules_stream_id(self):
-        return self.get_push_rules_stream_token()[0]
+        return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4377bddb8c..497f607703 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         limit: int = 0,
         order: str = "DESC",
     ) -> Tuple[List[EventBase], str]:
-
         """Get new room events in stream ordering since `from_key`.
 
         Args:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e2ddd01290..0bf772d4d1 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,7 +16,7 @@
 import contextlib
 import threading
 from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, Set
 
 from typing_extensions import Deque
 
@@ -158,63 +158,13 @@ class StreamIdGenerator(object):
 
             return self._current
 
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
 
-class ChainedIdGenerator(object):
-    """Used to generate new stream ids where the stream must be kept in sync
-    with another stream. It generates pairs of IDs, the first element is an
-    integer ID for this stream, the second element is the ID for the stream
-    that this stream needs to be kept in sync with."""
-
-    def __init__(self, chained_generator, db_conn, table, column):
-        self.chained_generator = chained_generator
-        self._table = table
-        self._lock = threading.Lock()
-        self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
-
-    def get_next(self):
-        """
-        Usage:
-            with stream_id_gen.get_next() as (stream_id, chained_id):
-                # ... persist event ...
-        """
-        with self._lock:
-            self._current_max += 1
-            next_id = self._current_max
-            chained_id = self.chained_generator.get_current_token()
-
-            self._unfinished_ids.append((next_id, chained_id))
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield (next_id, chained_id)
-            finally:
-                with self._lock:
-                    self._unfinished_ids.remove((next_id, chained_id))
-
-        return manager()
-
-    def get_current_token(self):
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
+        For streams with single writers this is equivalent to
+        `get_current_token`.
         """
-        with self._lock:
-            if self._unfinished_ids:
-                stream_id, chained_id = self._unfinished_ids[0]
-                return stream_id - 1, chained_id
-
-            return self._current_max, self.chained_generator.get_current_token()
-
-    def advance(self, token: int):
-        """Stub implementation for advancing the token when receiving updates
-        over replication; raises an exception as this instance should be the
-        only source of updates.
-        """
-
-        raise Exception(
-            "Attempted to advance token on source for table %r", self._table
-        )
+        return self.get_current_token()
 
 
 class MultiWriterIdGenerator:
@@ -298,7 +248,7 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than what we currently
         # believe the ID to be. If not, then the sequence and table have got
         # out of sync somehow.
-        assert self.get_current_token() < next_id
+        assert self.get_current_token_for_writer(self._instance_name) < next_id
 
         with self._lock:
             self._unfinished_ids.add(next_id)
@@ -344,16 +294,18 @@ class MultiWriterIdGenerator:
                 curr = self._current_positions.get(self._instance_name, 0)
                 self._current_positions[self._instance_name] = max(curr, next_id)
 
-    def get_current_token(self, instance_name: str = None) -> int:
-        """Gets the current position of a named writer (defaults to current
-        instance).
-
-        Returns 0 if we don't have a position for the named writer (likely due
-        to it being a new writer).
+    def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
         """
 
-        if instance_name is None:
-            instance_name = self._instance_name
+        # Currently we don't support this operation, as it's not obvious how to
+        # condense the stream positions of multiple writers into a single int.
+        raise NotImplementedError()
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+        """
 
         with self._lock:
             return self._current_positions.get(instance_name, 0)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 393e34b9fb..7ab46f42bf 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -39,7 +39,7 @@ class EventSources(object):
         self.store = hs.get_datastore()
 
     def get_current_token(self) -> StreamToken:
-        push_rules_key, _ = self.store.get_push_rules_stream_token()
+        push_rules_key = self.store.get_max_push_rules_stream_id()
         to_device_key = self.store.get_to_device_stream_token()
         device_list_key = self.store.get_device_stream_token()
         groups_key = self.store.get_group_stream_token()
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index b3f76428b6..b2a22dbd5c 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -25,8 +25,18 @@ from synapse.logging import context
 
 logger = logging.getLogger(__name__)
 
-# Create a custom encoder to reduce the whitespace produced by JSON encoding.
-json_encoder = json.JSONEncoder(separators=(",", ":"))
+
+def _reject_invalid_json(val):
+    """Do not allow Infinity, -Infinity, or NaN values in JSON."""
+    raise json.JSONDecodeError("Invalid JSON value: '%s'" % val)
+
+
+# Create a custom encoder to reduce the whitespace produced by JSON encoding and
+# ensure that valid JSON is produced.
+json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
+
+# Create a custom decoder to reject Python extensions to JSON.
+json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
 
 
 def unwrapFirstError(failure):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index c2d72a82cf..49d9fddcf0 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -285,16 +285,9 @@ class Cache(object):
 
 
 class _CacheDescriptorBase(object):
-    def __init__(
-        self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
-    ):
+    def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
         self.orig = orig
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
 
@@ -364,7 +357,7 @@ class CacheDescriptor(_CacheDescriptorBase):
     invalidated) by adding a special "cache_context" argument to the function
     and passing that as a kwarg to all caches called. For example::
 
-        @cachedInlineCallbacks(cache_context=True)
+        @cached(cache_context=True)
         def foo(self, key, cache_context):
             r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
@@ -382,17 +375,11 @@ class CacheDescriptor(_CacheDescriptorBase):
         max_entries=1000,
         num_args=None,
         tree=False,
-        inlineCallbacks=False,
         cache_context=False,
         iterable=False,
     ):
 
-        super(CacheDescriptor, self).__init__(
-            orig,
-            num_args=num_args,
-            inlineCallbacks=inlineCallbacks,
-            cache_context=cache_context,
-        )
+        super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
         self.max_entries = max_entries
         self.tree = tree
@@ -465,9 +452,7 @@ class CacheDescriptor(_CacheDescriptorBase):
                     observer = defer.succeed(cached_result_d)
 
             except KeyError:
-                ret = defer.maybeDeferred(
-                    preserve_fn(self.function_to_call), obj, *args, **kwargs
-                )
+                ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
 
                 def onErr(f):
                     cache.invalidate(cache_key)
@@ -510,9 +495,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
     of results.
     """
 
-    def __init__(
-        self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
-    ):
+    def __init__(self, orig, cached_method_name, list_name, num_args=None):
         """
         Args:
             orig (function)
@@ -521,12 +504,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
             num_args (int): number of positional arguments (excluding ``self``,
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
-            inlineCallbacks (bool): Whether orig is a generator that should
-                be wrapped by defer.inlineCallbacks
         """
-        super(CacheListDescriptor, self).__init__(
-            orig, num_args=num_args, inlineCallbacks=inlineCallbacks
-        )
+        super().__init__(orig, num_args=num_args)
 
         self.list_name = list_name
 
@@ -631,7 +610,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
                 cached_defers.append(
                     defer.maybeDeferred(
-                        preserve_fn(self.function_to_call), **args_to_call
+                        preserve_fn(self.orig), **args_to_call
                     ).addCallbacks(complete_all, errback)
                 )
 
@@ -695,21 +674,7 @@ def cached(
     )
 
 
-def cachedInlineCallbacks(
-    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
-):
-    return lambda orig: CacheDescriptor(
-        orig,
-        max_entries=max_entries,
-        num_args=num_args,
-        tree=tree,
-        inlineCallbacks=True,
-        cache_context=cache_context,
-        iterable=iterable,
-    )
-
-
-def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None):
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -725,8 +690,6 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
             do batch lookups in the cache.
         num_args (int): Number of arguments to use as the key in the cache
             (including list_name). Defaults to all named parameters.
-        inlineCallbacks (bool): Should the function be wrapped in an
-            `defer.inlineCallbacks`?
 
     Example:
 
@@ -744,5 +707,4 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
-        inlineCallbacks=inlineCallbacks,
     )
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index ac629b8a28..d7f0c19c4c 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_my_name(self):
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+        )
 
         displayname = yield defer.ensureDeferred(
             self.handler.get_displayname(self.frank)
@@ -112,10 +114,17 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+        )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
+            "Frank",
         )
 
         # Setting displayname a second time is forbidden
@@ -158,7 +167,9 @@ class ProfileTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_incoming_fed_query(self):
         yield defer.ensureDeferred(self.store.create_profile("caroline"))
-        yield self.store.set_profile_displayname("caroline", "Caroline", 1)
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname("caroline", "Caroline", 1)
+        )
 
         response = yield defer.ensureDeferred(
             self.query_handlers["profile"](
@@ -170,8 +181,10 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_my_avatar(self):
-        yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png", 1
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.frank.localpart, "http://my.server/me.png", 1
+            )
         )
         avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
 
@@ -211,8 +224,10 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png", 1
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.frank.localpart, "http://my.server/me.png", 1
+            )
         )
 
         self.assertEquals(
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
                 "password": "monkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # than 1min.
         self.assertTrue(retry_after_ms < 6000)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         params = {
             "type": "m.login.password",
             "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
             "password": "monkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "monkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "monkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "notamonkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # than 1min.
         self.assertTrue(retry_after_ms < 6000)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         params = {
             "type": "m.login.password",
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "notamonkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index fce79b38f2..ecf697e5e0 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             else:
                 self.assertEquals(channel.result["code"], b"200", channel.result)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
         self.render(request)
@@ -182,7 +182,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             else:
                 self.assertEquals(channel.result["code"], b"200", channel.result)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
         self.render(request)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 2858d13558..23db821fb7 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
         self._rlsn._store.get_events = Mock(
-            return_value=defer.succeed({"123": mock_event})
+            return_value=make_awaitable({"123": mock_event})
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         # Would be better to check the content, but once == remove blocking event
@@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
         self._rlsn._store.get_events = Mock(
-            return_value=defer.succeed({"123": mock_event})
+            return_value=make_awaitable({"123": mock_event})
         )
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
         self._rlsn._store.get_events = Mock(
-            return_value=defer.succeed({"123": mock_event})
+            return_value=make_awaitable({"123": mock_event})
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index a425e66f37..17fbde284a 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
 )
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
@@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
 
         # we aren't testing store._base stuff here, so mock this out
-        self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
+        self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
 
         yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
         yield self._insert_txn(service.id, 10, events)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 8e9a650f9f..43639ca286 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
             "3"
         ] = 300000
+
         self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
         # All entries within time frame
         self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
             3,
         )
         # Oldest room to expire
-        self.pump(1)
+        self.pump(1.01)
         self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
         self.assertEqual(
             len(
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index e845410dae..7a05194653 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen = self._create_id_generator()
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token("master"), 7)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(id_gen.get_positions(), {"master": 7})
-                self.assertEqual(id_gen.get_current_token("master"), 7)
+                self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         self.get_success(_get_next_async())
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token("master"), 8)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
     def test_multi_instance(self):
         """Test that reads and writes from multiple processes are handled
@@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen = self._create_id_generator("second")
 
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token("first"), 3)
-        self.assertEqual(first_id_gen.get_current_token("second"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen = self._create_id_generator()
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token("master"), 7)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
             self.assertEqual(stream_id, 8)
 
             self.assertEqual(id_gen.get_positions(), {"master": 7})
-            self.assertEqual(id_gen.get_current_token("master"), 7)
+            self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token("master"), 8)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 619ee0da15..745fa15e26 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -34,10 +34,12 @@ class DataStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_users_paginate(self):
-        yield self.store.register_user(self.user.to_string(), "pass")
+        yield defer.ensureDeferred(
+            self.store.register_user(self.user.to_string(), "pass")
+        )
         yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
-        yield self.store.set_profile_displayname(
-            self.user.localpart, self.displayname, 1
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.user.localpart, self.displayname, 1)
         )
 
         users, total = yield self.store.get_users_paginate(
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..0363735d4f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
-            def list_fn(self, args1, arg2):
+            @descriptors.cachedList("fn", "args1")
+            async def list_fn(self, args1, arg2):
                 assert current_context().request == "c1"
                 # we want this to behave like an asynchronous function
-                yield run_on_reactor()
+                await run_on_reactor()
                 assert current_context().request == "c1"
                 return self.mock(args1, arg2)
 
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
-            def list_fn(self, args1, arg2):
+            @descriptors.cachedList("fn", "args1")
+            async def list_fn(self, args1, arg2):
                 # we want this to behave like an asynchronous function
-                yield run_on_reactor()
+                await run_on_reactor()
                 return self.mock(args1, arg2)
 
         obj = Cls()