summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2020-09-03 15:30:00 +0100
committerBrendan Abolivier <babolivier@matrix.org>2020-09-03 15:30:00 +0100
commit505ea932f50b8995bbf101b45bafe7456c7534d5 (patch)
tree6161e859947944cb13bc6e92be10197b32ebd89a /synapse
parentMerge branch 'develop' into matrix-org-hotfixes (diff)
parentRemove useless changelog about reverting a #8239. (diff)
downloadsynapse-505ea932f50b8995bbf101b45bafe7456c7534d5.tar.xz
Merge branch 'develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py7
-rw-r--r--synapse/app/admin_cmd.py20
-rw-r--r--synapse/app/homeserver.py18
-rw-r--r--synapse/appservice/api.py19
-rw-r--r--synapse/config/_base.py21
-rw-r--r--synapse/config/_base.pyi1
-rw-r--r--synapse/config/workers.py37
-rw-r--r--synapse/events/__init__.py4
-rw-r--r--synapse/events/builder.py19
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/e2e_keys.py4
-rw-r--r--synapse/handlers/federation.py44
-rw-r--r--synapse/handlers/message.py27
-rw-r--r--synapse/handlers/pagination.py49
-rw-r--r--synapse/handlers/profile.py6
-rw-r--r--synapse/handlers/room.py14
-rw-r--r--synapse/handlers/room_member.py19
-rw-r--r--synapse/handlers/sync.py45
-rw-r--r--synapse/handlers/user_directory.py8
-rw-r--r--synapse/http/federation/matrix_federation_agent.py4
-rw-r--r--synapse/http/federation/well_known_resolver.py57
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py92
-rw-r--r--synapse/push/push_tools.py2
-rw-r--r--synapse/python_dependencies.py4
-rw-r--r--synapse/replication/http/federation.py12
-rw-r--r--synapse/replication/slave/storage/devices.py3
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/rest/__init__.py4
-rw-r--r--synapse/rest/client/v2_alpha/shared_rooms.py68
-rw-r--r--synapse/rest/client/v2_alpha/sync.py1
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/storage/background_updates.py4
-rw-r--r--synapse/storage/database.py87
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/__init__.py53
-rw-r--r--synapse/storage/databases/main/account_data.py59
-rw-r--r--synapse/storage/databases/main/client_ips.py4
-rw-r--r--synapse/storage/databases/main/devices.py102
-rw-r--r--synapse/storage/databases/main/directory.py6
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py194
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/event_push_actions.py258
-rw-r--r--synapse/storage/databases/main/events.py8
-rw-r--r--synapse/storage/databases/main/events_worker.py114
-rw-r--r--synapse/storage/databases/main/filtering.py5
-rw-r--r--synapse/storage/databases/main/media_repository.py31
-rw-r--r--synapse/storage/databases/main/openid.py8
-rw-r--r--synapse/storage/databases/main/profile.py6
-rw-r--r--synapse/storage/databases/main/purge_events.py30
-rw-r--r--synapse/storage/databases/main/push_rule.py10
-rw-r--r--synapse/storage/databases/main/receipts.py14
-rw-r--r--synapse/storage/databases/main/registration.py240
-rw-r--r--synapse/storage/databases/main/relations.py103
-rw-r--r--synapse/storage/databases/main/room.py49
-rw-r--r--synapse/storage/databases/main/roommember.py58
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres26
-rw-r--r--synapse/storage/databases/main/schema/delta/58/15unread_count.sql26
-rw-r--r--synapse/storage/databases/main/search.py15
-rw-r--r--synapse/storage/databases/main/signatures.py40
-rw-r--r--synapse/storage/databases/main/stats.py34
-rw-r--r--synapse/storage/databases/main/stream.py67
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/databases/main/ui_auth.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py91
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py8
-rw-r--r--synapse/storage/util/id_generators.py49
-rw-r--r--synapse/util/async_helpers.py16
70 files changed, 1576 insertions, 894 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 2b2cd795e0..a43dc5b2c9 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -334,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
     This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
     can run out of file descriptors and infinite loop if we attempt to do too
     many DNS queries at once
+
+    XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
+    you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
+    backed by the reactor's default threadpool (which is limited to 10 threads). So
+    (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
+    understand why we would run out of FDs if we did too many lookups at once.
+    -- richvdh 2020/08/29
     """
     new_resolver = _LimitedHostnameResolver(
         reactor.nameResolver, max_dns_requests_in_flight
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index a37818fe9a..b6c9085670 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer):
         pass
 
 
-@defer.inlineCallbacks
-def export_data_command(hs, args):
+async def export_data_command(hs, args):
     """Export data for a user.
 
     Args:
@@ -91,10 +90,8 @@ def export_data_command(hs, args):
     user_id = args.user_id
     directory = args.output_directory
 
-    res = yield defer.ensureDeferred(
-        hs.get_handlers().admin_handler.export_user_data(
-            user_id, FileExfiltrationWriter(user_id, directory=directory)
-        )
+    res = await hs.get_handlers().admin_handler.export_user_data(
+        user_id, FileExfiltrationWriter(user_id, directory=directory)
     )
     print(res)
 
@@ -232,14 +229,15 @@ def start(config_options):
     # We also make sure that `_base.start` gets run before we actually run the
     # command.
 
-    @defer.inlineCallbacks
-    def run(_reactor):
+    async def run():
         with LoggingContext("command"):
-            yield _base.start(ss, [])
-            yield args.func(ss, args)
+            _base.start(ss, [])
+            await args.func(ss, args)
 
     _base.start_worker_reactor(
-        "synapse-admin-cmd", config, run_command=lambda: task.react(run)
+        "synapse-admin-cmd",
+        config,
+        run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())),
     )
 
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 98d0d14a12..6014adc850 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -411,26 +411,24 @@ def setup(config_options):
 
         return provision
 
-    @defer.inlineCallbacks
-    def reprovision_acme():
+    async def reprovision_acme():
         """
         Provision a certificate from ACME, if required, and reload the TLS
         certificate if it's renewed.
         """
-        reprovisioned = yield defer.ensureDeferred(do_acme())
+        reprovisioned = await do_acme()
         if reprovisioned:
             _base.refresh_certificate(hs)
 
-    @defer.inlineCallbacks
-    def start():
+    async def start():
         try:
             # Run the ACME provisioning code, if it's enabled.
             if hs.config.acme_enabled:
                 acme = hs.get_acme_handler()
                 # Start up the webservices which we will respond to ACME
                 # challenges with, and then provision.
-                yield defer.ensureDeferred(acme.start_listening())
-                yield defer.ensureDeferred(do_acme())
+                await acme.start_listening()
+                await do_acme()
 
                 # Check if it needs to be reprovisioned every day.
                 hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@@ -439,8 +437,8 @@ def setup(config_options):
             if hs.config.oidc_enabled:
                 oidc = hs.get_oidc_handler()
                 # Loading the provider metadata also ensures the provider config is valid.
-                yield defer.ensureDeferred(oidc.load_metadata())
-                yield defer.ensureDeferred(oidc.load_jwks())
+                await oidc.load_metadata()
+                await oidc.load_jwks()
 
             _base.start(hs, config.listeners)
 
@@ -456,7 +454,7 @@ def setup(config_options):
                 reactor.stop()
             sys.exit(1)
 
-    reactor.callWhenRunning(start)
+    reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
 
     return hs
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e72a0b9ac0..bb6fa8299a 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,18 +14,20 @@
 # limitations under the License.
 import logging
 import urllib
+from typing import TYPE_CHECKING, Optional
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, ThirdPartyEntityKind
 from synapse.api.errors import CodeMessageException
 from synapse.events.utils import serialize_event
 from synapse.http.client import SimpleHttpClient
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
 
+if TYPE_CHECKING:
+    from synapse.appservice import ApplicationService
+
 logger = logging.getLogger(__name__)
 
 sent_transactions_counter = Counter(
@@ -163,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_3pe to %s threw exception %s", uri, ex)
             return []
 
-    def get_3pe_protocol(self, service, protocol):
+    async def get_3pe_protocol(
+        self, service: "ApplicationService", protocol: str
+    ) -> Optional[JsonDict]:
         if service.url is None:
             return {}
 
-        @defer.inlineCallbacks
-        def _get():
+        async def _get() -> Optional[JsonDict]:
             uri = "%s%s/thirdparty/protocol/%s" % (
                 service.url,
                 APP_SERVICE_PREFIX,
                 urllib.parse.quote(protocol),
             )
             try:
-                info = yield defer.ensureDeferred(self.get_json(uri, {}))
+                info = await self.get_json(uri, {})
 
                 if not _is_valid_3pe_metadata(info):
                     logger.warning(
@@ -196,7 +199,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                 return None
 
         key = (service.id, protocol)
-        return self.protocol_meta_cache.wrap(key, _get)
+        return await self.protocol_meta_cache.wrap(key, _get)
 
     async def push_bulk(self, service, events, txn_id=None):
         if service.url is None:
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1417487427..73f0717b0d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig:
     def should_handle(self, instance_name: str, key: str) -> bool:
         """Whether this instance is responsible for handling the given key.
         """
-
-        # If multiple instances are not defined we always return true.
+        # If multiple instances are not defined we always return true
         if not self.instances or len(self.instances) == 1:
             return True
 
+        return self.get_instance(key) == instance_name
+
+    def get_instance(self, key: str) -> str:
+        """Get the instance responsible for handling the given key.
+
+        Note: For things like federation sending the config for which instance
+        is sending is known only to the sender instance if there is only one.
+        Therefore `should_handle` should be used where possible.
+        """
+
+        if not self.instances:
+            return "master"
+
+        if len(self.instances) == 1:
+            return self.instances[0]
+
         # We shard by taking the hash, modulo it by the number of instances and
         # then checking whether this instance matches the instance at that
         # index.
@@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig:
         dest_hash = sha256(key.encode("utf8")).digest()
         dest_int = int.from_bytes(dest_hash, byteorder="little")
         remainder = dest_int % (len(self.instances))
-        return self.instances[remainder] == instance_name
+        return self.instances[remainder]
 
 
 __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index eb911e8f9f..b8faafa9bd 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
     instances: List[str]
     def __init__(self, instances: List[str]) -> None: ...
     def should_handle(self, instance_name: str, key: str) -> bool: ...
+    def get_instance(self, key: str) -> str: ...
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c784a71508..f23e42cdf9 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -13,12 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import List, Union
+
 import attr
 
 from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
 from .server import ListenerConfig, parse_listener_def
 
 
+def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
+    """Helper for allowing parsing a string or list of strings to a config
+    option expecting a list of strings.
+    """
+
+    if isinstance(obj, str):
+        return [obj]
+    return obj
+
+
 @attr.s
 class InstanceLocationConfig:
     """The host and port to talk to an instance via HTTP replication.
@@ -33,11 +45,13 @@ class WriterLocations:
     """Specifies the instances that write various streams.
 
     Attributes:
-        events: The instance that writes to the event and backfill streams.
-        events: The instance that writes to the typing stream.
+        events: The instances that write to the event and backfill streams.
+        typing: The instance that writes to the typing stream.
     """
 
-    events = attr.ib(default="master", type=str)
+    events = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter
+    )
     typing = attr.ib(default="master", type=str)
 
 
@@ -105,15 +119,18 @@ class WorkerConfig(Config):
         writers = config.get("stream_writers") or {}
         self.writers = WriterLocations(**writers)
 
-        # Check that the configured writer for events and typing also appears in
+        # Check that the configured writers for events and typing also appears in
         # `instance_map`.
         for stream in ("events", "typing"):
-            instance = getattr(self.writers, stream)
-            if instance != "master" and instance not in self.instance_map:
-                raise ConfigError(
-                    "Instance %r is configured to write %s but does not appear in `instance_map` config."
-                    % (instance, stream)
-                )
+            instances = _instance_to_list_converter(getattr(self.writers, stream))
+            for instance in instances:
+                if instance != "master" and instance not in self.instance_map:
+                    raise ConfigError(
+                        "Instance %r is configured to write %s but does not appear in `instance_map` config."
+                        % (instance, stream)
+                    )
+
+        self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 67db763dbf..62ea44fa49 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -18,7 +18,7 @@
 import abc
 import os
 from distutils.util import strtobool
-from typing import Dict, Optional, Type
+from typing import Dict, Optional, Tuple, Type
 
 from unpaddedbase64 import encode_base64
 
@@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
     # be here
     before = DictProperty("before")  # type: str
     after = DictProperty("after")  # type: str
-    order = DictProperty("order")  # type: int
+    order = DictProperty("order")  # type: Tuple[int, int]
 
     def get_dict(self) -> JsonDict:
         return dict(self._dict)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 9ed24380dd..7878cd7044 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import attr
 from nacl.signing import SigningKey
@@ -97,14 +97,14 @@ class EventBuilder(object):
     def is_state(self):
         return self._state_key is not None
 
-    async def build(self, prev_event_ids):
+    async def build(self, prev_event_ids: List[str]) -> EventBase:
         """Transform into a fully signed and hashed event
 
         Args:
-            prev_event_ids (list[str]): The event IDs to use as the prev events
+            prev_event_ids: The event IDs to use as the prev events
 
         Returns:
-            FrozenEvent
+            The signed and hashed event.
         """
 
         state_ids = await self._state.get_current_state_ids(
@@ -114,8 +114,13 @@ class EventBuilder(object):
 
         format_version = self.room_version.event_format
         if format_version == EventFormatVersions.V1:
-            auth_events = await self._store.add_event_hashes(auth_ids)
-            prev_events = await self._store.add_event_hashes(prev_event_ids)
+            # The types of auth/prev events changes between event versions.
+            auth_events = await self._store.add_event_hashes(
+                auth_ids
+            )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+            prev_events = await self._store.add_event_hashes(
+                prev_event_ids
+            )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
         else:
             auth_events = auth_ids
             prev_events = prev_event_ids
@@ -138,7 +143,7 @@ class EventBuilder(object):
             "unsigned": self.unsigned,
             "depth": depth,
             "prev_state": [],
-        }
+        }  # type: Dict[str, Any]
 
         if self.is_state():
             event_dict["state_key"] = self._state_key
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index db417d60de..ee4666337a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
         return result
 
     async def on_federation_query_user_devices(self, user_id):
-        stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
+        stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
+            user_id
+        )
         master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
         self_signing_key = await self.store.get_e2e_cross_signing_key(
             user_id, "self_signing"
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d8def45e38..dfd1c78549 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -353,7 +353,7 @@ class E2eKeysHandler(object):
             # make sure that each queried user appears in the result dict
             result_dict[user_id] = {}
 
-        results = await self.store.get_e2e_device_keys(local_query)
+        results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
 
         # Build the result structure
         for user_id, device_keys in results.items():
@@ -734,7 +734,7 @@ class E2eKeysHandler(object):
             # fetch our stored devices.  This is used to 1. verify
             # signatures on the master key, and 2. to compare with what
             # was sent if the device was signed
-            devices = await self.store.get_e2e_device_keys([(user_id, None)])
+            devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
 
             if user_id not in devices:
                 raise NotFoundError("No device keys found")
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 16389a0dca..bd8efbb768 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -923,7 +923,8 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        await self._handle_new_events(dest, ev_infos, backfilled=True)
+        if ev_infos:
+            await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -1216,7 +1217,7 @@ class FederationHandler(BaseHandler):
             event_infos.append(_NewEventInfo(event, None, auth))
 
         await self._handle_new_events(
-            destination, event_infos,
+            destination, room_id, event_infos,
         )
 
     def _sanity_check_event(self, ev):
@@ -1363,15 +1364,15 @@ class FederationHandler(BaseHandler):
             )
 
             max_stream_id = await self._persist_auth_tree(
-                origin, auth_chain, state, event, room_version_obj
+                origin, room_id, auth_chain, state, event, room_version_obj
             )
 
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.config.worker.writers.events, "events", max_stream_id
+                self.config.worker.events_shard_config.get_instance(room_id),
+                "events",
+                max_stream_id,
             )
 
             # Check whether this room is the result of an upgrade of a room we already know
@@ -1625,7 +1626,7 @@ class FederationHandler(BaseHandler):
         )
 
         context = await self.state_handler.compute_event_context(event)
-        await self.persist_events_and_notify([(event, context)])
+        await self.persist_events_and_notify(event.room_id, [(event, context)])
 
         return event
 
@@ -1652,7 +1653,9 @@ class FederationHandler(BaseHandler):
         await self.federation_client.send_leave(host_list, event)
 
         context = await self.state_handler.compute_event_context(event)
-        stream_id = await self.persist_events_and_notify([(event, context)])
+        stream_id = await self.persist_events_and_notify(
+            event.room_id, [(event, context)]
+        )
 
         return event, stream_id
 
@@ -1900,7 +1903,7 @@ class FederationHandler(BaseHandler):
                 )
 
             await self.persist_events_and_notify(
-                [(event, context)], backfilled=backfilled
+                event.room_id, [(event, context)], backfilled=backfilled
             )
         except Exception:
             run_in_background(
@@ -1913,6 +1916,7 @@ class FederationHandler(BaseHandler):
     async def _handle_new_events(
         self,
         origin: str,
+        room_id: str,
         event_infos: Iterable[_NewEventInfo],
         backfilled: bool = False,
     ) -> None:
@@ -1944,6 +1948,7 @@ class FederationHandler(BaseHandler):
         )
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
@@ -1954,6 +1959,7 @@ class FederationHandler(BaseHandler):
     async def _persist_auth_tree(
         self,
         origin: str,
+        room_id: str,
         auth_events: List[EventBase],
         state: List[EventBase],
         event: EventBase,
@@ -1968,6 +1974,7 @@ class FederationHandler(BaseHandler):
 
         Args:
             origin: Where the events came from
+            room_id,
             auth_events
             state
             event
@@ -2042,17 +2049,20 @@ class FederationHandler(BaseHandler):
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
-            ]
+            ],
         )
 
         new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
-        return await self.persist_events_and_notify([(event, new_event_context)])
+        return await self.persist_events_and_notify(
+            room_id, [(event, new_event_context)]
+        )
 
     async def _prep_event(
         self,
@@ -2903,6 +2913,7 @@ class FederationHandler(BaseHandler):
 
     async def persist_events_and_notify(
         self,
+        room_id: str,
         event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
     ) -> int:
@@ -2910,14 +2921,19 @@ class FederationHandler(BaseHandler):
         necessary.
 
         Args:
-            event_and_contexts:
+            room_id: The room ID of events being persisted.
+            event_and_contexts: Sequence of events with their associated
+                context that should be persisted. All events must belong to
+                the same room.
             backfilled: Whether these events are a result of
                 backfilling or not
         """
-        if self.config.worker.writers.events != self._instance_name:
+        instance = self.config.worker.events_shard_config.get_instance(room_id)
+        if instance != self._instance_name:
             result = await self._send_events(
-                instance_name=self.config.worker.writers.events,
+                instance_name=instance,
                 store=self.store,
+                room_id=room_id,
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
             )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7a48c69163..0016af44be 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import (
-    Collection,
-    Requester,
-    RoomAlias,
-    StreamToken,
-    UserID,
-    create_requester,
-)
+from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
 from synapse.util import json_decoder
 from synapse.util.async_helpers import Linearizer
 from synapse.util.frozenutils import frozendict_json_encoder
@@ -383,9 +376,8 @@ class EventCreationHandler(object):
         self.notifier = hs.get_notifier()
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
-        self._is_event_writer = (
-            self.config.worker.writers.events == hs.get_instance_name()
-        )
+        self._events_shard_config = self.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
 
         self.room_invite_state_types = self.hs.config.room_invite_state_types
 
@@ -448,7 +440,7 @@ class EventCreationHandler(object):
         event_dict: dict,
         token_id: Optional[str] = None,
         txn_id: Optional[str] = None,
-        prev_event_ids: Optional[Collection[str]] = None,
+        prev_event_ids: Optional[List[str]] = None,
         require_consent: bool = True,
     ) -> Tuple[EventBase, EventContext]:
         """
@@ -788,7 +780,7 @@ class EventCreationHandler(object):
         self,
         builder: EventBuilder,
         requester: Optional[Requester] = None,
-        prev_event_ids: Optional[Collection[str]] = None,
+        prev_event_ids: Optional[List[str]] = None,
     ) -> Tuple[EventBase, EventContext]:
         """Create a new event for a local client
 
@@ -913,9 +905,10 @@ class EventCreationHandler(object):
 
         try:
             # If we're a worker we need to hit out to the master.
-            if not self._is_event_writer:
+            writer_instance = self._events_shard_config.get_instance(event.room_id)
+            if writer_instance != self._instance_name:
                 result = await self.send_event(
-                    instance_name=self.config.worker.writers.events,
+                    instance_name=writer_instance,
                     event_id=event.event_id,
                     store=self.store,
                     requester=requester,
@@ -983,7 +976,9 @@ class EventCreationHandler(object):
 
         This should only be run on the instance in charge of persisting events.
         """
-        assert self._is_event_writer
+        assert self._events_shard_config.should_handle(
+            self._instance_name, event.room_id
+        )
 
         if ratelimit:
             # We check if this is a room admin redacting an event so that we
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index ac3418d69d..5a1aa7d830 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -14,15 +14,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Any, Dict, Optional
 
 from twisted.python.failure import Failure
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import SynapseError
+from synapse.api.filtering import Filter
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
-from synapse.types import RoomStreamToken
+from synapse.streams.config import PaginationConfig
+from synapse.types import Requester, RoomStreamToken
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
@@ -247,15 +250,16 @@ class PaginationHandler(object):
         )
         return purge_id
 
-    async def _purge_history(self, purge_id, room_id, token, delete_local_events):
+    async def _purge_history(
+        self, purge_id: str, room_id: str, token: str, delete_local_events: bool
+    ) -> None:
         """Carry out a history purge on a room.
 
         Args:
-            purge_id (str): The id for this purge
-            room_id (str): The room to purge from
-            token (str): topological token to delete events before
-            delete_local_events (bool): True to delete local events as well as
-                remote ones
+            purge_id: The id for this purge
+            room_id: The room to purge from
+            token: topological token to delete events before
+            delete_local_events: True to delete local events as well as remote ones
         """
         self._purges_in_progress_by_room.add(room_id)
         try:
@@ -291,9 +295,9 @@ class PaginationHandler(object):
         """
         return self._purges_by_id.get(purge_id)
 
-    async def purge_room(self, room_id):
+    async def purge_room(self, room_id: str) -> None:
         """Purge the given room from the database"""
-        with (await self.pagination_lock.write(room_id)):
+        with await self.pagination_lock.write(room_id):
             # check we know about the room
             await self.store.get_room_version_id(room_id)
 
@@ -307,23 +311,22 @@ class PaginationHandler(object):
 
     async def get_messages(
         self,
-        requester,
-        room_id=None,
-        pagin_config=None,
-        as_client_event=True,
-        event_filter=None,
-    ):
+        requester: Requester,
+        room_id: Optional[str] = None,
+        pagin_config: Optional[PaginationConfig] = None,
+        as_client_event: bool = True,
+        event_filter: Optional[Filter] = None,
+    ) -> Dict[str, Any]:
         """Get messages in a room.
 
         Args:
-            requester (Requester): The user requesting messages.
-            room_id (str): The room they want messages from.
-            pagin_config (synapse.api.streams.PaginationConfig): The pagination
-                config rules to apply, if any.
-            as_client_event (bool): True to get events in client-server format.
-            event_filter (Filter): Filter to apply to results or None
+            requester: The user requesting messages.
+            room_id: The room they want messages from.
+            pagin_config: The pagination config rules to apply, if any.
+            as_client_event: True to get events in client-server format.
+            event_filter: Filter to apply to results or None
         Returns:
-            dict: Pagination API results
+            Pagination API results
         """
         user_id = requester.user.to_string()
 
@@ -343,7 +346,7 @@ class PaginationHandler(object):
 
         source_config = pagin_config.get_source_config("room")
 
-        with (await self.pagination_lock.read(room_id)):
+        with await self.pagination_lock.read(room_id):
             (
                 membership,
                 member_event_id,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 96c9d6bab4..0cb8fad89a 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -161,6 +161,9 @@ class BaseProfileHandler(BaseHandler):
                     Codes.FORBIDDEN,
                 )
 
+        if not isinstance(new_displayname, str):
+            raise SynapseError(400, "Invalid displayname")
+
         if len(new_displayname) > MAX_DISPLAYNAME_LEN:
             raise SynapseError(
                 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -235,6 +238,9 @@ class BaseProfileHandler(BaseHandler):
                     400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
                 )
 
+        if not isinstance(new_avatar_url, str):
+            raise SynapseError(400, "Invalid displayname")
+
         if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
             raise SynapseError(
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9d5b1828df..55794c3057 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler):
 
         # Always wait for room creation to progate before returning
         await self._replication.wait_for_stream_position(
-            self.hs.config.worker.writers.events, "events", last_stream_id
+            self.hs.config.worker.events_shard_config.get_instance(room_id),
+            "events",
+            last_stream_id,
         )
 
         return result, last_stream_id
@@ -1260,10 +1262,10 @@ class RoomShutdownHandler(object):
             # We now wait for the create room to come back in via replication so
             # that we can assume that all the joins/invites have propogated before
             # we try and auto join below.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.hs.config.worker.writers.events, "events", stream_id
+                self.hs.config.worker.events_shard_config.get_instance(new_room_id),
+                "events",
+                stream_id,
             )
         else:
             new_room_id = None
@@ -1293,7 +1295,9 @@ class RoomShutdownHandler(object):
 
                 # Wait for leave to come in over replication before trying to forget.
                 await self._replication.wait_for_stream_position(
-                    self.hs.config.worker.writers.events, "events", stream_id
+                    self.hs.config.worker.events_shard_config.get_instance(room_id),
+                    "events",
+                    stream_id,
                 )
 
                 await self.room_member_handler.forget(target_requester.user, room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1017ae6b19..ed1d1bd83d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.storage.roommember import RoomsForUser
-from synapse.types import (
-    Collection,
-    JsonDict,
-    Requester,
-    RoomAlias,
-    RoomID,
-    StateMap,
-    UserID,
-)
+from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room, user_left_room
 
@@ -91,13 +83,6 @@ class RoomMemberHandler(object):
         self._enable_lookup = hs.config.enable_3pid_lookup
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
 
-        self._event_stream_writer_instance = hs.config.worker.writers.events
-        self._is_on_event_persistence_instance = (
-            self._event_stream_writer_instance == hs.get_instance_name()
-        )
-        if self._is_on_event_persistence_instance:
-            self.persist_event_storage = hs.get_storage().persistence
-
         self._join_rate_limiter_local = Ratelimiter(
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
@@ -185,7 +170,7 @@ class RoomMemberHandler(object):
         target: UserID,
         room_id: str,
         membership: str,
-        prev_event_ids: Collection[str],
+        prev_event_ids: List[str],
         txn_id: Optional[str] = None,
         ratelimit: bool = True,
         content: Optional[dict] = None,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 8118206f8e..c281ff163a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -16,7 +16,7 @@
 
 import itertools
 import logging
-from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
 
 import attr
 from prometheus_client import Counter
@@ -44,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.metrics import Measure, measure_func
 from synapse.visibility import filter_events_for_client
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Debug logger for https://github.com/matrix-org/synapse/issues/4422
@@ -96,7 +99,12 @@ class TimelineBatch:
     __bool__ = __nonzero__  # python3
 
 
-@attr.s(slots=True, frozen=True)
+# We can't freeze this class, because we need to update it after it's instantiated to
+# update its unread count. This is because we calculate the unread count for a room only
+# if there are updates for it, which we check after the instance has been created.
+# This should not be a big deal because we update the notification counts afterwards as
+# well anyway.
+@attr.s(slots=True)
 class JoinedSyncResult:
     room_id = attr.ib(type=str)
     timeline = attr.ib(type=TimelineBatch)
@@ -105,6 +113,7 @@ class JoinedSyncResult:
     account_data = attr.ib(type=List[JsonDict])
     unread_notifications = attr.ib(type=JsonDict)
     summary = attr.ib(type=Optional[JsonDict])
+    unread_count = attr.ib(type=int)
 
     def __nonzero__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -239,7 +248,7 @@ class SyncResult:
 
 
 class SyncHandler(object):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs_config = hs.config
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
@@ -714,9 +723,8 @@ class SyncHandler(object):
         ]
 
         missing_hero_state = await self.store.get_events(missing_hero_event_ids)
-        missing_hero_state = missing_hero_state.values()
 
-        for s in missing_hero_state:
+        for s in missing_hero_state.values():
             cache.set(s.state_key, s.event_id)
             state[(EventTypes.Member, s.state_key)] = s
 
@@ -934,7 +942,7 @@ class SyncHandler(object):
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> Optional[Dict[str, str]]:
+    ) -> Dict[str, int]:
         with Measure(self.clock, "unread_notifs_for_room_id"):
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
@@ -942,15 +950,10 @@ class SyncHandler(object):
                 receipt_type="m.read",
             )
 
-            if last_unread_event_id:
-                notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
-                    room_id, sync_config.user.to_string(), last_unread_event_id
-                )
-                return notifs
-
-        # There is no new information in this period, so your notification
-        # count is whatever it was last time.
-        return None
+            notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+                room_id, sync_config.user.to_string(), last_unread_event_id
+            )
+            return notifs
 
     async def generate_sync_result(
         self,
@@ -1773,7 +1776,7 @@ class SyncHandler(object):
         ignored_users: Set[str],
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
-        tags: Optional[List[JsonDict]],
+        tags: Optional[Dict[str, Dict[str, Any]]],
         account_data: Dict[str, JsonDict],
         always_include: bool = False,
     ):
@@ -1889,7 +1892,7 @@ class SyncHandler(object):
             )
 
         if room_builder.rtype == "joined":
-            unread_notifications = {}  # type: Dict[str, str]
+            unread_notifications = {}  # type: Dict[str, int]
             room_sync = JoinedSyncResult(
                 room_id=room_id,
                 timeline=batch,
@@ -1898,14 +1901,16 @@ class SyncHandler(object):
                 account_data=account_data_events,
                 unread_notifications=unread_notifications,
                 summary=summary,
+                unread_count=0,
             )
 
             if room_sync or always_include:
                 notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                if notifs is not None:
-                    unread_notifications["notification_count"] = notifs["notify_count"]
-                    unread_notifications["highlight_count"] = notifs["highlight_count"]
+                unread_notifications["notification_count"] = notifs["notify_count"]
+                unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+                room_sync.unread_count = notifs["unread_count"]
 
                 sync_result_builder.joined.append(room_sync)
 
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 521b6d620d..e21f8dbc58 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -234,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     async def _handle_room_publicity_change(
         self, room_id, prev_event_id, event_id, typ
     ):
-        """Handle a room having potentially changed from/to world_readable/publically
+        """Handle a room having potentially changed from/to world_readable/publicly
         joinable.
 
         Args:
@@ -388,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler):
 
         prev_name = prev_event.content.get("displayname")
         new_name = event.content.get("displayname")
+        # If the new name is an unexpected form, do not update the directory.
+        if not isinstance(new_name, str):
+            new_name = prev_name
 
         prev_avatar = prev_event.content.get("avatar_url")
         new_avatar = event.content.get("avatar_url")
+        # If the new avatar is an unexpected form, do not update the directory.
+        if not isinstance(new_avatar, str):
+            new_avatar = prev_avatar
 
         if prev_name != new_name or prev_avatar != new_avatar:
             await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 369bf9c2fc..782d39d4ca 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -134,8 +134,8 @@ class MatrixFederationAgent(object):
             and not _is_ip_literal(parsed_uri.hostname)
             and not parsed_uri.port
         ):
-            well_known_result = yield self._well_known_resolver.get_well_known(
-                parsed_uri.hostname
+            well_known_result = yield defer.ensureDeferred(
+                self._well_known_resolver.get_well_known(parsed_uri.hostname)
             )
             delegated_server = well_known_result.delegated_server
 
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index e701dcc961..37c29c008a 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -16,6 +16,7 @@
 import logging
 import random
 import time
+from typing import Callable, Dict, Optional, Tuple
 
 import attr
 
@@ -23,6 +24,7 @@ from twisted.internet import defer
 from twisted.web.client import RedirectAgent, readBody
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
 
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import Clock, json_decoder
@@ -99,15 +101,14 @@ class WellKnownResolver(object):
         self._well_known_agent = RedirectAgent(agent)
         self.user_agent = user_agent
 
-    @defer.inlineCallbacks
-    def get_well_known(self, server_name):
+    async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
         """Attempt to fetch and parse a .well-known file for the given server
 
         Args:
-            server_name (bytes): name of the server, from the requested url
+            server_name: name of the server, from the requested url
 
         Returns:
-            Deferred[WellKnownLookupResult]: The result of the lookup
+            The result of the lookup
         """
 
         if server_name == b"kde.org":
@@ -128,7 +129,9 @@ class WellKnownResolver(object):
         # requests for the same server in parallel?
         try:
             with Measure(self._clock, "get_well_known"):
-                result, cache_period = yield self._fetch_well_known(server_name)
+                result, cache_period = await self._fetch_well_known(
+                    server_name
+                )  # type: Tuple[Optional[bytes], float]
 
         except _FetchWellKnownFailure as e:
             if prev_result and e.temporary:
@@ -157,18 +160,17 @@ class WellKnownResolver(object):
 
         return WellKnownLookupResult(delegated_server=result)
 
-    @defer.inlineCallbacks
-    def _fetch_well_known(self, server_name):
+    async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
         """Actually fetch and parse a .well-known, without checking the cache
 
         Args:
-            server_name (bytes): name of the server, from the requested url
+            server_name: name of the server, from the requested url
 
         Raises:
             _FetchWellKnownFailure if we fail to lookup a result
 
         Returns:
-            Deferred[Tuple[bytes,int]]: The lookup result and cache period.
+            The lookup result and cache period.
         """
 
         had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
@@ -176,7 +178,7 @@ class WellKnownResolver(object):
         # We do this in two steps to differentiate between possibly transient
         # errors (e.g. can't connect to host, 503 response) and more permenant
         # errors (such as getting a 404 response).
-        response, body = yield self._make_well_known_request(
+        response, body = await self._make_well_known_request(
             server_name, retry=had_valid_well_known
         )
 
@@ -219,20 +221,20 @@ class WellKnownResolver(object):
 
         return result, cache_period
 
-    @defer.inlineCallbacks
-    def _make_well_known_request(self, server_name, retry):
+    async def _make_well_known_request(
+        self, server_name: bytes, retry: bool
+    ) -> Tuple[IResponse, bytes]:
         """Make the well known request.
 
         This will retry the request if requested and it fails (with unable
         to connect or receives a 5xx error).
 
         Args:
-            server_name (bytes)
-            retry (bool): Whether to retry the request if it fails.
+            server_name: name of the server, from the requested url
+            retry: Whether to retry the request if it fails.
 
         Returns:
-            Deferred[tuple[IResponse, bytes]] Returns the response object and
-            body. Response may be a non-200 response.
+            Returns the response object and body. Response may be a non-200 response.
         """
         uri = b"https://%s/.well-known/matrix/server" % (server_name,)
         uri_str = uri.decode("ascii")
@@ -247,12 +249,12 @@ class WellKnownResolver(object):
 
             logger.info("Fetching %s", uri_str)
             try:
-                response = yield make_deferred_yieldable(
+                response = await make_deferred_yieldable(
                     self._well_known_agent.request(
                         b"GET", uri, headers=Headers(headers)
                     )
                 )
-                body = yield make_deferred_yieldable(readBody(response))
+                body = await make_deferred_yieldable(readBody(response))
 
                 if 500 <= response.code < 600:
                     raise Exception("Non-200 response %s" % (response.code,))
@@ -269,21 +271,24 @@ class WellKnownResolver(object):
                 logger.info("Error fetching %s: %s. Retrying", uri_str, e)
 
             # Sleep briefly in the hopes that they come back up
-            yield self._clock.sleep(0.5)
+            await self._clock.sleep(0.5)
 
 
-def _cache_period_from_headers(headers, time_now=time.time):
+def _cache_period_from_headers(
+    headers: Headers, time_now: Callable[[], float] = time.time
+) -> Optional[float]:
     cache_controls = _parse_cache_control(headers)
 
     if b"no-store" in cache_controls:
         return 0
 
     if b"max-age" in cache_controls:
-        try:
-            max_age = int(cache_controls[b"max-age"])
-            return max_age
-        except ValueError:
-            pass
+        max_age = cache_controls[b"max-age"]
+        if max_age:
+            try:
+                return int(max_age)
+            except ValueError:
+                pass
 
     expires = headers.getRawHeaders(b"expires")
     if expires is not None:
@@ -299,7 +304,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
     return None
 
 
-def _parse_cache_control(headers):
+def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
     cache_controls = {}
     for hdr in headers.getRawHeaders(b"cache-control", []):
         for directive in hdr.split(b","):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e7fcee0e87..e7fa02b78b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,10 @@ from collections import namedtuple
 
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
 from synapse.event_auth import get_user_power_level
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import register_cache
@@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
 )
 
 
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+    EventTypes.Topic,
+    EventTypes.Name,
+    EventTypes.RoomAvatar,
+    EventTypes.Tombstone,
+}
+
+
+def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+    # Exclude rejected and soft-failed events.
+    if context.rejected or event.internal_metadata.is_soft_failed():
+        return False
+
+    # Exclude notices.
+    if (
+        not event.is_state()
+        and event.type == EventTypes.Message
+        and event.content.get("msgtype") == "m.notice"
+    ):
+        return False
+
+    # Exclude edits.
+    relates_to = event.content.get("m.relates_to", {})
+    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+        return False
+
+    # Mark events that have a non-empty string body as unread.
+    body = event.content.get("body")
+    if isinstance(body, str) and body:
+        return True
+
+    # Mark some state events as unread.
+    if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+        return True
+
+    # Mark encrypted events as unread.
+    if not event.is_state() and event.type == EventTypes.Encrypted:
+        return True
+
+    return False
+
+
 class BulkPushRuleEvaluator(object):
     """Calculates the outcome of push rules for an event for all users in the
     room at once.
@@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
         return pl_event.content if pl_event else {}, sender_level
 
     async def action_for_event_by_user(self, event, context) -> None:
-        """Given an event and context, evaluate the push rules and insert the
-        results into the event_push_actions_staging table.
+        """Given an event and context, evaluate the push rules, check if the message
+        should increment the unread count, and insert the results into the
+        event_push_actions_staging table.
         """
+        count_as_unread = _should_count_as_unread(event, context)
+
         rules_by_user = await self._get_rules_for_event(event, context)
         actions_by_user = {}
 
@@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
                 if event.type == EventTypes.Member and event.state_key == uid:
                     display_name = event.content.get("displayname", None)
 
+            actions_by_user[uid] = []
+
             for rule in rules:
                 if "enabled" in rule and not rule["enabled"]:
                     continue
@@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
         # Mark in the DB staging area the push actions for users who should be
         # notified for this event. (This will then get handled when we persist
         # the event)
-        await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
+        await self.store.add_push_actions_to_staging(
+            event.event_id, actions_by_user, count_as_unread,
+        )
 
 
 def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -369,8 +420,8 @@ class RulesForRoom(object):
         Args:
             ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
                 updated with any new rules.
-            member_event_ids (list): List of event ids for membership events that
-                have happened since the last time we filled rules_by_user
+            member_event_ids (dict): Dict of user id to event id for membership events
+                that have happened since the last time we filled rules_by_user
             state_group: The state group we are currently computing push rules
                 for. Used when updating the cache.
         """
@@ -390,34 +441,19 @@ class RulesForRoom(object):
         if logger.isEnabledFor(logging.DEBUG):
             logger.debug("Found members %r: %r", self.room_id, members.values())
 
-        interested_in_user_ids = {
+        user_ids = {
             user_id
             for user_id, membership in members.values()
             if membership == Membership.JOIN
         }
 
-        logger.debug("Joined: %r", interested_in_user_ids)
-
-        if_users_with_pushers = await self.store.get_if_users_have_pushers(
-            interested_in_user_ids, on_invalidate=self.invalidate_all_cb
-        )
-
-        user_ids = {
-            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-        }
-
-        logger.debug("With pushers: %r", user_ids)
-
-        users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
-            self.room_id, on_invalidate=self.invalidate_all_cb
-        )
-
-        logger.debug("With receipts: %r", users_with_receipts)
+        logger.debug("Joined: %r", user_ids)
 
-        # any users with pushers must be ours: they have pushers
-        for uid in users_with_receipts:
-            if uid in interested_in_user_ids:
-                user_ids.add(uid)
+        # Previously we only considered users with pushers or read receipts in that
+        # room. We can't do this anymore because we use push actions to calculate unread
+        # counts, which don't rely on the user having pushers or sent a read receipt into
+        # the room. Therefore we just need to filter for local users here.
+        user_ids = list(filter(self.is_mine_id, user_ids))
 
         rules_by_user = await self.store.bulk_get_push_rules(
             user_ids, on_invalidate=self.invalidate_all_cb
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..f7a25571f3 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
             )
             # return one badge count per conversation, as count per
             # message is so noisy as to be almost useless
-            badge += 1 if notifs["notify_count"] else 0
+            badge += 1 if notifs["unread_count"] else 0
     return badge
 
 
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index dd77a44b8d..2d995ec456 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -66,7 +66,9 @@ REQUIREMENTS = [
     "msgpack>=0.5.2",
     "phonenumbers>=8.2.0",
     "prometheus_client>=0.0.18,<0.9.0",
-    # we use attr.validators.deep_iterable, which arrived in 19.1.0
+    # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
+    # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
+    # is out in November.)
     "attrs>=19.1.0",
     "netaddr>=0.7.18",
     "Jinja2>=2.9",
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 6b56315148..5c8be747e1 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         self.federation_handler = hs.get_handlers().federation_handler
 
     @staticmethod
-    async def _serialize_payload(store, event_and_contexts, backfilled):
+    async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
         """
         Args:
             store
+            room_id (str)
             event_and_contexts (list[tuple[FrozenEvent, EventContext]])
             backfilled (bool): Whether or not the events are the result of
                 backfilling
@@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 }
             )
 
-        payload = {"events": event_payloads, "backfilled": backfilled}
+        payload = {
+            "events": event_payloads,
+            "backfilled": backfilled,
+            "room_id": room_id,
+        }
 
         return payload
 
@@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         with Measure(self.clock, "repl_fed_send_events_parse"):
             content = parse_json_object_from_request(request)
 
+            room_id = content["room_id"]
             backfilled = content["backfilled"]
 
             event_payloads = content["events"]
@@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         logger.info("Got %d events from federation", len(event_and_contexts))
 
         max_stream_id = await self.federation_handler.persist_events_and_notify(
-            event_and_contexts, backfilled
+            room_id, event_and_contexts, backfilled
         )
 
         return 200, {"max_stream_id": max_stream_id}
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 596c72eb92..3b788c9625 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
+    def get_device_stream_token(self) -> int:
+        return self._device_list_id_gen.get_current_token()
+
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(instance_name, token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index fcf8ebf1e7..d6ecf5b327 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 """A replication client for use by synapse workers.
 """
-import heapq
 import logging
 from typing import TYPE_CHECKING, Dict, List, Tuple
 
@@ -219,9 +218,8 @@ class ReplicationDataHandler:
 
         waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
 
-        # We insert into the list using heapq as it is more efficient than
-        # pushing then resorting each time.
-        heapq.heappush(waiting_list, (position, deferred))
+        waiting_list.append((position, deferred))
+        waiting_list.sort(key=lambda t: t[0])
 
         # We measure here to get in flight counts and average waiting time.
         with Measure(self._clock, "repl.wait_for_stream_position"):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1c303f3a46..b323841f73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -109,7 +109,7 @@ class ReplicationCommandHandler:
             if isinstance(stream, (EventsStream, BackfillStream)):
                 # Only add EventStream and BackfillStream as a source on the
                 # instance in charge of event persistence.
-                if hs.config.worker.writers.events == hs.get_instance_name():
+                if hs.get_instance_name() in hs.config.worker.writers.events:
                     self._streams_to_replicate.append(stream)
 
                 continue
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 16c63ff4ec..3705618b4f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
+from ._base import Stream, StreamUpdateResult, Token
 
 """Handling of the 'events' replication stream
 
@@ -117,7 +117,7 @@ class EventsStream(Stream):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self._store.get_current_events_token),
+            self._store._stream_id_gen.get_current_token_for_writer,
             self._update_function,
         )
 
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 46e458e95b..87f927890c 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
     room_keys,
     room_upgrade_rest_servlet,
     sendtodevice,
+    shared_rooms,
     sync,
     tags,
     thirdparty,
@@ -125,3 +126,6 @@ class ClientRestResource(JsonResource):
         synapse.rest.admin.register_servlets_for_client_rest_resource(
             hs, client_resource
         )
+
+        # unstable
+        shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
new file mode 100644
index 0000000000..2492634dac
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserSharedRoomsServlet(RestServlet):
+    """
+    GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
+    """
+
+    PATTERNS = client_patterns(
+        "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
+        releases=(),  # This is an unstable feature
+    )
+
+    def __init__(self, hs):
+        super(UserSharedRoomsServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.user_directory_active = hs.config.update_user_directory
+
+    async def on_GET(self, request, user_id):
+
+        if not self.user_directory_active:
+            raise SynapseError(
+                code=400,
+                msg="The user directory is disabled on this server. Cannot determine shared rooms.",
+                errcode=Codes.FORBIDDEN,
+            )
+
+        UserID.from_string(user_id)
+
+        requester = await self.auth.get_user_by_req(request)
+        if user_id == requester.user.to_string():
+            raise SynapseError(
+                code=400,
+                msg="You cannot request a list of shared rooms with yourself",
+                errcode=Codes.FORBIDDEN,
+            )
+        rooms = await self.store.get_shared_rooms_for_users(
+            requester.user.to_string(), user_id
+        )
+
+        return 200, {"joined": list(rooms)}
+
+
+def register_servlets(hs, http_server):
+    UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 96488b131a..a0b00135e1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
             result["ephemeral"] = {"events": ephemeral_events}
             result["unread_notifications"] = room.unread_notifications
             result["summary"] = room.summary
+            result["org.matrix.msc2654.unread_count"] = room.unread_count
 
         return result
 
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..24ac57f35d 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -60,6 +60,8 @@ class VersionsRestServlet(RestServlet):
                     "org.matrix.e2e_cross_signing": True,
                     # Implements additional endpoints as described in MSC2432
                     "org.matrix.msc2432": True,
+                    # Implements additional endpoints as described in MSC2666
+                    "uk.half-shot.msc2666": True,
                 },
             },
         )
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 0db900fa0e..67a89cd51a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -433,7 +433,7 @@ class BackgroundUpdater(object):
             "background_updates", keyvalues={"update_name": update_name}
         )
 
-    def _background_update_progress(self, update_name: str, progress: dict):
+    async def _background_update_progress(self, update_name: str, progress: dict):
         """Update the progress of a background update
 
         Args:
@@ -441,7 +441,7 @@ class BackgroundUpdater(object):
             progress: The progress of the update.
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "background_update_progress",
             self._background_update_progress_txn,
             update_name,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 7ab370efef..78ca6d8346 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,6 +28,7 @@ from typing import (
     Optional,
     Tuple,
     TypeVar,
+    cast,
     overload,
 )
 
@@ -35,7 +36,6 @@ from prometheus_client import Histogram
 from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
-from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -507,8 +507,9 @@ class DatabasePool(object):
             self._txn_perf_counters.update(desc, duration)
             sql_txn_timer.labels(desc).observe(duration)
 
-    @defer.inlineCallbacks
-    def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+    async def runInteraction(
+        self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Starts a transaction on the database and runs a given function
 
         Arguments:
@@ -521,7 +522,7 @@ class DatabasePool(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         after_callbacks = []  # type: List[_CallbackListEntry]
         exception_callbacks = []  # type: List[_CallbackListEntry]
@@ -530,16 +531,14 @@ class DatabasePool(object):
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
-            result = yield defer.ensureDeferred(
-                self.runWithConnection(
-                    self.new_transaction,
-                    desc,
-                    after_callbacks,
-                    exception_callbacks,
-                    func,
-                    *args,
-                    **kwargs
-                )
+            result = await self.runWithConnection(
+                self.new_transaction,
+                desc,
+                after_callbacks,
+                exception_callbacks,
+                func,
+                *args,
+                **kwargs
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
@@ -549,7 +548,7 @@ class DatabasePool(object):
                 after_callback(*after_args, **after_kwargs)
             raise
 
-        return result
+        return cast(R, result)
 
     async def runWithConnection(
         self, func: "Callable[..., R]", *args: Any, **kwargs: Any
@@ -604,6 +603,18 @@ class DatabasePool(object):
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
+    @overload
+    async def execute(
+        self, desc: str, decoder: Literal[None], query: str, *args: Any
+    ) -> List[Tuple[Any, ...]]:
+        ...
+
+    @overload
+    async def execute(
+        self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
+    ) -> R:
+        ...
+
     async def execute(
         self,
         desc: str,
@@ -1088,6 +1099,28 @@ class DatabasePool(object):
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
+    @overload
+    async def simple_select_one_onecol(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> Any:
+        ...
+
+    @overload
+    async def simple_select_one_onecol(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one_onecol",
+    ) -> Optional[Any]:
+        ...
+
     async def simple_select_one_onecol(
         self,
         table: str,
@@ -1116,6 +1149,30 @@ class DatabasePool(object):
             allow_none=allow_none,
         )
 
+    @overload
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: Literal[False] = False,
+    ) -> Any:
+        ...
+
+    @overload
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: Literal[True] = True,
+    ) -> Optional[Any]:
+        ...
+
     @classmethod
     def simple_select_one_onecol_txn(
         cls,
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0ac854aee2..c73d54fb67 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -68,7 +68,7 @@ class Databases(object):
 
                     # If we're on a process that can persist events also
                     # instantiate a `PersistEventsStore`
-                    if hs.config.worker.writers.events == hs.get_instance_name():
+                    if hs.get_instance_name() in hs.config.worker.writers.events:
                         persist_events = PersistEventsStore(hs, database, main)
 
                 if "state" in database_config.databases:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70cf15dd7f..99890ffbf3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@
 import calendar
 import logging
 import time
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import PresenceState
 from synapse.config.homeserver import HomeServerConfig
@@ -264,6 +264,9 @@ class DataStore(
         # Used in _generate_user_daily_visits to keep track of progress
         self._last_user_visit_update = self._get_start_of_day()
 
+    def get_device_stream_token(self) -> int:
+        return self._device_list_id_gen.get_current_token()
+
     def take_presence_startup_info(self):
         active_on_startup = self._presence_on_startup
         self._presence_on_startup = None
@@ -291,16 +294,16 @@ class DataStore(
 
         return [UserPresenceState(**row) for row in rows]
 
-    def count_daily_users(self):
+    async def count_daily_users(self) -> int:
         """
         Counts the number of users who used this homeserver in the last 24 hours.
         """
         yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_daily_users", self._count_users, yesterday
         )
 
-    def count_monthly_users(self):
+    async def count_monthly_users(self) -> int:
         """
         Counts the number of users who used this homeserver in the last 30 days.
         Note this method is intended for phonehome metrics only and is different
@@ -308,7 +311,7 @@ class DataStore(
         amongst other things, includes a 3 day grace period before a user counts.
         """
         thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
@@ -327,15 +330,15 @@ class DataStore(
         (count,) = txn.fetchone()
         return count
 
-    def count_r30_users(self):
+    async def count_r30_users(self) -> Dict[str, int]:
         """
         Counts the number of 30 day retained users, defined as:-
          * Users who have created their accounts more than 30 days ago
          * Where last seen at most 30 days ago
          * Where account creation and last_seen are > 30 days apart
 
-         Returns counts globaly for a given user as well as breaking
-         by platform
+        Returns:
+             A mapping of counts globally as well as broken out by platform.
         """
 
         def _count_r30_users(txn):
@@ -408,7 +411,7 @@ class DataStore(
 
             return results
 
-        return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
 
     def _get_start_of_day(self):
         """
@@ -418,7 +421,7 @@ class DataStore(
         today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
         return today_start * 1000
 
-    def generate_user_daily_visits(self):
+    async def generate_user_daily_visits(self) -> None:
         """
         Generates daily visit data for use in cohort/ retention analysis
         """
@@ -473,7 +476,7 @@ class DataStore(
             # frequently
             self._last_user_visit_update = now
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "generate_user_daily_visits", _generate_user_daily_visits
         )
 
@@ -497,22 +500,28 @@ class DataStore(
             desc="get_users",
         )
 
-    def get_users_paginate(
-        self, start, limit, user_id=None, name=None, guests=True, deactivated=False
-    ):
+    async def get_users_paginate(
+        self,
+        start: int,
+        limit: int,
+        user_id: Optional[str] = None,
+        name: Optional[str] = None,
+        guests: bool = True,
+        deactivated: bool = False,
+    ) -> Tuple[List[Dict[str, Any]], int]:
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
         total number of users matching the filter criteria.
 
         Args:
-            start (int): start number to begin the query from
-            limit (int): number of rows to retrieve
-            user_id (string): search for user_id. ignored if name is not None
-            name (string): search for local part of user_id or display name
-            guests (bool): whether to in include guest users
-            deactivated (bool): whether to include deactivated users
+            start: start number to begin the query from
+            limit: number of rows to retrieve
+            user_id: search for user_id. ignored if name is not None
+            name: search for local part of user_id or display name
+            guests: whether to in include guest users
+            deactivated: whether to include deactivated users
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]], int
+            A tuple of a list of mappings from user to information and a count of total users.
         """
 
         def get_users_paginate_txn(txn):
@@ -555,7 +564,7 @@ class DataStore(
             users = self.db_pool.cursor_to_dict(txn)
             return users, count
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_paginate_txn", get_users_paginate_txn
         )
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 04042a2c98..4436b1a83d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,9 +16,7 @@
 
 import abc
 import logging
-from typing import List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Dict, List, Optional, Tuple
 
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
@@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cached()
-    def get_account_data_for_user(self, user_id):
+    async def get_account_data_for_user(
+        self, user_id: str
+    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
         """Get all the client account_data for a user.
 
         Args:
-            user_id(str): The user to get the account_data for.
+            user_id: The user to get the account_data for.
         Returns:
-            A deferred pair of a dict of global account_data and a dict
-            mapping from room_id string to per room account_data dicts.
+            A 2-tuple of a dict of global account_data and a dict mapping from
+            room_id string to per room account_data dicts.
         """
 
         def get_account_data_for_user_txn(txn):
@@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return global_account_data, by_room
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
@@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
             return None
 
     @cached(num_args=2)
-    def get_account_data_for_room(self, user_id, room_id):
+    async def get_account_data_for_room(
+        self, user_id: str, room_id: str
+    ) -> Dict[str, JsonDict]:
         """Get all the client account_data for a user for a room.
 
         Args:
-            user_id(str): The user to get the account_data for.
-            room_id(str): The room to get the account_data for.
+            user_id: The user to get the account_data for.
+            room_id: The room to get the account_data for.
         Returns:
-            A deferred dict of the room account_data
+            A dict of the room account_data
         """
 
         def get_account_data_for_room_txn(txn):
@@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: db_to_json(row["content"]) for row in rows
             }
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
     @cached(num_args=3, max_entries=5000)
-    def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+    async def get_account_data_for_room_and_type(
+        self, user_id: str, room_id: str, account_data_type: str
+    ) -> Optional[JsonDict]:
         """Get the client account_data of given type for a user for a room.
 
         Args:
-            user_id(str): The user to get the account_data for.
-            room_id(str): The room to get the account_data for.
-            account_data_type (str): The account data type to get.
+            user_id: The user to get the account_data for.
+            room_id: The room to get the account_data for.
+            account_data_type: The account data type to get.
         Returns:
-            A deferred of the room account_data for that type, or None if
-            there isn't any set.
+            The room account_data for that type, or None if there isn't any set.
         """
 
         def get_account_data_for_room_and_type_txn(txn):
@@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return db_to_json(content_json) if content_json else None
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
@@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
-    def get_updated_account_data_for_user(self, user_id, stream_id):
+    async def get_updated_account_data_for_user(
+        self, user_id: str, stream_id: int
+    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
         """Get all the client account_data for a that's changed for a user
 
         Args:
-            user_id(str): The user to get the account_data for.
-            stream_id(int): The point in the stream since which to get updates
+            user_id: The user to get the account_data for.
+            stream_id: The point in the stream since which to get updates
         Returns:
             A deferred pair of a dict of global account_data and a dict
             mapping from room_id string to per room account_data dicts.
@@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             user_id, int(stream_id)
         )
         if not changed:
-            return defer.succeed(({}, {}))
+            return ({}, {})
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
@@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
 
         return self._account_data_id_gen.get_current_token()
 
-    def _update_max_stream_id(self, next_id: int):
+    async def _update_max_stream_id(self, next_id: int) -> None:
         """Update the max stream_id
 
         Args:
@@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
             )
             txn.execute(update_max_id_sql, (next_id, next_id))
 
-        return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+        await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 4e2b2a85ee..d568789124 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         self._batch_row_update[key] = (user_agent, device_id, now)
 
     @wrap_as_background_process("update_client_ips")
-    def _update_client_ips_batch(self):
+    async def _update_client_ips_batch(self) -> None:
 
         # If the DB pool has already terminated, don't try updating
         if not self.db_pool.is_running():
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         to_update = self._batch_row_update
         self._batch_row_update = {}
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index def96637a2..f8fe948122 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,6 +14,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
 from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
 
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
             update included in the response), and the list of updates, where
             each update is a pair of EDU type and EDU contents.
         """
-        now_stream_id = self._device_list_id_gen.get_current_token()
+        now_stream_id = self.get_device_stream_token()
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -254,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
             List of objects representing an device update EDU
         """
         devices = (
-            await self.db_pool.runInteraction(
-                "_get_e2e_device_keys_txn",
-                self._get_e2e_device_keys_txn,
+            await self.get_e2e_device_keys_and_signatures(
                 query_map.keys(),
                 include_all_devices=True,
                 include_deleted_devices=True,
@@ -292,17 +291,17 @@ class DeviceWorkerStore(SQLBaseStore):
                 prev_id = stream_id
 
                 if device is not None:
-                    key_json = device.get("key_json", None)
+                    key_json = device.key_json
                     if key_json:
                         result["keys"] = db_to_json(key_json)
 
-                        if "signatures" in device:
-                            for sig_user_id, sigs in device["signatures"].items():
+                        if device.signatures:
+                            for sig_user_id, sigs in device.signatures.items():
                                 result["keys"].setdefault("signatures", {}).setdefault(
                                     sig_user_id, {}
                                 ).update(sigs)
 
-                    device_display_name = device.get("device_display_name", None)
+                    device_display_name = device.display_name
                     if device_display_name:
                         result["device_display_name"] = device_display_name
                 else:
@@ -312,9 +311,9 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return results
 
-    def _get_last_device_update_for_remote_user(
+    async def _get_last_device_update_for_remote_user(
         self, destination: str, user_id: str, from_stream_id: int
-    ):
+    ) -> int:
         def f(txn):
             prev_sent_id_sql = """
                 SELECT coalesce(max(stream_id), 0) as stream_id
@@ -325,12 +324,16 @@ class DeviceWorkerStore(SQLBaseStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
+        return await self.db_pool.runInteraction(
+            "get_last_device_update_for_remote_user", f
+        )
 
-    def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
+    async def mark_as_sent_devices_by_remote(
+        self, destination: str, stream_id: int
+    ) -> None:
         """Mark that updates have successfully been sent to the destination.
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "mark_as_sent_devices_by_remote",
             self._mark_as_sent_devices_by_remote_txn,
             destination,
@@ -412,8 +415,10 @@ class DeviceWorkerStore(SQLBaseStore):
             },
         )
 
+    @abc.abstractmethod
     def get_device_stream_token(self) -> int:
-        return self._device_list_id_gen.get_current_token()
+        """Get the current stream id from the _device_list_id_gen"""
+        ...
 
     @trace
     async def get_user_devices_from_cache(
@@ -481,51 +486,6 @@ class DeviceWorkerStore(SQLBaseStore):
             device["device_id"]: db_to_json(device["content"]) for device in devices
         }
 
-    def get_devices_with_keys_by_user(self, user_id: str):
-        """Get all devices (with any device keys) for a user
-
-        Returns:
-            Deferred which resolves to (stream_id, devices)
-        """
-        return self.db_pool.runInteraction(
-            "get_devices_with_keys_by_user",
-            self._get_devices_with_keys_by_user_txn,
-            user_id,
-        )
-
-    def _get_devices_with_keys_by_user_txn(
-        self, txn: LoggingTransaction, user_id: str
-    ) -> Tuple[int, List[JsonDict]]:
-        now_stream_id = self._device_list_id_gen.get_current_token()
-
-        devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
-
-        if devices:
-            user_devices = devices[user_id]
-            results = []
-            for device_id, device in user_devices.items():
-                result = {"device_id": device_id}
-
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = db_to_json(key_json)
-
-                    if "signatures" in device:
-                        for sig_user_id, sigs in device["signatures"].items():
-                            result["keys"].setdefault("signatures", {}).setdefault(
-                                sig_user_id, {}
-                            ).update(sigs)
-
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
-
-                results.append(result)
-
-            return now_stream_id, results
-
-        return now_stream_id, []
-
     async def get_users_whose_devices_changed(
         self, from_key: str, user_ids: Iterable[str]
     ) -> Set[str]:
@@ -726,7 +686,7 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="make_remote_user_device_cache_as_stale",
         )
 
-    def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
+    async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user.
         """
 
@@ -740,7 +700,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "mark_remote_user_device_list_as_unsubscribed",
             _mark_remote_user_device_list_as_unsubscribed_txn,
         )
@@ -1001,9 +961,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             desc="update_device",
         )
 
-    def update_remote_device_list_cache_entry(
+    async def update_remote_device_list_cache_entry(
         self, user_id: str, device_id: str, content: JsonDict, stream_id: int
-    ):
+    ) -> None:
         """Updates a single device in the cache of a remote user's devicelist.
 
         Note: assumes that we are the only thread that can be updating this user's
@@ -1014,11 +974,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             device_id: ID of decivice being updated
             content: new data on this device
             stream_id: the version of the device list
-
-        Returns:
-            Deferred[None]
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_remote_device_list_cache_entry",
             self._update_remote_device_list_cache_entry_txn,
             user_id,
@@ -1070,9 +1027,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             lock=False,
         )
 
-    def update_remote_device_list_cache(
+    async def update_remote_device_list_cache(
         self, user_id: str, devices: List[dict], stream_id: int
-    ):
+    ) -> None:
         """Replace the entire cache of the remote user's devices.
 
         Note: assumes that we are the only thread that can be updating this user's
@@ -1082,11 +1039,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: User to update device list for
             devices: list of device objects supplied over federation
             stream_id: the version of the device list
-
-        Returns:
-            Deferred[None]
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_remote_device_list_cache",
             self._update_remote_device_list_cache_txn,
             user_id,
@@ -1096,7 +1050,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
     def _update_remote_device_list_cache_txn(
         self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
-    ):
+    ) -> None:
         self.db_pool.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 405b5eafa5..e5060d4c46 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
 
         return room_id
 
-    def update_aliases_for_room(
+    async def update_aliases_for_room(
         self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
-    ):
+    ) -> None:
         """Repoint all of the aliases for a given room, to a different room.
 
         Args:
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
                 txn, self.get_aliases_for_room, (new_room_id,)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_update_aliases_for_room_txn", _update_aliases_for_room_txn
         )
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index af0b85e2c9..cc0b15ae07 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,8 +14,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
+import attr
 from canonicaljson import encode_canonical_json
 
 from twisted.enterprise.adbapi import Connection
@@ -23,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import make_in_list_sql_clause
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
@@ -31,19 +34,67 @@ if TYPE_CHECKING:
     from synapse.handlers.e2e_keys import SignatureListItem
 
 
+@attr.s
+class DeviceKeyLookupResult:
+    """The type returned by get_e2e_device_keys_and_signatures"""
+
+    display_name = attr.ib(type=Optional[str])
+
+    # the key data from e2e_device_keys_json. Typically includes fields like
+    # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+    # key) and "signatures" (a signature of the structure by the ed25519 key)
+    key_json = attr.ib(type=Optional[str])
+
+    # cross-signing sigs
+    signatures = attr.ib(type=Optional[Dict], default=None)
+
+
 class EndToEndKeyWorkerStore(SQLBaseStore):
+    async def get_e2e_device_keys_for_federation_query(
+        self, user_id: str
+    ) -> Tuple[int, List[JsonDict]]:
+        """Get all devices (with any device keys) for a user
+
+        Returns:
+            (stream_id, devices)
+        """
+        now_stream_id = self.get_device_stream_token()
+
+        devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
+
+        if devices:
+            user_devices = devices[user_id]
+            results = []
+            for device_id, device in user_devices.items():
+                result = {"device_id": device_id}
+
+                key_json = device.key_json
+                if key_json:
+                    result["keys"] = db_to_json(key_json)
+
+                    if device.signatures:
+                        for sig_user_id, sigs in device.signatures.items():
+                            result["keys"].setdefault("signatures", {}).setdefault(
+                                sig_user_id, {}
+                            ).update(sigs)
+
+                device_display_name = device.display_name
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
     @trace
-    async def get_e2e_device_keys(
-        self, query_list, include_all_devices=False, include_deleted_devices=False
-    ):
-        """Fetch a list of device keys.
+    async def get_e2e_device_keys_for_cs_api(
+        self, query_list: List[Tuple[str, Optional[str]]]
+    ) -> Dict[str, Dict[str, JsonDict]]:
+        """Fetch a list of device keys, formatted suitably for the C/S API.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
-            include_all_devices (bool): whether to include entries for devices
-                that don't have device keys
-            include_deleted_devices (bool): whether to include null entries for
-                devices which no longer exist (but were in the query_list).
-                This option only takes effect if include_all_devices is true.
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             key data.  The key data will be a dict in the same format as the
@@ -53,13 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         if not query_list:
             return {}
 
-        results = await self.db_pool.runInteraction(
-            "get_e2e_device_keys",
-            self._get_e2e_device_keys_txn,
-            query_list,
-            include_all_devices,
-            include_deleted_devices,
-        )
+        results = await self.get_e2e_device_keys_and_signatures(query_list)
 
         # Build the result structure, un-jsonify the results, and add the
         # "unsigned" section
@@ -67,13 +112,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         for user_id, device_keys in results.items():
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
-                r = db_to_json(device_info.pop("key_json"))
+                r = db_to_json(device_info.key_json)
                 r["unsigned"] = {}
-                display_name = device_info["device_display_name"]
+                display_name = device_info.display_name
                 if display_name is not None:
                     r["unsigned"]["device_display_name"] = display_name
-                if "signatures" in device_info:
-                    for sig_user_id, sigs in device_info["signatures"].items():
+                if device_info.signatures:
+                    for sig_user_id, sigs in device_info.signatures.items():
                         r.setdefault("signatures", {}).setdefault(
                             sig_user_id, {}
                         ).update(sigs)
@@ -82,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         return rv
 
     @trace
-    def _get_e2e_device_keys_txn(
-        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
-    ):
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: List[Tuple[str, Optional[str]]],
+        include_all_devices: bool = False,
+        include_deleted_devices: bool = False,
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        """Fetch a list of device keys, together with their cross-signatures.
+
+        Args:
+            query_list: List of pairs of user_ids and device_ids. Device id can be None
+                to indicate "all devices for this user"
+
+            include_all_devices: whether to return devices without device keys
+
+            include_deleted_devices: whether to include null entries for
+                devices which no longer exist (but were in the query_list).
+                This option only takes effect if include_all_devices is true.
+
+        Returns:
+            Dict mapping from user-id to dict mapping from device_id to
+            key data.
+        """
         set_tag("include_all_devices", include_all_devices)
         set_tag("include_deleted_devices", include_deleted_devices)
 
+        result = await self.db_pool.runInteraction(
+            "get_e2e_device_keys",
+            self._get_e2e_device_keys_and_signatures_txn,
+            query_list,
+            include_all_devices,
+            include_deleted_devices,
+        )
+
+        log_kv(result)
+        return result
+
+    def _get_e2e_device_keys_and_signatures_txn(
+        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
         query_clauses = []
         query_params = []
         signature_query_clauses = []
@@ -119,7 +197,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         sql = (
             "SELECT user_id, device_id, "
-            "    d.display_name AS device_display_name, "
+            "    d.display_name, "
             "    k.key_json"
             " FROM devices d"
             "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -130,13 +208,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
         txn.execute(sql, query_params)
-        rows = self.db_pool.cursor_to_dict(txn)
 
-        result = {}
-        for row in rows:
+        result = {}  # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+        for (user_id, device_id, display_name, key_json) in txn:
             if include_deleted_devices:
-                deleted_devices.remove((row["user_id"], row["device_id"]))
-            result.setdefault(row["user_id"], {})[row["device_id"]] = row
+                deleted_devices.remove((user_id, device_id))
+            result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+                display_name, key_json
+            )
 
         if include_deleted_devices:
             for user_id, device_id in deleted_devices:
@@ -167,13 +246,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 # note that target_device_result will be None for deleted devices.
                 continue
 
-            target_device_signatures = target_device_result.setdefault("signatures", {})
+            target_device_signatures = target_device_result.signatures
+            if target_device_signatures is None:
+                target_device_signatures = target_device_result.signatures = {}
+
             signing_user_signatures = target_device_signatures.setdefault(
                 signing_user_id, {}
             )
             signing_user_signatures[signing_key_id] = signature
 
-        log_kv(result)
         return result
 
     async def get_e2e_one_time_keys(
@@ -252,10 +333,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    def count_e2e_one_time_keys(self, user_id, device_id):
+    async def count_e2e_one_time_keys(
+        self, user_id: str, device_id: str
+    ) -> Dict[str, int]:
         """ Count the number of one time keys the server has for a device
         Returns:
-            Dict mapping from algorithm to number of keys for that algorithm.
+            A mapping from algorithm to number of keys for that algorithm.
         """
 
         def _count_e2e_one_time_keys(txn):
@@ -270,7 +353,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 result[algorithm] = key_count
             return result
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
@@ -308,7 +391,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         list_name="user_ids",
         num_args=1,
     )
-    def _get_bare_e2e_cross_signing_keys_bulk(
+    async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
@@ -316,16 +399,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         the signatures for the calling user need to be fetched.
 
         Args:
-            user_ids (list[str]): the users whose keys are being requested
+            user_ids: the users whose keys are being requested
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  If a user's cross-signing keys were not found, either
-                their user ID will not be in the dict, or their user ID will map
-                to None.
+            A mapping from user ID to key type to key data. If a user's cross-signing
+            keys were not found, either their user ID will not be in the dict, or
+            their user ID will map to None.
 
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_bare_e2e_cross_signing_keys_bulk",
             self._get_bare_e2e_cross_signing_keys_bulk_txn,
             user_ids,
@@ -541,9 +623,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             _get_all_user_signature_changes_for_remotes_txn,
         )
 
+    @abc.abstractmethod
+    def get_device_stream_token(self) -> int:
+        """Get the current stream id from the _device_list_id_gen"""
+        ...
+
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
-    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+    async def set_e2e_device_keys(
+        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+    ) -> bool:
         """Stores device keys for a device. Returns whether there was a change
         or the keys were already in the database.
         """
@@ -579,12 +668,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             log_kv({"message": "Device keys stored."})
             return True
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "set_e2e_device_keys", _set_e2e_device_keys_txn
         )
 
-    def claim_e2e_one_time_keys(self, query_list):
-        """Take a list of one time keys out of the database"""
+    async def claim_e2e_one_time_keys(
+        self, query_list: Iterable[Tuple[str, str, str]]
+    ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+        """Take a list of one time keys out of the database.
+
+        Args:
+            query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+        """
 
         @trace
         def _claim_e2e_one_time_keys(txn):
@@ -620,11 +718,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 )
             return result
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    def delete_e2e_keys_by_device(self, user_id, device_id):
+    async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn):
             log_kv(
                 {
@@ -647,7 +745,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0b69aa6a94..4c3c162acf 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
 
         if stream_ordering <= self.stream_ordering_month_ago:
-            raise StoreError(400, "stream_ordering too old")
+            raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
 
         sql = """
                 SELECT event_id FROM stream_ordering_to_exterm
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e8834b2162..001d06378d 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 
 import logging
-from typing import List
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
-        self, room_id, user_id, last_read_event_id
-    ):
+        self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+    ) -> Dict[str, int]:
+        """Get the notification count, the highlight count and the unread message count
+        for a given user in a given room after the given read receipt.
+
+        Note that this function assumes the user to be a current member of the room,
+        since it's either called by the sync handler to handle joined room entries, or by
+        the HTTP pusher to calculate the badge of unread joined rooms.
+
+        Args:
+            room_id: The room to retrieve the counts in.
+            user_id: The user to retrieve the counts for.
+            last_read_event_id: The event associated with the latest read receipt for
+                this user in this room. None if no receipt for this user in this room.
+
+        Returns
+            A dict containing the counts mentioned earlier in this docstring,
+            respectively under the keys "notify_count", "highlight_count" and
+            "unread_count".
+        """
         return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
@@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     def _get_unread_counts_by_receipt_txn(
-        self, txn, room_id, user_id, last_read_event_id
+        self, txn, room_id, user_id, last_read_event_id,
     ):
-        sql = (
-            "SELECT stream_ordering"
-            " FROM events"
-            " WHERE room_id = ? AND event_id = ?"
-        )
-        txn.execute(sql, (room_id, last_read_event_id))
-        results = txn.fetchall()
-        if len(results) == 0:
-            return {"notify_count": 0, "highlight_count": 0}
+        stream_ordering = None
+
+        if last_read_event_id is not None:
+            stream_ordering = self.get_stream_id_for_event_txn(
+                txn, last_read_event_id, allow_none=True,
+            )
+
+        if stream_ordering is None:
+            # Either last_read_event_id is None, or it's an event we don't have (e.g.
+            # because it's been purged), in which case retrieve the stream ordering for
+            # the latest membership event from this user in this room (which we assume is
+            # a join).
+            event_id = self.db_pool.simple_select_one_onecol_txn(
+                txn=txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+                retcol="event_id",
+            )
 
-        stream_ordering = results[0][0]
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
         )
 
     def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
-        # First get number of notifications.
-        # We don't need to put a notif=1 clause as all rows always have
-        # notif=1
         sql = (
-            "SELECT count(*)"
+            "SELECT"
+            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
             " FROM event_push_actions ea"
-            " WHERE"
-            " user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
+            " WHERE user_id = ?"
+            "   AND room_id = ?"
+            "   AND stream_ordering > ?"
         )
 
         txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        notify_count = row[0] if row else 0
+
+        (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+        if row:
+            (notif_count, highlight_count, unread_count) = row
 
         txn.execute(
             """
-            SELECT notif_count FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
-        """,
+                SELECT notif_count, unread_count FROM event_push_summary
+                WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+            """,
             (room_id, user_id, stream_ordering),
         )
-        rows = txn.fetchall()
-        if rows:
-            notify_count += rows[0][0]
-
-        # Now get the number of highlights
-        sql = (
-            "SELECT count(*)"
-            " FROM event_push_actions ea"
-            " WHERE"
-            " highlight = 1"
-            " AND user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
-        )
-
-        txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        highlight_count = row[0] if row else 0
 
-        return {"notify_count": notify_count, "highlight_count": highlight_count}
+        if row:
+            notif_count += row[0]
+            unread_count += row[1]
+
+        return {
+            "notify_count": notif_count,
+            "unread_count": unread_count,
+            "highlight_count": highlight_count,
+        }
 
     async def get_push_action_users_in_range(
         self, min_stream_ordering, max_stream_ordering
@@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -383,62 +409,66 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # Now return the first `limit`
         return notifs[:limit]
 
-    def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+    async def get_if_maybe_push_in_range_for_user(
+        self, user_id: str, min_stream_ordering: int
+    ) -> bool:
         """A fast check to see if there might be something to push for the
         user since the given stream ordering. May return false positives.
 
         Useful to know whether to bother starting a pusher on start up or not.
 
         Args:
-            user_id (str)
-            min_stream_ordering (int)
+            user_id
+            min_stream_ordering
 
         Returns:
-            Deferred[bool]: True if there may be push to process, False if
-            there definitely isn't.
+            True if there may be push to process, False if there definitely isn't.
         """
 
         def _get_if_maybe_push_in_range_for_user_txn(txn):
             sql = """
                 SELECT 1 FROM event_push_actions
-                WHERE user_id = ? AND stream_ordering > ?
+                WHERE user_id = ? AND stream_ordering > ? AND notif = 1
                 LIMIT 1
             """
 
             txn.execute(sql, (user_id, min_stream_ordering))
             return bool(txn.fetchone())
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_if_maybe_push_in_range_for_user",
             _get_if_maybe_push_in_range_for_user_txn,
         )
 
-    async def add_push_actions_to_staging(self, event_id, user_id_actions):
+    async def add_push_actions_to_staging(
+        self,
+        event_id: str,
+        user_id_actions: Dict[str, List[Union[dict, str]]],
+        count_as_unread: bool,
+    ) -> None:
         """Add the push actions for the event to the push action staging area.
 
         Args:
-            event_id (str)
-            user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
-                user_id to list of push actions, where an action can either be
-                a string or dict.
-
-        Returns:
-            Deferred
+            event_id
+            user_id_actions: A mapping of user_id to list of push actions, where
+                an action can either be a string or dict.
+            count_as_unread: Whether this event should increment unread counts.
         """
-
         if not user_id_actions:
             return
 
         # This is a helper function for generating the necessary tuple that
-        # can be used to inert into the `event_push_actions_staging` table.
+        # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(user_id, actions):
             is_highlight = 1 if _action_has_highlight(actions) else 0
+            notif = 1 if "notify" in actions else 0
             return (
                 event_id,  # event_id column
                 user_id,  # user_id column
                 _serialize_action(actions, is_highlight),  # actions column
-                1,  # notif column
+                notif,  # notif column
                 is_highlight,  # highlight column
+                int(count_as_unread),  # unread column
             )
 
         def _add_push_actions_to_staging_txn(txn):
@@ -447,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
             sql = """
                 INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight)
-                VALUES (?, ?, ?, ?, ?)
+                    (event_id, user_id, actions, notif, highlight, unread)
+                VALUES (?, ?, ?, ?, ?, ?)
             """
 
             txn.executemany(
@@ -507,7 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
         )
 
-    def find_first_stream_ordering_after_ts(self, ts):
+    async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
         """Gets the stream ordering corresponding to a given timestamp.
 
         Specifically, finds the stream_ordering of the first event that was
@@ -516,13 +546,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         relatively slow.
 
         Args:
-            ts (int): timestamp in millis
+            ts: timestamp in millis
 
         Returns:
-            Deferred[int]: stream ordering of the first event received on/after
-                the timestamp
+            stream ordering of the first event received on/after the timestamp
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "_find_first_stream_ordering_after_ts_txn",
             self._find_first_stream_ordering_after_ts_txn,
             ts,
@@ -813,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id,
-                coalesce(old.notif_count, 0) + upd.notif_count,
+                coalesce(old.%s, 0) + upd.cnt,
                 upd.stream_ordering,
                 old.user_id
             FROM (
-                SELECT user_id, room_id, count(*) as notif_count,
+                SELECT user_id, room_id, count(*) as cnt,
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                     AND highlight = 0
+                    AND %s = 1
                 GROUP BY user_id, room_id
             ) AS upd
             LEFT JOIN event_push_summary AS old USING (user_id, room_id)
         """
 
-        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
-        rows = txn.fetchall()
+        # First get the count of unread messages.
+        txn.execute(
+            sql % ("unread_count", "unread"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        # We need to merge results from the two requests (the one that retrieves the
+        # unread count and the one that retrieves the notifications count) into a single
+        # object because we might not have the same amount of rows in each of them. To do
+        # this, we use a dict indexed on the user ID and room ID to make it easier to
+        # populate.
+        summaries = {}  # type: Dict[Tuple[str, str], _EventPushSummary]
+        for row in txn:
+            summaries[(row[0], row[1])] = _EventPushSummary(
+                unread_count=row[2],
+                stream_ordering=row[3],
+                old_user_id=row[4],
+                notif_count=0,
+            )
+
+        # Then get the count of notifications.
+        txn.execute(
+            sql % ("notif_count", "notif"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        for row in txn:
+            if (row[0], row[1]) in summaries:
+                summaries[(row[0], row[1])].notif_count = row[2]
+            else:
+                # Because the rules on notifying are different than the rules on marking
+                # a message unread, we might end up with messages that notify but aren't
+                # marked unread, so we might not have a summary for this (user, room)
+                # tuple to complete.
+                summaries[(row[0], row[1])] = _EventPushSummary(
+                    unread_count=0,
+                    stream_ordering=row[3],
+                    old_user_id=row[4],
+                    notif_count=row[2],
+                )
 
-        logger.info("Rotating notifications, handling %d rows", len(rows))
+        logger.info("Rotating notifications, handling %d rows", len(summaries))
 
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
@@ -840,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             table="event_push_summary",
             values=[
                 {
-                    "user_id": row[0],
-                    "room_id": row[1],
-                    "notif_count": row[2],
-                    "stream_ordering": row[3],
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "notif_count": summary.notif_count,
+                    "unread_count": summary.unread_count,
+                    "stream_ordering": summary.stream_ordering,
                 }
-                for row in rows
-                if row[4] is None
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is None
             ],
         )
 
         txn.executemany(
             """
-                UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+                UPDATE event_push_summary
+                SET notif_count = ?, unread_count = ?, stream_ordering = ?
                 WHERE user_id = ? AND room_id = ?
             """,
-            ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+            (
+                (
+                    summary.notif_count,
+                    summary.unread_count,
+                    summary.stream_ordering,
+                    user_id,
+                    room_id,
+                )
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is not None
+            ),
         )
 
         txn.execute(
@@ -881,3 +961,15 @@ def _action_has_highlight(actions):
             pass
 
     return False
+
+
+@attr.s
+class _EventPushSummary:
+    """Summary of pending event push actions for a given user in a given room.
+    Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+    """
+
+    unread_count = attr.ib(type=int)
+    stream_ordering = attr.ib(type=int)
+    old_user_id = attr.ib(type=str)
+    notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 6313b41eef..b94fe7ac17 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -97,6 +97,7 @@ class PersistEventsStore:
         self.store = main_data_store
         self.database_engine = db.engine
         self._clock = hs.get_clock()
+        self._instance_name = hs.get_instance_name()
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
@@ -108,7 +109,7 @@ class PersistEventsStore:
 
         # This should only exist on instances that are configured to write
         assert (
-            hs.config.worker.writers.events == hs.get_instance_name()
+            hs.get_instance_name() in hs.config.worker.writers.events
         ), "Can only instantiate EventsStore on master"
 
     async def _persist_events_and_state_updates(
@@ -800,6 +801,7 @@ class PersistEventsStore:
             table="events",
             values=[
                 {
+                    "instance_name": self._instance_name,
                     "stream_ordering": event.internal_metadata.stream_ordering,
                     "topological_ordering": event.depth,
                     "depth": event.depth,
@@ -1296,9 +1298,9 @@ class PersistEventsStore:
         sql = """
             INSERT INTO event_push_actions (
                 room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
+                topological_ordering, notif, highlight, unread
             )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
             FROM event_push_actions_staging
             WHERE event_id = ?
         """
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e6247d682d..17f5997b89 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import Collection, get_domain_from_id
 from synapse.util.caches.descriptors import Cache, cached
 from synapse.util.iterutils import batch_iter
@@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
-        if hs.config.worker.writers.events == hs.get_instance_name():
-            # We are the process in charge of generating stream ids for events,
-            # so instantiate ID generators based on the database
-            self._stream_id_gen = StreamIdGenerator(
-                db_conn, "events", "stream_ordering",
+        if isinstance(database.engine, PostgresEngine):
+            # If we're using Postgres than we can use `MultiWriterIdGenerator`
+            # regardless of whether this process writes to the streams or not.
+            self._stream_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                instance_name=hs.get_instance_name(),
+                table="events",
+                instance_column="instance_name",
+                id_column="stream_ordering",
+                sequence_name="events_stream_seq",
             )
-            self._backfill_id_gen = StreamIdGenerator(
-                db_conn,
-                "events",
-                "stream_ordering",
-                step=-1,
-                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+            self._backfill_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                instance_name=hs.get_instance_name(),
+                table="events",
+                instance_column="instance_name",
+                id_column="stream_ordering",
+                sequence_name="events_backfill_stream_seq",
+                positive=False,
             )
         else:
-            # Another process is in charge of persisting events and generating
-            # stream IDs: rely on the replication streams to let us know which
-            # IDs we can process.
-            self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
-            self._backfill_id_gen = SlavedIdTracker(
-                db_conn, "events", "stream_ordering", step=-1
-            )
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.events:
+                self._stream_id_gen = StreamIdGenerator(
+                    db_conn, "events", "stream_ordering",
+                )
+                self._backfill_id_gen = StreamIdGenerator(
+                    db_conn,
+                    "events",
+                    "stream_ordering",
+                    step=-1,
+                    extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+                )
+            else:
+                self._stream_id_gen = SlavedIdTracker(
+                    db_conn, "events", "stream_ordering"
+                )
+                self._backfill_id_gen = SlavedIdTracker(
+                    db_conn, "events", "stream_ordering", step=-1
+                )
 
         self._get_event_cache = Cache(
             "*getEvent*",
@@ -823,20 +851,24 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_dict
 
-    def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+    def _maybe_redact_event_row(
+        self,
+        original_ev: EventBase,
+        redactions: Iterable[str],
+        event_map: Dict[str, EventBase],
+    ) -> Optional[EventBase]:
         """Given an event object and a list of possible redacting event ids,
         determine whether to honour any of those redactions and if so return a redacted
         event.
 
         Args:
-             original_ev (EventBase):
-             redactions (iterable[str]): list of event ids of potential redaction events
-             event_map (dict[str, EventBase]): other events which have been fetched, in
-                 which we can look up the redaaction events. Map from event id to event.
+             original_ev: The original event.
+             redactions: list of event ids of potential redaction events
+             event_map: other events which have been fetched, in which we can
+                look up the redaaction events. Map from event id to event.
 
         Returns:
-            Deferred[EventBase|None]: if the event should be redacted, a pruned
-                event object. Otherwise, None.
+            If the event should be redacted, a pruned event object. Otherwise, None.
         """
         if original_ev.type == "m.room.create":
             # we choose to ignore redactions of m.room.create events.
@@ -946,17 +978,17 @@ class EventsWorkerStore(SQLBaseStore):
         row = txn.fetchone()
         return row[0] if row else 0
 
-    def get_current_state_event_counts(self, room_id):
+    async def get_current_state_event_counts(self, room_id: str) -> int:
         """
         Gets the current number of state events in a room.
 
         Args:
-            room_id (str)
+            room_id: The room ID to query.
 
         Returns:
-            Deferred[int]
+            The current number of state events.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_current_state_event_counts",
             self._get_current_state_event_counts_txn,
             room_id,
@@ -991,7 +1023,9 @@ class EventsWorkerStore(SQLBaseStore):
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
-    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+    async def get_all_new_forward_event_rows(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -999,7 +1033,7 @@ class EventsWorkerStore(SQLBaseStore):
             current_id: the maximum stream_id to return up to
             limit: the maximum number of rows to return
 
-        Returns: Deferred[List[Tuple]]
+        Returns:
             a list of events stream rows. Each tuple consists of a stream id as
             the first element, followed by fields suitable for casting into an
             EventsStreamRow.
@@ -1020,18 +1054,20 @@ class EventsWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
         )
 
-    def get_ex_outlier_stream_rows(self, last_id, current_id):
+    async def get_ex_outlier_stream_rows(
+        self, last_id: int, current_id: int
+    ) -> List[Tuple]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
             last_id: the last stream_id from the previous batch.
             current_id: the maximum stream_id to return up to
 
-        Returns: Deferred[List[Tuple]]
+        Returns:
             a list of events stream rows. Each tuple consists of a stream id as
             the first element, followed by fields suitable for casting into an
             EventsStreamRow.
@@ -1054,7 +1090,7 @@ class EventsWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id))
             return txn.fetchall()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
         )
 
@@ -1226,11 +1262,11 @@ class EventsWorkerStore(SQLBaseStore):
 
         return (int(res["topological_ordering"]), int(res["stream_ordering"]))
 
-    def get_next_event_to_expire(self):
+    async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
         """Retrieve the entry with the lowest expiry timestamp in the event_expiry
         table, or None if there's no more event to expire.
 
-        Returns: Deferred[Optional[Tuple[str, int]]]
+        Returns:
             A tuple containing the event ID as its first element and an expiry timestamp
             as its second one, if there's at least one row in the event_expiry table.
             None otherwise.
@@ -1246,6 +1282,6 @@ class EventsWorkerStore(SQLBaseStore):
 
             return txn.fetchone()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 45a1760170..d2f5b9a502 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-    def add_user_filter(self, user_localpart, user_filter):
+    async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
         def_json = encode_canonical_json(user_filter)
 
         # Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
 
             return filter_id
 
-        return self.db_pool.runInteraction("add_user_filter", _do_txn)
+        return await self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 3919ecad69..86557d5512 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
@@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="mark_local_media_as_safe",
         )
 
-    def get_url_cache(self, url, ts):
+    async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
             None if the URL isn't cached.
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 )
             )
 
-        return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
+        return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
     async def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
@@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_cached_remote_media",
         )
 
-    def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+    async def update_cached_last_access_time(
+        self,
+        local_media: Iterable[str],
+        remote_media: Iterable[Tuple[str, str]],
+        time_ms: int,
+    ):
         """Updates the last access time of the given media
 
         Args:
-            local_media (iterable[str]): Set of media_ids
-            remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+            local_media: Set of media_ids
+            remote_media: Set of (server_name, media_id)
             time_ms: Current time in milliseconds
         """
 
@@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
         )
 
@@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
         )
 
-    def delete_remote_media(self, media_origin, media_id):
+    async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
         def delete_remote_media_txn(txn):
             self.db_pool.simple_delete_txn(
                 txn,
@@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_remote_media", delete_remote_media_txn
         )
 
-    def get_expired_url_cache(self, now_ts):
+    async def get_expired_url_cache(self, now_ts: int) -> List[str]:
         sql = (
             "SELECT media_id FROM local_media_repository_url_cache"
             " WHERE expires_ts < ?"
@@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (now_ts,))
             return [row[0] for row in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_expired_url_cache", _get_expired_url_cache_txn
         )
 
@@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "delete_url_cache", _delete_url_cache_txn
         )
 
-    def get_url_cache_media_before(self, before_ts):
+    async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
         sql = (
             "SELECT media_id FROM local_media_repository"
             " WHERE created_ts < ? AND url_cache IS NOT NULL"
@@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (before_ts,))
             return [row[0] for row in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_url_cache_media_before", _get_url_cache_media_before_txn
         )
 
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index 4db8949da7..2aac64901b 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
 from synapse.storage._base import SQLBaseStore
 
 
@@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
             desc="insert_open_id_token",
         )
 
-    def get_user_id_for_open_id_token(self, token, ts_now_ms):
+    async def get_user_id_for_open_id_token(
+        self, token: str, ts_now_ms: int
+    ) -> Optional[str]:
         def get_user_id_for_token_txn(txn):
             sql = (
                 "SELECT user_id FROM open_id_tokens"
@@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
             else:
                 return rows[0][0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_user_id_for_token", get_user_id_for_token_txn
         )
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 301875a672..d2e0685e9e 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
                 desc="delete_remote_profile_cache",
             )
 
-    def get_remote_profile_cache_entries_that_expire(self, last_checked):
+    async def get_remote_profile_cache_entries_that_expire(
+        self, last_checked: int
+    ) -> Dict[str, str]:
         """Get all users who haven't been checked since `last_checked`
         """
 
@@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
 
             return self.db_pool.cursor_to_dict(txn)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_remote_profile_cache_entries_that_expire",
             _get_remote_profile_cache_entries_that_expire_txn,
         )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3526b6fd66..ea833829ae 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Tuple
+from typing import Any, List, Set, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
@@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
 
 
 class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
-    def purge_history(self, room_id, token, delete_local_events):
+    async def purge_history(
+        self, room_id: str, token: str, delete_local_events: bool
+    ) -> Set[int]:
         """Deletes room history before a certain point
 
         Args:
-            room_id (str):
-
-            token (str): A topological token to delete events before
-
-            delete_local_events (bool):
+            room_id:
+            token: A topological token to delete events before
+            delete_local_events:
                 if True, we will delete local events as well as remote ones
                 (instead of just marking them as outliers and deleting their
                 state groups).
 
         Returns:
-            Deferred[set[int]]: The set of state groups that are referenced by
-            deleted events.
+            The set of state groups that are referenced by deleted events.
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "purge_history",
             self._purge_history_txn,
             room_id,
@@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
 
         return referenced_state_groups
 
-    def purge_room(self, room_id):
+    async def purge_room(self, room_id: str) -> List[int]:
         """Deletes all record of a room
 
         Args:
-            room_id (str)
+            room_id
 
         Returns:
-            Deferred[List[int]]: The list of state groups to delete.
+            The list of state groups to delete.
         """
-
-        return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
+        return await self.db_pool.runInteraction(
+            "purge_room", self._purge_room_txn, room_id
+        )
 
     def _purge_room_txn(self, txn, room_id):
         # First we fetch all the state groups that should be deleted, before
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2fb5b02d7d..0de802a86b 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -18,8 +18,6 @@ import abc
 import logging
 from typing import List, Tuple, Union
 
-from twisted.internet import defer
-
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -149,9 +147,11 @@ class PushRulesWorkerStore(
         )
         return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
+    async def have_push_rules_changed_for_user(
+        self, user_id: str, last_id: int
+    ) -> bool:
         if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
+            return False
         else:
 
             def have_push_rules_changed_txn(txn):
@@ -163,7 +163,7 @@ class PushRulesWorkerStore(
                 (count,) = txn.fetchone()
                 return bool(count)
 
-            return self.db_pool.runInteraction(
+            return await self.db_pool.runInteraction(
                 "have_push_rules_changed", have_push_rules_changed_txn
             )
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 436f22ad2d..4a0d5a320e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
         return results
 
-    def get_users_sent_receipts_between(self, last_id: int, current_id: int):
+    async def get_users_sent_receipts_between(
+        self, last_id: int, current_id: int
+    ) -> List[str]:
         """Get all users who sent receipts between `last_id` exclusive and
         `current_id` inclusive.
 
         Returns:
-            Deferred[List[str]]
+            The list of users.
         """
 
         if last_id == current_id:
@@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return [r[0] for r in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
         )
 
@@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
         return stream_id, max_persisted_id
 
-    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
-        return self.db_pool.runInteraction(
+    async def insert_graph_receipt(
+        self, room_id, receipt_type, user_id, event_ids, data
+    ):
+        return await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 12689f4308..01f20c03c2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -84,22 +84,22 @@ class RegistrationWorkerStore(SQLBaseStore):
         return is_trial
 
     @cached()
-    def get_user_by_access_token(self, token):
+    async def get_user_by_access_token(self, token: str) -> Optional[dict]:
         """Get a user from the given access token.
 
         Args:
-            token (str): The access token of a user.
+            token: The access token of a user.
         Returns:
-            defer.Deferred: None, if the token did not match, otherwise dict
-                including the keys `name`, `is_guest`, `device_id`, `token_id`,
-                `valid_until_ms`.
+            None, if the token did not match, otherwise dict
+            including the keys `name`, `is_guest`, `device_id`, `token_id`,
+            `valid_until_ms`.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
         )
 
     @cached()
-    async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
+    async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
         """Get the expiration timestamp for the account bearing a given user ID.
 
         Args:
@@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return bool(res) if res else False
 
-    def set_server_admin(self, user, admin):
+    async def set_server_admin(self, user: UserID, admin: bool) -> None:
         """Sets whether a user is an admin of this homeserver.
 
         Args:
-            user (UserID): user ID of the user to test
-            admin (bool): true iff the user is to be a server admin,
-                false otherwise.
+            user: user ID of the user to test
+            admin: true iff the user is to be a server admin, false otherwise.
         """
 
         def set_server_admin_txn(txn):
@@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 txn, self.get_user_by_id, (user.to_string(),)
             )
 
-        return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+        await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
     def _query_for_auth(self, txn, token):
         sql = (
@@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
         return True if res == UserTypes.SUPPORT else False
 
-    def get_users_by_id_case_insensitive(self, user_id):
+    async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
         """Gets users that match user_id case insensitively.
-        Returns a mapping of user_id -> password_hash.
+
+        Returns:
+             A mapping of user_id -> password_hash.
         """
 
         def f(txn):
@@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return dict(txn)
 
-        return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+        return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
@@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_users", _count_users)
 
-    def count_daily_user_type(self):
+    async def count_daily_user_type(self) -> Dict[str, int]:
         """
         Counts 1) native non guest users
                2) native guests users
@@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 results[row[0]] = row[1]
             return results
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_daily_user_type", _count_daily_user_type
         )
 
@@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
         # Convert the integer into a boolean.
         return res == 1
 
-    def get_threepid_validation_session(
-        self, medium, client_secret, address=None, sid=None, validated=True
-    ):
+    async def get_threepid_validation_session(
+        self,
+        medium: Optional[str],
+        client_secret: str,
+        address: Optional[str] = None,
+        sid: Optional[str] = None,
+        validated: Optional[bool] = True,
+    ) -> Optional[Dict[str, Any]]:
         """Gets a session_id and last_send_attempt (if available) for a
         combination of validation metadata
 
         Args:
-            medium (str|None): The medium of the 3PID
-            address (str|None): The address of the 3PID
-            sid (str|None): The ID of the validation session
-            client_secret (str): A unique string provided by the client to help identify this
+            medium: The medium of the 3PID
+            client_secret: A unique string provided by the client to help identify this
                 validation attempt
-            validated (bool|None): Whether sessions should be filtered by
+            address: The address of the 3PID
+            sid: The ID of the validation session
+            validated: Whether sessions should be filtered by
                 whether they have been validated already or not. None to
                 perform no filtering
 
         Returns:
-            Deferred[dict|None]: A dict containing the following:
+            A dict containing the following:
                 * address - address of the 3pid
                 * medium - medium of the 3pid
                 * client_secret - a secret provided by the client for this validation session
@@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
 
             return rows[0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_threepid_validation_session", get_threepid_validation_session_txn
         )
 
-    def delete_threepid_session(self, session_id):
+    async def delete_threepid_session(self, session_id: str) -> None:
         """Removes a threepid validation session from the database. This can
         be done after validation has been performed and whatever action was
         waiting on it has been carried out
 
         Args:
-            session_id (str): The ID of the session to delete
+            session_id: The ID of the session to delete
         """
 
         def delete_threepid_session_txn(txn):
@@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 keyvalues={"session_id": session_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_threepid_session", delete_threepid_session_txn
         )
 
@@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_access_token_to_user",
         )
 
-    def register_user(
+    async def register_user(
         self,
-        user_id,
-        password_hash=None,
-        was_guest=False,
-        make_guest=False,
-        appservice_id=None,
-        create_profile_with_displayname=None,
-        admin=False,
-        user_type=None,
-        shadow_banned=False,
-    ):
+        user_id: str,
+        password_hash: Optional[str] = None,
+        was_guest: bool = False,
+        make_guest: bool = False,
+        appservice_id: Optional[str] = None,
+        create_profile_with_displayname: Optional[str] = None,
+        admin: bool = False,
+        user_type: Optional[str] = None,
+        shadow_banned: bool = False,
+    ) -> None:
         """Attempts to register an account.
 
         Args:
-            user_id (str): The desired user ID to register.
-            password_hash (str|None): Optional. The password hash for this user.
-            was_guest (bool): Optional. Whether this is a guest account being
-                upgraded to a non-guest account.
-            make_guest (boolean): True if the the new user should be guest,
-                false to add a regular user account.
-            appservice_id (str): The ID of the appservice registering the user.
-            create_profile_with_displayname (unicode): Optionally create a profile for
+            user_id: The desired user ID to register.
+            password_hash: Optional. The password hash for this user.
+            was_guest: Whether this is a guest account being upgraded to a
+                non-guest account.
+            make_guest: True if the the new user should be guest, false to add a
+                regular user account.
+            appservice_id: The ID of the appservice registering the user.
+            create_profile_with_displayname: Optionally create a profile for
                 the user, setting their displayname to the given value
-            admin (boolean): is an admin user?
-            user_type (str|None): type of user. One of the values from
-                api.constants.UserTypes, or None for a normal user.
-            shadow_banned (bool): Whether the user is shadow-banned,
-                i.e. they may be told their requests succeeded but we ignore them.
+            admin: is an admin user?
+            user_type: type of user. One of the values from api.constants.UserTypes,
+                or None for a normal user.
+            shadow_banned: Whether the user is shadow-banned, i.e. they may be
+                told their requests succeeded but we ignore them.
 
         Raises:
             StoreError if the user_id could not be registered.
-
-        Returns:
-            Deferred
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "register_user",
             self._register_user,
             user_id,
@@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="record_user_external_id",
         )
 
-    def user_set_password_hash(self, user_id, password_hash):
+    async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
         """
         NB. This does *not* evict any cache because the one use for this
             removes most of the entries subsequently anyway so it would be
@@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "user_set_password_hash", user_set_password_hash_txn
         )
 
-    def user_set_consent_version(self, user_id, consent_version):
+    async def user_set_consent_version(
+        self, user_id: str, consent_version: str
+    ) -> None:
         """Updates the user table to record privacy policy consent
 
         Args:
-            user_id (str): full mxid of the user to update
-            consent_version (str): version of the policy the user has consented
-                to
+            user_id: full mxid of the user to update
+            consent_version: version of the policy the user has consented to
 
         Raises:
             StoreError(404) if user not found
@@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction("user_set_consent_version", f)
+        await self.db_pool.runInteraction("user_set_consent_version", f)
 
-    def user_set_consent_server_notice_sent(self, user_id, consent_version):
+    async def user_set_consent_server_notice_sent(
+        self, user_id: str, consent_version: str
+    ) -> None:
         """Updates the user table to record that we have sent the user a server
         notice about privacy policy consent
 
         Args:
-            user_id (str): full mxid of the user to update
-            consent_version (str): version of the policy we have notified the
-                user about
+            user_id: full mxid of the user to update
+            consent_version: version of the policy we have notified the user about
 
         Raises:
             StoreError(404) if user not found
@@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
+        await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
 
-    def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
+    async def user_delete_access_tokens(
+        self,
+        user_id: str,
+        except_token_id: Optional[str] = None,
+        device_id: Optional[str] = None,
+    ) -> List[Tuple[str, int, Optional[str]]]:
         """
         Invalidate access tokens belonging to a user
 
         Args:
-            user_id (str):  ID of user the tokens belong to
-            except_token_id (str): list of access_tokens IDs which should
-                *not* be deleted
-            device_id (str|None):  ID of device the tokens are associated with.
+            user_id: ID of user the tokens belong to
+            except_token_id: access_tokens ID which should *not* be deleted
+            device_id: ID of device the tokens are associated with.
                 If None, tokens associated with any device (or no device) will
                 be deleted
         Returns:
-            defer.Deferred[list[str, int, str|None, int]]: a list of
-                (token, token id, device id) for each of the deleted tokens
+            A tuple of (token, token id, device id) for each of the deleted tokens
         """
 
         def f(txn):
@@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
             return tokens_and_devices
 
-        return self.db_pool.runInteraction("user_delete_access_tokens", f)
+        return await self.db_pool.runInteraction("user_delete_access_tokens", f)
 
-    def delete_access_token(self, access_token):
+    async def delete_access_token(self, access_token: str) -> None:
         def f(txn):
             self.db_pool.simple_delete_one_txn(
                 txn, table="access_tokens", keyvalues={"token": access_token}
@@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 txn, self.get_user_by_access_token, (access_token,)
             )
 
-        return self.db_pool.runInteraction("delete_access_token", f)
+        await self.db_pool.runInteraction("delete_access_token", f)
 
     @cached()
     async def is_guest(self, user_id: str) -> bool:
@@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="get_users_pending_deactivation",
         )
 
-    def validate_threepid_session(self, session_id, client_secret, token, current_ts):
+    async def validate_threepid_session(
+        self, session_id: str, client_secret: str, token: str, current_ts: int
+    ) -> Optional[str]:
         """Attempt to validate a threepid session using a token
 
         Args:
-            session_id (str): The id of a validation session
-            client_secret (str): A unique string provided by the client to
-                help identify this validation attempt
-            token (str): A validation token
-            current_ts (int): The current unix time in milliseconds. Used for
-                checking token expiry status
+            session_id: The id of a validation session
+            client_secret: A unique string provided by the client to help identify
+                this validation attempt
+            token: A validation token
+            current_ts: The current unix time in milliseconds. Used for checking
+                token expiry status
 
         Raises:
             ThreepidValidationError: if a matching validation token was not found or has
                 expired
 
         Returns:
-            deferred str|None: A str representing a link to redirect the user
-            to if there is one.
+            A str representing a link to redirect the user to if there is one.
         """
 
         # Insert everything into a transaction in order to run atomically
@@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             return next_link
 
         # Return next_link if it exists
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "validate_threepid_session_txn", validate_threepid_session_txn
         )
 
-    def start_or_continue_validation_session(
+    async def start_or_continue_validation_session(
         self,
-        medium,
-        address,
-        session_id,
-        client_secret,
-        send_attempt,
-        next_link,
-        token,
-        token_expires,
-    ):
+        medium: str,
+        address: str,
+        session_id: str,
+        client_secret: str,
+        send_attempt: int,
+        next_link: Optional[str],
+        token: str,
+        token_expires: int,
+    ) -> None:
         """Creates a new threepid validation session if it does not already
         exist and associates a new validation token with it
 
         Args:
-            medium (str): The medium of the 3PID
-            address (str): The address of the 3PID
-            session_id (str): The id of this validation session
-            client_secret (str): A unique string provided by the client to
-                help identify this validation attempt
-            send_attempt (int): The latest send_attempt on this session
-            next_link (str|None): The link to redirect the user to upon
-                successful validation
-            token (str): The validation token
-            token_expires (int): The timestamp for which after the token
-                will no longer be valid
+            medium: The medium of the 3PID
+            address: The address of the 3PID
+            session_id: The id of this validation session
+            client_secret: A unique string provided by the client to help
+                identify this validation attempt
+            send_attempt: The latest send_attempt on this session
+            next_link: The link to redirect the user to upon successful validation
+            token: The validation token
+            token_expires: The timestamp for which after the token will no
+                longer be valid
         """
 
         def start_or_continue_validation_session_txn(txn):
@@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 },
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "start_or_continue_validation_session",
             start_or_continue_validation_session_txn,
         )
 
-    def cull_expired_threepid_validation_tokens(self):
+    async def cull_expired_threepid_validation_tokens(self) -> None:
         """Remove threepid validation tokens with expiry dates that have passed"""
 
         def cull_expired_threepid_validation_tokens_txn(txn, ts):
@@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             DELETE FROM threepid_validation_token WHERE
             expires < ?
             """
-            return txn.execute(sql, (ts,))
+            txn.execute(sql, (ts,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "cull_expired_threepid_validation_tokens",
             cull_expired_threepid_validation_tokens_txn,
             self.clock.time_msec(),
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index a9ceffc20e..5cd61547f7 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
 
 class RelationsWorkerStore(SQLBaseStore):
     @cached(tree=True)
-    def get_relations_for_event(
+    async def get_relations_for_event(
         self,
-        event_id,
-        relation_type=None,
-        event_type=None,
-        aggregation_key=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
+        event_id: str,
+        relation_type: Optional[str] = None,
+        event_type: Optional[str] = None,
+        aggregation_key: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[RelationPaginationToken] = None,
+        to_token: Optional[RelationPaginationToken] = None,
+    ) -> PaginationChunk:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
-            event_id (str): Fetch events that relate to this event ID.
-            relation_type (str|None): Only fetch events with this relation
-                type, if given.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            aggregation_key (str|None): Only fetch events with this aggregation
-                key, if given.
-            limit (int): Only fetch the most recent `limit` events.
-            direction (str): Whether to fetch the most recent first (`"b"`) or
-                the oldest first (`"f"`).
-            from_token (RelationPaginationToken|None): Fetch rows from the given
-                token, or from the start if None.
-            to_token (RelationPaginationToken|None): Fetch rows up to the given
-                token, or up to the end if None.
+            event_id: Fetch events that relate to this event ID.
+            relation_type: Only fetch events with this relation type, if given.
+            event_type: Only fetch events with this event type, if given.
+            aggregation_key: Only fetch events with this aggregation key, if given.
+            limit: Only fetch the most recent `limit` events.
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`).
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            Deferred[PaginationChunk]: List of event IDs that match relations
-            requested. The rows are of the form `{"event_id": "..."}`.
+            List of event IDs that match relations requested. The rows are of
+            the form `{"event_id": "..."}`.
         """
 
         where_clause = ["relates_to_id = ?"]
@@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
     @cached(tree=True)
-    def get_aggregation_groups_for_event(
+    async def get_aggregation_groups_for_event(
         self,
-        event_id,
-        event_type=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
+        event_id: str,
+        event_type: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[AggregationPaginationToken] = None,
+        to_token: Optional[AggregationPaginationToken] = None,
+    ) -> PaginationChunk:
         """Get a list of annotations on the event, grouped by event type and
         aggregation key, sorted by count.
 
@@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
         on an event.
 
         Args:
-            event_id (str): Fetch events that relate to this event ID.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            limit (int): Only fetch the `limit` groups.
-            direction (str): Whether to fetch the highest count first (`"b"`) or
+            event_id: Fetch events that relate to this event ID.
+            event_type: Only fetch events with this event type, if given.
+            limit: Only fetch the `limit` groups.
+            direction: Whether to fetch the highest count first (`"b"`) or
                 the lowest count first (`"f"`).
-            from_token (AggregationPaginationToken|None): Fetch rows from the
-                given token, or from the start if None.
-            to_token (AggregationPaginationToken|None): Fetch rows up to the
-                given token, or up to the end if None.
-
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            Deferred[PaginationChunk]: List of groups of annotations that
-            match. Each row is a dict with `type`, `key` and `count` fields.
+            List of groups of annotations that match. Each row is a dict with
+            `type`, `key` and `count` fields.
         """
 
         where_clause = ["relates_to_id = ?", "relation_type = ?"]
@@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
@@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return await self.get_event(edit_id, allow_none=True)
 
-    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+    async def has_user_annotated_event(
+        self, parent_id: str, event_type: str, aggregation_key: str, sender: str
+    ) -> bool:
         """Check if a user has already annotated an event with the same key
         (e.g. already liked an event).
 
         Args:
-            parent_id (str): The event being annotated
-            event_type (str): The event type of the annotation
-            aggregation_key (str): The aggregation key of the annotation
-            sender (str): The sender of the annotation
+            parent_id: The event being annotated
+            event_type: The event type of the annotation
+            aggregation_key: The aggregation key of the annotation
+            sender: The sender of the annotation
 
         Returns:
-            Deferred[bool]
+            True if the event is already annotated.
         """
 
         sql = """
@@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             return bool(txn.fetchone())
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a92641c339..717df97301 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    def get_room_with_stats(self, room_id: str):
+    async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
         """Retrieve room with statistics.
 
         Args:
@@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
             res["public"] = bool(res["public"])
             return res
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_room_with_stats", get_room_with_stats_txn, room_id
         )
 
@@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
             desc="get_public_room_ids",
         )
 
-    def count_public_rooms(self, network_tuple, ignore_non_federatable):
+    async def count_public_rooms(
+        self,
+        network_tuple: Optional[ThirdPartyInstanceID],
+        ignore_non_federatable: bool,
+    ) -> int:
         """Counts the number of public rooms as tracked in the room_stats_current
         and room_stats_state table.
 
         Args:
-            network_tuple (ThirdPartyInstanceID|None)
-            ignore_non_federatable (bool): If true filters out non-federatable rooms
+            network_tuple
+            ignore_non_federatable: If true filters out non-federatable rooms
         """
 
         def _count_public_rooms_txn(txn):
@@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
             txn.execute(sql, query_args)
             return txn.fetchone()[0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_public_rooms", _count_public_rooms_txn
         )
 
@@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
 
         return row
 
-    def get_media_mxcs_in_room(self, room_id):
+    async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
         """Retrieves all the local and remote media MXC URIs in a given room
 
         Args:
-            room_id (str)
+            room_id
 
         Returns:
-            The local and remote media as a lists of tuples where the key is
-            the hostname and the value is the media ID.
+            The local and remote media as a lists of the media IDs.
         """
 
         def _get_media_mxcs_in_room_txn(txn):
@@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
 
             return local_media_mxcs, remote_media_mxcs
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_media_ids_in_room", _get_media_mxcs_in_room_txn
         )
 
-    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+    async def quarantine_media_ids_in_room(
+        self, room_id: str, quarantined_by: str
+    ) -> int:
         """For a room loops through all events with media and quarantines
         the associated media
         """
@@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
@@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
 
         return local_media_mxcs, remote_media_mxcs
 
-    def quarantine_media_by_id(
+    async def quarantine_media_by_id(
         self, server_name: str, media_id: str, quarantined_by: str,
-    ):
+    ) -> int:
         """quarantines a single local or remote media id
 
         Args:
@@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_by_user", _quarantine_media_by_id_txn
         )
 
-    def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+    async def quarantine_media_ids_by_user(
+        self, user_id: str, quarantined_by: str
+    ) -> int:
         """quarantines all local media associated with a single user
 
         Args:
@@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
             local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
             return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_by_user", _quarantine_media_by_user_txn
         )
 
@@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    def get_room_count(self):
-        """Retrieve a list of all rooms
+    async def get_room_count(self) -> int:
+        """Retrieve the total number of rooms.
         """
 
         def f(txn):
@@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             row = txn.fetchone()
             return row[0] or 0
 
-        return self.db_pool.runInteraction("get_rooms", f)
+        return await self.db_pool.runInteraction("get_rooms", f)
 
     async def add_event_report(
         self,
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 161edbeccb..c46f5cd524 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
@@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
     @cached(max_entries=100000, iterable=True)
-    def get_users_in_room(self, room_id: str):
-        return self.db_pool.runInteraction(
+    async def get_users_in_room(self, room_id: str) -> List[str]:
+        return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
@@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return [r[0] for r in txn]
 
     @cached(max_entries=100000)
-    def get_room_summary(self, room_id: str):
+    async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
         """ Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
             room_id: The room ID to query
         Returns:
-            Deferred[dict[str, MemberSummary]:
-                dict of membership states, pointing to a MemberSummary named tuple.
+            dict of membership states, pointing to a MemberSummary named tuple.
         """
 
         def _get_room_summary_txn(txn):
@@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
             return res
 
-        return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
+        return await self.db_pool.runInteraction(
+            "get_room_summary", _get_room_summary_txn
+        )
 
     @cached()
-    def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+    async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
         """Get all the rooms the *local* user is invited to.
 
         Args:
             user_id: The user ID.
 
         Returns:
-            A awaitable list of RoomsForUser.
+            A list of RoomsForUser.
         """
 
-        return self.get_rooms_for_local_user_where_membership_is(
+        return await self.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE]
         )
 
@@ -297,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return None
 
     async def get_rooms_for_local_user_where_membership_is(
-        self, user_id: str, membership_list: List[str]
-    ) -> Optional[List[RoomsForUser]]:
+        self, user_id: str, membership_list: Collection[str]
+    ) -> List[RoomsForUser]:
         """Get all the rooms for this *local* user where the membership for this user
         matches one in the membership list.
 
@@ -313,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             The RoomsForUser that the user matches the membership types.
         """
         if not membership_list:
-            return None
+            return []
 
         rooms = await self.db_pool.runInteraction(
             "get_rooms_for_local_user_where_membership_is",
@@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cached(max_entries=500000, iterable=True)
-    def get_rooms_for_user_with_stream_ordering(self, user_id: str):
+    async def get_rooms_for_user_with_stream_ordering(
+        self, user_id: str
+    ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         """Returns a set of room_ids the user is currently joined to.
 
         If a remote user only returns rooms this server is currently
@@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             user_id
 
         Returns:
-            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
-            the rooms the user is in currently, along with the stream ordering
-            of the most recent join for that user and room.
+            Returns the rooms the user is in currently, along with the stream
+            ordering of the most recent join for that user and room.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_rooms_for_user_with_stream_ordering",
             self._get_rooms_for_user_with_stream_ordering_txn,
             user_id,
         )
 
-    def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
+    def _get_rooms_for_user_with_stream_ordering_txn(
+        self, txn, user_id: str
+    ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
         # for rooms the server is participating in.
@@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
 
         txn.execute(sql, (user_id, Membership.JOIN))
-        results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
-
-        return results
+        return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
 
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
@@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return count == 0
 
     @cached()
-    def get_forgotten_rooms_for_user(self, user_id: str):
+    async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
         """Gets all rooms the user has forgotten.
 
         Args:
-            user_id
+            user_id: The user ID to query the rooms of.
 
         Returns:
-            Deferred[set[str]]
+            The forgotten rooms.
         """
 
         def _get_forgotten_rooms_for_user_txn(txn):
@@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (user_id,))
             return {row[0] for row in txn if row[1] == 0}
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
@@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
-    def forget(self, user_id: str, room_id: str):
+    async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
 
         def f(txn):
@@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                 txn, self.get_forgotten_rooms_for_user, (user_id,)
             )
 
-        return self.db_pool.runInteraction("forget_membership", f)
+        await self.db_pool.runInteraction("forget_membership", f)
 
 
 class _JoinedHostsCache(object):
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
new file mode 100644
index 0000000000..98ff76d709
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
new file mode 100644
index 0000000000..97c1e6a0c5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
+
+SELECT setval('events_stream_seq', (
+    SELECT COALESCE(MAX(stream_ordering), 1) FROM events
+));
+
+CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+
+SELECT setval('events_backfill_stream_seq', (
+    SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+));
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..b451e8663a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index dcbdeab36e..9c5f0229c1 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,9 +16,10 @@
 import logging
 import re
 from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Set
 
 from synapse.api.errors import SynapseError
+from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
             "count": count,
         }
 
-    def _find_highlights_in_postgres(self, search_query, events):
+    async def _find_highlights_in_postgres(
+        self, search_query: str, events: List[EventBase]
+    ) -> Set[str]:
         """Given a list of events and a search term, return a list of words
         that match from the content of the event.
 
@@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         highlight the matching parts.
 
         Args:
-            search_query (str)
-            events (list): A list of events
+            search_query
+            events: A list of events
 
         Returns:
-            deferred : A set of strings.
+            A set of strings.
         """
 
         def f(txn):
@@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
             return highlight_words
 
-        return self.db_pool.runInteraction("_find_highlights", f)
+        return await self.db_pool.runInteraction("_find_highlights", f)
 
 
 def _to_postgres_options(options_dict):
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index be191dd870..c8c67953e4 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -13,9 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Dict, Iterable, List, Tuple
+
 from unpaddedbase64 import encode_base64
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
 
 
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
     @cachedList(
         cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
     )
-    def get_event_reference_hashes(self, event_ids):
+    async def get_event_reference_hashes(
+        self, event_ids: Iterable[str]
+    ) -> Dict[str, Dict[str, bytes]]:
+        """Get all hashes for given events.
+
+        Args:
+            event_ids: The event IDs to get hashes for.
+
+        Returns:
+             A mapping of event ID to a mapping of algorithm to hash.
+        """
+
         def f(txn):
             return {
                 event_id: self._get_event_reference_hashes_txn(txn, event_id)
                 for event_id in event_ids
             }
 
-        return self.db_pool.runInteraction("get_event_reference_hashes", f)
+        return await self.db_pool.runInteraction("get_event_reference_hashes", f)
 
-    async def add_event_hashes(self, event_ids):
+    async def add_event_hashes(
+        self, event_ids: Iterable[str]
+    ) -> List[Tuple[str, Dict[str, str]]]:
+        """
+
+        Args:
+            event_ids: The event IDs
+
+        Returns:
+            A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
+        """
         hashes = await self.get_event_reference_hashes(event_ids)
         hashes = {
             e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
 
         return list(hashes.items())
 
-    def _get_event_reference_hashes_txn(self, txn, event_id):
+    def _get_event_reference_hashes_txn(
+        self, txn: Cursor, event_id: str
+    ) -> Dict[str, bytes]:
         """Get all the hashes for a given PDU.
         Args:
-            txn (cursor):
-            event_id (str): Id for the Event.
+            txn:
+            event_id: Id for the Event.
         Returns:
-            A dict[unicode, bytes] of algorithm -> hash.
+            A mapping of algorithm -> hash.
         """
         query = (
             "SELECT algorithm, hash"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9b9bc304a8..55a250ef06 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -224,14 +224,32 @@ class StatsStore(StateDeltasStore):
         )
 
     async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
-        """
+        """Update the state of a room.
+
+        fields can contain the following keys with string values:
+        * join_rules
+        * history_visibility
+        * encryption
+        * name
+        * topic
+        * avatar
+        * canonical_alias
+
+        A is_federatable key can also be included with a boolean value.
+
         Args:
-            room_id
-            fields
+            room_id: The room ID to update the state of.
+            fields: The fields to update. This can include a partial list of the
+                above fields to only update some room information.
         """
-
-        # For whatever reason some of the fields may contain null bytes, which
-        # postgres isn't a fan of, so we replace those fields with null.
+        # Ensure that the values to update are valid, they should be strings and
+        # not contain any null bytes.
+        #
+        # Invalid data gets overwritten with null.
+        #
+        # Note that a missing value should not be overwritten (it keeps the
+        # previous value).
+        sentinel = object()
         for col in (
             "join_rules",
             "history_visibility",
@@ -241,8 +259,8 @@ class StatsStore(StateDeltasStore):
             "avatar",
             "canonical_alias",
         ):
-            field = fields.get(col)
-            if field and "\0" in field:
+            field = fields.get(col, sentinel)
+            if field is not sentinel and (not isinstance(field, str) or "\0" in field):
                 fields[col] = None
 
         await self.db_pool.simple_upsert(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 24f44a7e36..be6df8a6d1 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,7 +39,7 @@ what sort order was used:
 import abc
 import logging
 from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 from twisted.internet import defer
 
@@ -47,12 +47,19 @@ from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import RoomStreamToken
+from synapse.types import Collection, RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -202,7 +209,7 @@ def _make_generic_sql_bound(
     )
 
 
-def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
+def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
     # "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -260,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
@@ -293,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         self._stream_order_on_start = self.get_room_max_stream_ordering()
 
     @abc.abstractmethod
-    def get_room_max_stream_ordering(self):
+    def get_room_max_stream_ordering(self) -> int:
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def get_room_min_stream_ordering(self):
+    def get_room_min_stream_ordering(self) -> int:
         raise NotImplementedError()
 
     async def get_room_events_stream_for_rooms(
         self,
-        room_ids: Iterable[str],
+        room_ids: Collection[str],
         from_key: str,
         to_key: str,
         limit: int = 0,
@@ -356,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return results
 
-    def get_rooms_that_changed(self, room_ids, from_key):
+    def get_rooms_that_changed(
+        self, room_ids: Collection[str], from_key: str
+    ) -> Set[str]:
         """Given a list of rooms and a token, return rooms where there may have
         been changes.
 
         Args:
-            room_ids (list)
-            from_key (str): The room_key portion of a StreamToken
+            room_ids
+            from_key: The room_key portion of a StreamToken
         """
-        from_key = RoomStreamToken.parse_stream_token(from_key).stream
+        from_id = RoomStreamToken.parse_stream_token(from_key).stream
         return {
             room_id
             for room_id in room_ids
-            if self._events_stream_cache.has_entity_changed(room_id, from_key)
+            if self._events_stream_cache.has_entity_changed(room_id, from_id)
         }
 
     async def get_room_events_stream_for_room(
@@ -440,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret, key
 
-    async def get_membership_changes_for_user(self, user_id, from_key, to_key):
+    async def get_membership_changes_for_user(
+        self, user_id: str, from_key: str, to_key: str
+    ) -> List[EventBase]:
         from_id = RoomStreamToken.parse_stream_token(from_key).stream
         to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
@@ -593,8 +604,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A stream ID.
         """
-        return await self.db_pool.simple_select_one_onecol(
-            table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
+        return await self.db_pool.runInteraction(
+            "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+        )
+
+    def get_stream_id_for_event_txn(
+        self, txn: LoggingTransaction, event_id: str, allow_none=False,
+    ) -> int:
+        return self.db_pool.simple_select_one_onecol_txn(
+            txn=txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            retcol="stream_ordering",
+            allow_none=allow_none,
         )
 
     async def get_stream_token_for_event(self, event_id: str) -> str:
@@ -646,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
         return row[0][0] if row else 0
 
-    def _get_max_topological_txn(self, txn, room_id):
+    def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
         txn.execute(
             "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
             (room_id,),
@@ -719,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _get_events_around_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         room_id: str,
         event_id: str,
         before_limit: int,
@@ -747,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             retcols=["stream_ordering", "topological_ordering"],
         )
 
+        # This cannot happen as `allow_none=False`.
+        assert results is not None
+
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
@@ -856,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="update_federation_out_pos",
         )
 
-    def _reset_federation_positions_txn(self, txn) -> None:
+    def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
         """Fiddles with the `federation_stream_position` table to make it match
         the configured federation sender instances during start up.
         """
@@ -895,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             GROUP BY type
         """
         txn.execute(sql)
-        min_positions = dict(txn)  # Map from type -> min position
+        min_positions = {typ: pos for typ, pos in txn}  # Map from type -> min position
 
         # Ensure we do actually have some values here
         assert set(min_positions) == {"federation", "events"}
@@ -922,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _paginate_room_events_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         room_id: str,
         from_token: RoomStreamToken,
         to_token: Optional[RoomStreamToken] = None,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 0c34bbf21a..96ffe26cc9 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
-        tags_by_room = {}
+        tags_by_room = {}  # type: Dict[str, Dict[str, JsonDict]]
         for row in rows:
             room_tags = tags_by_room.setdefault(row["room_id"], {})
             room_tags[row["tag"]] = db_to_json(row["content"])
@@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
-    ) -> Dict[str, List[str]]:
+    ) -> Dict[str, Dict[str, JsonDict]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version
 
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 9eef8e57c5..b89668d561 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
 
 
 class UIAuthStore(UIAuthWorkerStore):
-    def delete_old_ui_auth_sessions(self, expiration_time: int):
+    async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
         """
         Remove sessions which were last used earlier than the expiration time.
 
@@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
                 This is an epoch time in milliseconds.
 
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_old_ui_auth_sessions",
             self._delete_old_ui_auth_sessions_txn,
             expiration_time,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a9f2e93614..f2f9a5799a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,7 +15,7 @@
 
 import logging
 import re
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Iterable, Optional, Set, Tuple
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.database import DatabasePool
@@ -365,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         return False
 
-    def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+    async def update_profile_in_user_dir(
+        self, user_id: str, display_name: str, avatar_url: str
+    ) -> None:
         """
         Update or add a user's profile in the user directory.
         """
+        # If the display name or avatar URL are unexpected types, overwrite them.
+        if not isinstance(display_name, str):
+            display_name = None
+        if not isinstance(avatar_url, str):
+            avatar_url = None
 
         def _update_profile_in_user_dir_txn(txn):
             new_entry = self.db_pool.simple_upsert_txn(
@@ -458,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_profile_in_user_dir", _update_profile_in_user_dir_txn
         )
 
-    def add_users_who_share_private_room(self, room_id, user_id_tuples):
+    async def add_users_who_share_private_room(
+        self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
+    ) -> None:
         """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
 
         Args:
-            room_id (str)
-            user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+            room_id
+            user_id_tuples: iterable of 2-tuple of user IDs.
         """
 
         def _add_users_who_share_room_txn(txn):
@@ -484,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 value_values=None,
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
-    def add_users_in_public_rooms(self, room_id, user_ids):
+    async def add_users_in_public_rooms(
+        self, room_id: str, user_ids: Iterable[str]
+    ) -> None:
         """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
 
         Args:
-            room_id (str)
-            user_ids (list[str])
+            room_id
+            user_ids
         """
 
         def _add_users_in_public_rooms_txn(txn):
@@ -508,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 value_values=None,
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "add_users_in_public_rooms", _add_users_in_public_rooms_txn
         )
 
-    def delete_all_from_user_dir(self):
+    async def delete_all_from_user_dir(self) -> None:
         """Delete the entire user directory
         """
 
@@ -523,7 +534,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             txn.execute("DELETE FROM users_who_share_private_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
 
@@ -555,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(UserDirectoryStore, self).__init__(database, db_conn, hs)
 
-    def remove_from_user_dir(self, user_id):
+    async def remove_from_user_dir(self, user_id: str) -> None:
         def _remove_from_user_dir_txn(txn):
             self.db_pool.simple_delete_txn(
                 txn, table="user_directory", keyvalues={"user_id": user_id}
@@ -578,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             )
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "remove_from_user_dir", _remove_from_user_dir_txn
         )
 
@@ -605,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
         return user_ids
 
-    def remove_user_who_share_room(self, user_id, room_id):
+    async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
         """
         Deletes entries in the users_who_share_*_rooms table. The first
         user should be a local user.
 
         Args:
-            user_id (str)
-            room_id (str)
+            user_id
+            room_id
         """
 
         def _remove_user_who_share_room_txn(txn):
@@ -632,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
@@ -664,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
+    @cached()
+    async def get_shared_rooms_for_users(
+        self, user_id: str, other_user_id: str
+    ) -> Set[str]:
+        """
+        Returns the rooms that a local user shares with another local or remote user.
+
+        Args:
+            user_id: The MXID of a local user
+            other_user_id: The MXID of the other user
+
+        Returns:
+            A set of room ID's that the users share.
+        """
+
+        def _get_shared_rooms_for_users_txn(txn):
+            txn.execute(
+                """
+                SELECT p1.room_id
+                FROM users_in_public_rooms as p1
+                INNER JOIN users_in_public_rooms as p2
+                    ON p1.room_id = p2.room_id
+                    AND p1.user_id = ?
+                    AND p2.user_id = ?
+                UNION
+                SELECT room_id
+                FROM users_who_share_private_rooms
+                WHERE
+                    user_id = ?
+                    AND other_user_id = ?
+                """,
+                (user_id, other_user_id, user_id, other_user_id),
+            )
+            rows = self.db_pool.cursor_to_dict(txn)
+            return rows
+
+        rows = await self.db_pool.runInteraction(
+            "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+        )
+
+        return {row["room_id"] for row in rows}
+
     async def get_user_directory_stream_pos(self) -> int:
         return await self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index e3547e53b3..2f7c95fc74 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
 
 
 class UserErasureStore(UserErasureWorkerStore):
-    def mark_user_erased(self, user_id: str) -> None:
+    async def mark_user_erased(self, user_id: str) -> None:
         """Indicate that user_id wishes their message history to be erased.
 
         Args:
@@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.db_pool.runInteraction("mark_user_erased", f)
+        await self.db_pool.runInteraction("mark_user_erased", f)
 
-    def mark_user_not_erased(self, user_id: str) -> None:
+    async def mark_user_not_erased(self, user_id: str) -> None:
         """Indicate that user_id is no longer erased.
 
         Args:
@@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.db_pool.runInteraction("mark_user_not_erased", f)
+        await self.db_pool.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b27a4843d0..8fd21c2bf8 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        positive: Whether the IDs are positive (true) or negative (false).
+            When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
 
     def __init__(
@@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        positive: bool = True,
     ):
         self._db = db
         self._instance_name = instance_name
+        self._positive = positive
+        self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
         self._lock = threading.Lock()
 
+        # Note: If we are a negative stream then we still store all the IDs as
+        # positive to make life easier for us, and simply negate the IDs when we
+        # return them.
         self._current_positions = self._load_current_ids(
             db_conn, table, instance_column, id_column
         )
@@ -223,8 +231,12 @@ class MultiWriterIdGenerator:
         # gaps should be relatively rare it's still worth doing the book keeping
         # that allows us to skip forwards when there are gapless runs of
         # positions.
+        #
+        # We start at 1 here as a) the first generated stream ID will be 2, and
+        # b) other parts of the code assume that stream IDs are strictly greater
+        # than 0.
         self._persisted_upto_position = (
-            min(self._current_positions.values()) if self._current_positions else 0
+            min(self._current_positions.values()) if self._current_positions else 1
         )
         self._known_persisted_positions = []  # type: List[int]
 
@@ -233,13 +245,16 @@ class MultiWriterIdGenerator:
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
     ) -> Dict[str, int]:
+        # If positive stream aggregate via MAX. For negative stream use MIN
+        # *and* negate the result to get a positive number.
         sql = """
-            SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
             GROUP BY %(instance)s
         """ % {
             "instance": instance_column,
             "id": id_column,
             "table": table,
+            "agg": "MAX" if self._positive else "-MIN",
         }
 
         cur = db_conn.cursor()
@@ -269,15 +284,16 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than what we currently
         # believe the ID to be. If not, then the sequence and table have got
         # out of sync somehow.
-        assert self.get_current_token_for_writer(self._instance_name) < next_id
-
         with self._lock:
+            assert self._current_positions.get(self._instance_name, 0) < next_id
+
             self._unfinished_ids.add(next_id)
 
         @contextlib.contextmanager
         def manager():
             try:
-                yield next_id
+                # Multiply by the return factor so that the ID has correct sign.
+                yield self._return_factor * next_id
             finally:
                 self._mark_id_as_finished(next_id)
 
@@ -296,15 +312,15 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than any ID we've already
         # seen. If not, then the sequence and table have got out of sync
         # somehow.
-        assert max(self.get_positions().values(), default=0) < min(next_ids)
-
         with self._lock:
+            assert max(self._current_positions.values(), default=0) < min(next_ids)
+
             self._unfinished_ids.update(next_ids)
 
         @contextlib.contextmanager
         def manager():
             try:
-                yield next_ids
+                yield [self._return_factor * i for i in next_ids]
             finally:
                 for i in next_ids:
                     self._mark_id_as_finished(i)
@@ -327,7 +343,7 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
-        return next_id
+        return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
         """The ID has finished being processed so we should advance the
@@ -350,29 +366,32 @@ class MultiWriterIdGenerator:
         equal to it have been successfully persisted.
         """
 
-        # Currently we don't support this operation, as it's not obvious how to
-        # condense the stream positions of multiple writers into a single int.
-        raise NotImplementedError()
+        return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
         """Returns the position of the given writer.
         """
 
         with self._lock:
-            return self._current_positions.get(instance_name, 0)
+            return self._return_factor * self._current_positions.get(instance_name, 0)
 
     def get_positions(self) -> Dict[str, int]:
         """Get a copy of the current positon map.
         """
 
         with self._lock:
-            return dict(self._current_positions)
+            return {
+                name: self._return_factor * i
+                for name, i in self._current_positions.items()
+            }
 
     def advance(self, instance_name: str, new_id: int):
         """Advance the postion of the named writer to the given ID, if greater
         than existing entry.
         """
 
+        new_id *= self._return_factor
+
         with self._lock:
             self._current_positions[instance_name] = max(
                 new_id, self._current_positions.get(instance_name, 0)
@@ -390,7 +409,7 @@ class MultiWriterIdGenerator:
         """
 
         with self._lock:
-            return self._persisted_upto_position
+            return self._return_factor * self._persisted_upto_position
 
     def _add_persisted_position(self, new_id: int):
         """Record that we have persisted a position.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f562770922..dfefbd996d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -20,6 +20,7 @@ from contextlib import contextmanager
 from typing import Dict, Sequence, Set, Union
 
 import attr
+from typing_extensions import ContextManager
 
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
@@ -338,11 +339,11 @@ class Linearizer(object):
 
 
 class ReadWriteLock(object):
-    """A deferred style read write lock.
+    """An async read write lock.
 
     Example:
 
-        with (yield read_write_lock.read("test_key")):
+        with await read_write_lock.read("test_key"):
             # do some work
     """
 
@@ -365,8 +366,7 @@ class ReadWriteLock(object):
         # Latest writer queued
         self.key_to_current_writer = {}  # type: Dict[str, defer.Deferred]
 
-    @defer.inlineCallbacks
-    def read(self, key):
+    async def read(self, key: str) -> ContextManager:
         new_defer = defer.Deferred()
 
         curr_readers = self.key_to_current_readers.setdefault(key, set())
@@ -376,7 +376,8 @@ class ReadWriteLock(object):
 
         # We wait for the latest writer to finish writing. We can safely ignore
         # any existing readers... as they're readers.
-        yield make_deferred_yieldable(curr_writer)
+        if curr_writer:
+            await make_deferred_yieldable(curr_writer)
 
         @contextmanager
         def _ctx_manager():
@@ -388,8 +389,7 @@ class ReadWriteLock(object):
 
         return _ctx_manager()
 
-    @defer.inlineCallbacks
-    def write(self, key):
+    async def write(self, key: str) -> ContextManager:
         new_defer = defer.Deferred()
 
         curr_readers = self.key_to_current_readers.get(key, set())
@@ -405,7 +405,7 @@ class ReadWriteLock(object):
         curr_readers.clear()
         self.key_to_current_writer[key] = new_defer
 
-        yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
+        await make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
         @contextmanager
         def _ctx_manager():