summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py2
-rw-r--r--synapse/handlers/account_data.py15
-rw-r--r--synapse/handlers/account_validity.py86
-rw-r--r--synapse/handlers/admin.py41
-rw-r--r--synapse/handlers/deactivate_account.py54
-rw-r--r--synapse/handlers/e2e_keys.py35
-rw-r--r--synapse/handlers/federation.py221
-rw-r--r--synapse/handlers/initial_sync.py115
-rw-r--r--synapse/handlers/message.py15
-rw-r--r--synapse/handlers/pagination.py27
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/profile.py6
-rw-r--r--synapse/handlers/room.py14
-rw-r--r--synapse/handlers/room_member.py4
-rw-r--r--synapse/handlers/saml_handler.py198
-rw-r--r--synapse/handlers/search.py34
-rw-r--r--synapse/handlers/typing.py3
17 files changed, 497 insertions, 375 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d15c6282fb..51413d910e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -134,7 +134,7 @@ class BaseHandler(object):
             guest_access = event.content.get("guest_access", "forbidden")
             if guest_access != "can_join":
                 if context:
-                    current_state_ids = yield context.get_current_state_ids(self.store)
+                    current_state_ids = yield context.get_current_state_ids()
                     current_state = yield self.store.get_events(
                         list(current_state_ids.values())
                     )
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 2d7e6df6e4..a8d3fbc6de 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 
 class AccountDataEventSource(object):
     def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
     def get_current_key(self, direction="f"):
         return self.store.get_max_account_data_stream_id()
 
-    @defer.inlineCallbacks
-    def get_new_events(self, user, from_key, **kwargs):
+    async def get_new_events(self, user, from_key, **kwargs):
         user_id = user.to_string()
         last_stream_id = from_key
 
-        current_stream_id = yield self.store.get_max_account_data_stream_id()
+        current_stream_id = self.store.get_max_account_data_stream_id()
 
         results = []
-        tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+        tags = await self.store.get_updated_tags(user_id, last_stream_id)
 
         for room_id, room_tags in tags.items():
             results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
         (
             account_data,
             room_account_data,
-        ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+        ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
 
         for account_data_type, content in account_data.items():
             results.append({"type": account_data_type, "content": content})
@@ -53,7 +50,3 @@ class AccountDataEventSource(object):
                 )
 
         return results, current_stream_id
-
-    @defer.inlineCallbacks
-    def get_pagination_rows(self, user, config, key):
-        return [], config.to_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d04e0fe576..829f52eca1 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,7 @@ import email.utils
 import logging
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-
-from twisted.internet import defer
+from typing import List
 
 from synapse.api.errors import StoreError
 from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
                 # run as a background process to make sure that the database transactions
                 # have a logcontext to report to
                 return run_as_background_process(
-                    "send_renewals", self.send_renewal_emails
+                    "send_renewals", self._send_renewal_emails
                 )
 
             self.clock.looping_call(send_emails, 30 * 60 * 1000)
 
-    @defer.inlineCallbacks
-    def send_renewal_emails(self):
+    async def _send_renewal_emails(self):
         """Gets the list of users whose account is expiring in the amount of time
         configured in the ``renew_at`` parameter from the ``account_validity``
         configuration, and sends renewal emails to all of these users as long as they
         have an email 3PID attached to their account.
         """
-        expiring_users = yield self.store.get_users_expiring_soon()
+        expiring_users = await self.store.get_users_expiring_soon()
 
         if expiring_users:
             for user in expiring_users:
-                yield self._send_renewal_email(
+                await self._send_renewal_email(
                     user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
                 )
 
-    @defer.inlineCallbacks
-    def send_renewal_email_to_user(self, user_id):
-        expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
-        yield self._send_renewal_email(user_id, expiration_ts)
+    async def send_renewal_email_to_user(self, user_id: str):
+        expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+        await self._send_renewal_email(user_id, expiration_ts)
 
-    @defer.inlineCallbacks
-    def _send_renewal_email(self, user_id, expiration_ts):
+    async def _send_renewal_email(self, user_id: str, expiration_ts: int):
         """Sends out a renewal email to every email address attached to the given user
         with a unique link allowing them to renew their account.
 
         Args:
-            user_id (str): ID of the user to send email(s) to.
-            expiration_ts (int): Timestamp in milliseconds for the expiration date of
+            user_id: ID of the user to send email(s) to.
+            expiration_ts: Timestamp in milliseconds for the expiration date of
                 this user's account (used in the email templates).
         """
-        addresses = yield self._get_email_addresses_for_user(user_id)
+        addresses = await self._get_email_addresses_for_user(user_id)
 
         # Stop right here if the user doesn't have at least one email address.
         # In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
             return
 
         try:
-            user_display_name = yield self.store.get_profile_displayname(
+            user_display_name = await self.store.get_profile_displayname(
                 UserID.from_string(user_id).localpart
             )
             if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
         except StoreError:
             user_display_name = user_id
 
-        renewal_token = yield self._get_renewal_token(user_id)
+        renewal_token = await self._get_renewal_token(user_id)
         url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
             self.hs.config.public_baseurl,
             renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
 
             logger.info("Sending renewal email to %s", address)
 
-            yield make_deferred_yieldable(
+            await make_deferred_yieldable(
                 self.sendmail(
                     self.hs.config.email_smtp_host,
                     self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
                 )
             )
 
-        yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+        await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
 
-    @defer.inlineCallbacks
-    def _get_email_addresses_for_user(self, user_id):
+    async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
         """Retrieve the list of email addresses attached to a user's account.
 
         Args:
-            user_id (str): ID of the user to lookup email addresses for.
+            user_id: ID of the user to lookup email addresses for.
 
         Returns:
-            defer.Deferred[list[str]]: Email addresses for this account.
+            Email addresses for this account.
         """
-        threepids = yield self.store.user_get_threepids(user_id)
+        threepids = await self.store.user_get_threepids(user_id)
 
         addresses = []
         for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
 
         return addresses
 
-    @defer.inlineCallbacks
-    def _get_renewal_token(self, user_id):
+    async def _get_renewal_token(self, user_id: str) -> str:
         """Generates a 32-byte long random string that will be inserted into the
         user's renewal email's unique link, then saves it into the database.
 
         Args:
-            user_id (str): ID of the user to generate a string for.
+            user_id: ID of the user to generate a string for.
 
         Returns:
-            defer.Deferred[str]: The generated string.
+            The generated string.
 
         Raises:
             StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
         while attempts < 5:
             try:
                 renewal_token = stringutils.random_string(32)
-                yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+                await self.store.set_renewal_token_for_user(user_id, renewal_token)
                 return renewal_token
             except StoreError:
                 attempts += 1
         raise StoreError(500, "Couldn't generate a unique string as refresh string.")
 
-    @defer.inlineCallbacks
-    def renew_account(self, renewal_token):
+    async def renew_account(self, renewal_token: str) -> bool:
         """Renews the account attached to a given renewal token by pushing back the
         expiration date by the current validity period in the server's configuration.
 
         Args:
-            renewal_token (str): Token sent with the renewal request.
+            renewal_token: Token sent with the renewal request.
         Returns:
-            bool: Whether the provided token is valid.
+            Whether the provided token is valid.
         """
         try:
-            user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+            user_id = await self.store.get_user_from_renewal_token(renewal_token)
         except StoreError:
-            defer.returnValue(False)
+            return False
 
         logger.debug("Renewing an account for user %s", user_id)
-        yield self.renew_account_for_user(user_id)
+        await self.renew_account_for_user(user_id)
 
-        defer.returnValue(True)
+        return True
 
-    @defer.inlineCallbacks
-    def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+    async def renew_account_for_user(
+        self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+    ) -> int:
         """Renews the account attached to a given user by pushing back the
         expiration date by the current validity period in the server's
         configuration.
 
         Args:
-            renewal_token (str): Token sent with the renewal request.
-            expiration_ts (int): New expiration date. Defaults to now + validity period.
-            email_sent (bool): Whether an email has been sent for this validity period.
+            renewal_token: Token sent with the renewal request.
+            expiration_ts: New expiration date. Defaults to now + validity period.
+            email_sen: Whether an email has been sent for this validity period.
                 Defaults to False.
 
         Returns:
-            defer.Deferred[int]: New expiration date for this account, as a timestamp
-                in milliseconds since epoch.
+            New expiration date for this account, as a timestamp in
+            milliseconds since epoch.
         """
         if expiration_ts is None:
             expiration_ts = self.clock.time_msec() + self._account_validity.period
 
-        yield self.store.set_account_validity_for_user(
+        await self.store.set_account_validity_for_user(
             user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
         )
 
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 14449b9a1e..1a4ba12385 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.constants import Membership
 from synapse.types import RoomStreamToken
 from synapse.visibility import filter_events_for_client
@@ -33,11 +31,10 @@ class AdminHandler(BaseHandler):
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    @defer.inlineCallbacks
-    def get_whois(self, user):
+    async def get_whois(self, user):
         connections = []
 
-        sessions = yield self.store.get_user_ip_and_agents(user)
+        sessions = await self.store.get_user_ip_and_agents(user)
         for session in sessions:
             connections.append(
                 {
@@ -54,20 +51,18 @@ class AdminHandler(BaseHandler):
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_users(self):
+    async def get_users(self):
         """Function to retrieve a list of users in users table.
 
         Args:
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
-        ret = yield self.store.get_users()
+        ret = await self.store.get_users()
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_users_paginate(self, start, limit, name, guests, deactivated):
+    async def get_users_paginate(self, start, limit, name, guests, deactivated):
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users.
 
@@ -80,14 +75,13 @@ class AdminHandler(BaseHandler):
         Returns:
             defer.Deferred: resolves to json list[dict[str, Any]]
         """
-        ret = yield self.store.get_users_paginate(
+        ret = await self.store.get_users_paginate(
             start, limit, name, guests, deactivated
         )
 
         return ret
 
-    @defer.inlineCallbacks
-    def search_users(self, term):
+    async def search_users(self, term):
         """Function to search users list for one or more users with
         the matched term.
 
@@ -96,7 +90,7 @@ class AdminHandler(BaseHandler):
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
-        ret = yield self.store.search_users(term)
+        ret = await self.store.search_users(term)
 
         return ret
 
@@ -119,8 +113,7 @@ class AdminHandler(BaseHandler):
         """
         return self.store.set_server_admin(user, admin)
 
-    @defer.inlineCallbacks
-    def export_user_data(self, user_id, writer):
+    async def export_user_data(self, user_id, writer):
         """Write all data we have on the user to the given writer.
 
         Args:
@@ -132,7 +125,7 @@ class AdminHandler(BaseHandler):
             The returned value is that returned by `writer.finished()`.
         """
         # Get all rooms the user is in or has been in
-        rooms = yield self.store.get_rooms_for_user_where_membership_is(
+        rooms = await self.store.get_rooms_for_user_where_membership_is(
             user_id,
             membership_list=(
                 Membership.JOIN,
@@ -145,7 +138,7 @@ class AdminHandler(BaseHandler):
         # We only try and fetch events for rooms the user has been in. If
         # they've been e.g. invited to a room without joining then we handle
         # those seperately.
-        rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id)
+        rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
 
         for index, room in enumerate(rooms):
             room_id = room.room_id
@@ -154,7 +147,7 @@ class AdminHandler(BaseHandler):
                 "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
             )
 
-            forgotten = yield self.store.did_forget(user_id, room_id)
+            forgotten = await self.store.did_forget(user_id, room_id)
             if forgotten:
                 logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
                 continue
@@ -166,7 +159,7 @@ class AdminHandler(BaseHandler):
 
                 if room.membership == Membership.INVITE:
                     event_id = room.event_id
-                    invite = yield self.store.get_event(event_id, allow_none=True)
+                    invite = await self.store.get_event(event_id, allow_none=True)
                     if invite:
                         invited_state = invite.unsigned["invite_room_state"]
                         writer.write_invite(room_id, invite, invited_state)
@@ -177,7 +170,7 @@ class AdminHandler(BaseHandler):
             # were joined. We estimate that point by looking at the
             # stream_ordering of the last membership if it wasn't a join.
             if room.membership == Membership.JOIN:
-                stream_ordering = yield self.store.get_room_max_stream_ordering()
+                stream_ordering = self.store.get_room_max_stream_ordering()
             else:
                 stream_ordering = room.stream_ordering
 
@@ -203,7 +196,7 @@ class AdminHandler(BaseHandler):
             # events that we have and then filtering, this isn't the most
             # efficient method perhaps but it does guarantee we get everything.
             while True:
-                events, _ = yield self.store.paginate_room_events(
+                events, _ = await self.store.paginate_room_events(
                     room_id, from_key, to_key, limit=100, direction="f"
                 )
                 if not events:
@@ -211,7 +204,7 @@ class AdminHandler(BaseHandler):
 
                 from_key = events[-1].internal_metadata.after
 
-                events = yield filter_events_for_client(self.storage, user_id, events)
+                events = await filter_events_for_client(self.storage, user_id, events)
 
                 writer.write_events(room_id, events)
 
@@ -247,7 +240,7 @@ class AdminHandler(BaseHandler):
             for event_id in extremities:
                 if not event_to_unseen_prevs[event_id]:
                     continue
-                state = yield self.state_store.get_state_for_event(event_id)
+                state = await self.state_store.get_state_for_event(event_id)
                 writer.write_state(room_id, event_id, state)
 
         return writer.finished()
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 6dedaaff8d..4426967f88 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -15,8 +15,6 @@
 # limitations under the License.
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.errors import SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import UserID, create_requester
@@ -46,8 +44,7 @@ class DeactivateAccountHandler(BaseHandler):
 
         self._account_validity_enabled = hs.config.account_validity.enabled
 
-    @defer.inlineCallbacks
-    def deactivate_account(self, user_id, erase_data, id_server=None):
+    async def deactivate_account(self, user_id, erase_data, id_server=None):
         """Deactivate a user's account
 
         Args:
@@ -74,11 +71,11 @@ class DeactivateAccountHandler(BaseHandler):
         identity_server_supports_unbinding = True
 
         # Retrieve the 3PIDs this user has bound to an identity server
-        threepids = yield self.store.user_get_bound_threepids(user_id)
+        threepids = await self.store.user_get_bound_threepids(user_id)
 
         for threepid in threepids:
             try:
-                result = yield self._identity_handler.try_unbind_threepid(
+                result = await self._identity_handler.try_unbind_threepid(
                     user_id,
                     {
                         "medium": threepid["medium"],
@@ -91,33 +88,33 @@ class DeactivateAccountHandler(BaseHandler):
                 # Do we want this to be a fatal error or should we carry on?
                 logger.exception("Failed to remove threepid from ID server")
                 raise SynapseError(400, "Failed to remove threepid from ID server")
-            yield self.store.user_delete_threepid(
+            await self.store.user_delete_threepid(
                 user_id, threepid["medium"], threepid["address"]
             )
 
         # Remove all 3PIDs this user has bound to the homeserver
-        yield self.store.user_delete_threepids(user_id)
+        await self.store.user_delete_threepids(user_id)
 
         # delete any devices belonging to the user, which will also
         # delete corresponding access tokens.
-        yield self._device_handler.delete_all_devices_for_user(user_id)
+        await self._device_handler.delete_all_devices_for_user(user_id)
         # then delete any remaining access tokens which weren't associated with
         # a device.
-        yield self._auth_handler.delete_access_tokens_for_user(user_id)
+        await self._auth_handler.delete_access_tokens_for_user(user_id)
 
-        yield self.store.user_set_password_hash(user_id, None)
+        await self.store.user_set_password_hash(user_id, None)
 
         # Add the user to a table of users pending deactivation (ie.
         # removal from all the rooms they're a member of)
-        yield self.store.add_user_pending_deactivation(user_id)
+        await self.store.add_user_pending_deactivation(user_id)
 
         # delete from user directory
-        yield self.user_directory_handler.handle_user_deactivated(user_id)
+        await self.user_directory_handler.handle_user_deactivated(user_id)
 
         # Mark the user as erased, if they asked for that
         if erase_data:
             logger.info("Marking %s as erased", user_id)
-            yield self.store.mark_user_erased(user_id)
+            await self.store.mark_user_erased(user_id)
 
         # Now start the process that goes through that list and
         # parts users from rooms (if it isn't already running)
@@ -125,30 +122,29 @@ class DeactivateAccountHandler(BaseHandler):
 
         # Reject all pending invites for the user, so that the user doesn't show up in the
         # "invited" section of rooms' members list.
-        yield self._reject_pending_invites_for_user(user_id)
+        await self._reject_pending_invites_for_user(user_id)
 
         # Remove all information on the user from the account_validity table.
         if self._account_validity_enabled:
-            yield self.store.delete_account_validity_for_user(user_id)
+            await self.store.delete_account_validity_for_user(user_id)
 
         # Mark the user as deactivated.
-        yield self.store.set_user_deactivated_status(user_id, True)
+        await self.store.set_user_deactivated_status(user_id, True)
 
         return identity_server_supports_unbinding
 
-    @defer.inlineCallbacks
-    def _reject_pending_invites_for_user(self, user_id):
+    async def _reject_pending_invites_for_user(self, user_id):
         """Reject pending invites addressed to a given user ID.
 
         Args:
             user_id (str): The user ID to reject pending invites for.
         """
         user = UserID.from_string(user_id)
-        pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
+        pending_invites = await self.store.get_invited_rooms_for_user(user_id)
 
         for room in pending_invites:
             try:
-                yield self._room_member_handler.update_membership(
+                await self._room_member_handler.update_membership(
                     create_requester(user),
                     user,
                     room.room_id,
@@ -180,8 +176,7 @@ class DeactivateAccountHandler(BaseHandler):
         if not self._user_parter_running:
             run_as_background_process("user_parter_loop", self._user_parter_loop)
 
-    @defer.inlineCallbacks
-    def _user_parter_loop(self):
+    async def _user_parter_loop(self):
         """Loop that parts deactivated users from rooms
 
         Returns:
@@ -191,19 +186,18 @@ class DeactivateAccountHandler(BaseHandler):
         logger.info("Starting user parter")
         try:
             while True:
-                user_id = yield self.store.get_user_pending_deactivation()
+                user_id = await self.store.get_user_pending_deactivation()
                 if user_id is None:
                     break
                 logger.info("User parter parting %r", user_id)
-                yield self._part_user(user_id)
-                yield self.store.del_user_pending_deactivation(user_id)
+                await self._part_user(user_id)
+                await self.store.del_user_pending_deactivation(user_id)
                 logger.info("User parter finished parting %r", user_id)
             logger.info("User parter finished: stopping")
         finally:
             self._user_parter_running = False
 
-    @defer.inlineCallbacks
-    def _part_user(self, user_id):
+    async def _part_user(self, user_id):
         """Causes the given user_id to leave all the rooms they're joined to
 
         Returns:
@@ -211,11 +205,11 @@ class DeactivateAccountHandler(BaseHandler):
         """
         user = UserID.from_string(user_id)
 
-        rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+        rooms_for_user = await self.store.get_rooms_for_user(user_id)
         for room_id in rooms_for_user:
             logger.info("User parter parting %r from %r", user_id, room_id)
             try:
-                yield self._room_member_handler.update_membership(
+                await self._room_member_handler.update_membership(
                     create_requester(user),
                     user,
                     room_id,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 57a10daefd..2d889364d4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,6 +264,7 @@ class E2eKeysHandler(object):
 
         return ret
 
+    @defer.inlineCallbacks
     def get_cross_signing_keys_from_cache(self, query, from_user_id):
         """Get cross-signing keys for users from the database
 
@@ -283,14 +284,32 @@ class E2eKeysHandler(object):
         self_signing_keys = {}
         user_signing_keys = {}
 
-        # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
-        return defer.succeed(
-            {
-                "master_keys": master_keys,
-                "self_signing_keys": self_signing_keys,
-                "user_signing_keys": user_signing_keys,
-            }
-        )
+        user_ids = list(query)
+
+        keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
+
+        for user_id, user_info in keys.items():
+            if user_info is None:
+                continue
+            if "master" in user_info:
+                master_keys[user_id] = user_info["master"]
+            if "self_signing" in user_info:
+                self_signing_keys[user_id] = user_info["self_signing"]
+
+        if (
+            from_user_id in keys
+            and keys[from_user_id] is not None
+            and "user_signing" in keys[from_user_id]
+        ):
+            # users can see other users' master and self-signing keys, but can
+            # only see their own user-signing keys
+            user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+
+        return {
+            "master_keys": master_keys,
+            "self_signing_keys": self_signing_keys,
+            "user_signing_keys": user_signing_keys,
+        }
 
     @trace
     @defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6fb453ce60..72a0febc2b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
 
 import itertools
 import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
 
 import six
 from six import iteritems, itervalues
@@ -63,6 +63,7 @@ from synapse.replication.http.federation import (
 )
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.types import UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
@@ -163,8 +164,7 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
@@ -174,17 +174,15 @@ class FederationHandler(BaseHandler):
             pdu (FrozenEvent): received PDU
             sent_to_us_directly (bool): True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
-
-        Returns (Deferred): completes with None
         """
 
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+        logger.info("handling received PDU: %s", pdu)
 
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self.store.get_event(
+        existing = await self.store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
@@ -228,7 +226,7 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
         if not is_in_room:
             logger.info(
                 "[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -243,12 +241,12 @@ class FederationHandler(BaseHandler):
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
-            min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+            min_depth = await self.get_min_depth_for_context(pdu.room_id)
 
             logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
 
             prevs = set(pdu.prev_event_ids())
-            seen = yield self.store.have_seen_events(prevs)
+            seen = await self.store.have_seen_events(prevs)
 
             if min_depth and pdu.depth < min_depth:
                 # This is so that we don't notify the user about this
@@ -268,7 +266,7 @@ class FederationHandler(BaseHandler):
                         len(missing_prevs),
                         shortstr(missing_prevs),
                     )
-                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+                    with (await self._room_pdu_linearizer.queue(pdu.room_id)):
                         logger.info(
                             "[%s %s] Acquired room lock to fetch %d missing prev_events",
                             room_id,
@@ -276,13 +274,19 @@ class FederationHandler(BaseHandler):
                             len(missing_prevs),
                         )
 
-                        yield self._get_missing_events_for_pdu(
-                            origin, pdu, prevs, min_depth
-                        )
+                        try:
+                            await self._get_missing_events_for_pdu(
+                                origin, pdu, prevs, min_depth
+                            )
+                        except Exception as e:
+                            raise Exception(
+                                "Error fetching missing prev_events for %s: %s"
+                                % (event_id, e)
+                            )
 
                         # Update the set of things we've seen after trying to
                         # fetch the missing stuff
-                        seen = yield self.store.have_seen_events(prevs)
+                        seen = await self.store.have_seen_events(prevs)
 
                         if not prevs - seen:
                             logger.info(
@@ -290,14 +294,6 @@ class FederationHandler(BaseHandler):
                                 room_id,
                                 event_id,
                             )
-                elif missing_prevs:
-                    logger.info(
-                        "[%s %s] Not recursively fetching %d missing prev_events: %s",
-                        room_id,
-                        event_id,
-                        len(missing_prevs),
-                        shortstr(missing_prevs),
-                    )
 
             if prevs - seen:
                 # We've still not been able to get all of the prev_events for this event.
@@ -342,12 +338,18 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
+                logger.info(
+                    "Event %s is missing prev_events: calculating state for a "
+                    "backwards extremity",
+                    event_id,
+                )
+
                 # Calculate the state after each of the previous events, and
                 # resolve them to find the correct state at the current event.
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+                    ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
                     state_maps = list(
@@ -361,17 +363,14 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         logger.info(
-                            "[%s %s] Requesting state at missing prev_event %s",
-                            room_id,
-                            event_id,
-                            p,
+                            "Requesting state at missing prev_event %s", event_id,
                         )
 
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
                             # auth events, the requests to fetch those events are deduped
                             # by the get_pdu_cache in federation_client.
-                            (remote_state, _,) = yield self._get_state_for_room(
+                            (remote_state, _,) = await self._get_state_for_room(
                                 origin, room_id, p, include_event_in_state=True
                             )
 
@@ -383,8 +382,8 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
-                    room_version = yield self.store.get_room_version(room_id)
-                    state_map = yield resolve_events_with_store(
+                    room_version = await self.store.get_room_version(room_id)
+                    state_map = await resolve_events_with_store(
                         room_id,
                         room_version,
                         state_maps,
@@ -397,10 +396,10 @@ class FederationHandler(BaseHandler):
 
                     # First though we need to fetch all the events that are in
                     # state_map, so we can build up the state below.
-                    evs = yield self.store.get_events(
+                    evs = await self.store.get_events(
                         list(state_map.values()),
                         get_prev_content=False,
-                        check_redacted=False,
+                        redact_behaviour=EventRedactBehaviour.AS_IS,
                     )
                     event_map.update(evs)
 
@@ -420,10 +419,9 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
-        yield self._process_received_pdu(origin, pdu, state=state)
+        await self._process_received_pdu(origin, pdu, state=state)
 
-    @defer.inlineCallbacks
-    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
         """
         Args:
             origin (str): Origin of the pdu. Will be called to get the missing events
@@ -435,12 +433,12 @@ class FederationHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        seen = yield self.store.have_seen_events(prevs)
+        seen = await self.store.have_seen_events(prevs)
 
         if not prevs - seen:
             return
 
-        latest = yield self.store.get_latest_event_ids_in_room(room_id)
+        latest = await self.store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
@@ -504,7 +502,7 @@ class FederationHandler(BaseHandler):
         # All that said: Let's try increasing the timout to 60s and see what happens.
 
         try:
-            missing_events = yield self.federation_client.get_missing_events(
+            missing_events = await self.federation_client.get_missing_events(
                 origin,
                 room_id,
                 earliest_events_ids=list(latest),
@@ -543,7 +541,7 @@ class FederationHandler(BaseHandler):
             )
             with nested_logging_context(ev.event_id):
                 try:
-                    yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+                    await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
                 except FederationError as e:
                     if e.code == 403:
                         logger.warning(
@@ -555,29 +553,30 @@ class FederationHandler(BaseHandler):
                     else:
                         raise
 
-    @defer.inlineCallbacks
-    @log_function
-    def _get_state_for_room(
-        self, destination, room_id, event_id, include_event_in_state
-    ):
+    async def _get_state_for_room(
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        include_event_in_state: bool = False,
+    ) -> Tuple[List[EventBase], List[EventBase]]:
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
-            destination (str): The remote homeserver to query for the state.
-            room_id (str): The id of the room we're interested in.
-            event_id (str): The id of the event we want the state at.
+            destination: The remote homeserver to query for the state.
+            room_id: The id of the room we're interested in.
+            event_id: The id of the event we want the state at.
             include_event_in_state: if true, the event itself will be included in the
                 returned state event list.
 
         Returns:
-            Deferred[Tuple[List[EventBase], List[EventBase]]]:
-                A list of events in the state, and a list of events in the auth chain
-                for the given event.
+            A list of events in the state, possibly including the event itself, and
+            a list of events in the auth chain for the given event.
         """
         (
             state_event_ids,
             auth_event_ids,
-        ) = yield self.federation_client.get_room_state_ids(
+        ) = await self.federation_client.get_room_state_ids(
             destination, room_id, event_id=event_id
         )
 
@@ -586,15 +585,15 @@ class FederationHandler(BaseHandler):
         if include_event_in_state:
             desired_events.add(event_id)
 
-        event_map = yield self._get_events_from_store_or_dest(
+        event_map = await self._get_events_from_store_or_dest(
             destination, room_id, desired_events
         )
 
         failed_to_fetch = desired_events - event_map.keys()
         if failed_to_fetch:
             logger.warning(
-                "Failed to fetch missing state/auth events for %s: %s",
-                room_id,
+                "Failed to fetch missing state/auth events for %s %s",
+                event_id,
                 failed_to_fetch,
             )
 
@@ -614,15 +613,11 @@ class FederationHandler(BaseHandler):
 
         return remote_state, auth_chain
 
-    @defer.inlineCallbacks
-    def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+    async def _get_events_from_store_or_dest(
+        self, destination: str, room_id: str, event_ids: Iterable[str]
+    ) -> Dict[str, EventBase]:
         """Fetch events from a remote destination, checking if we already have them.
 
-        Args:
-            destination (str)
-            room_id (str)
-            event_ids (Iterable[str])
-
         Persists any events we don't already have as outliers.
 
         If we fail to fetch any of the events, a warning will be logged, and the event
@@ -630,10 +625,9 @@ class FederationHandler(BaseHandler):
         be in the given room.
 
         Returns:
-            Deferred[dict[str, EventBase]]: A deferred resolving to a map
-            from event_id to event
+            map from event_id to event
         """
-        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+        fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
 
         missing_events = set(event_ids) - fetched_events.keys()
 
@@ -644,14 +638,14 @@ class FederationHandler(BaseHandler):
                 room_id,
             )
 
-            yield self._get_events_and_persist(
+            await self._get_events_and_persist(
                 destination=destination, room_id=room_id, events=missing_events
             )
 
             # we need to make sure we re-load from the database to get the rejected
             # state correct.
             fetched_events.update(
-                (yield self.store.get_events(missing_events, allow_rejected=True))
+                (await self.store.get_events(missing_events, allow_rejected=True))
             )
 
         # check for events which were in the wrong room.
@@ -677,12 +671,14 @@ class FederationHandler(BaseHandler):
                 bad_room_id,
                 room_id,
             )
+
             del fetched_events[bad_event_id]
 
         return fetched_events
 
-    @defer.inlineCallbacks
-    def _process_received_pdu(self, origin, event, state):
+    async def _process_received_pdu(
+        self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+    ):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
 
@@ -701,15 +697,15 @@ class FederationHandler(BaseHandler):
         logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         try:
-            context = yield self._handle_new_event(origin, event, state=state)
+            context = await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        room = yield self.store.get_room(room_id)
+        room = await self.store.get_room(room_id)
 
         if not room:
             try:
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=room_id, room_creator_user_id="", is_public=False
                 )
             except StoreError:
@@ -722,11 +718,11 @@ class FederationHandler(BaseHandler):
                 # changing their profile info.
                 newly_joined = True
 
-                prev_state_ids = yield context.get_prev_state_ids(self.store)
+                prev_state_ids = await context.get_prev_state_ids()
 
                 prev_state_id = prev_state_ids.get((event.type, event.state_key))
                 if prev_state_id:
-                    prev_state = yield self.store.get_event(
+                    prev_state = await self.store.get_event(
                         prev_state_id, allow_none=True
                     )
                     if prev_state and prev_state.membership == Membership.JOIN:
@@ -734,11 +730,10 @@ class FederationHandler(BaseHandler):
 
                 if newly_joined:
                     user = UserID.from_string(event.state_key)
-                    yield self.user_joined_room(user, room_id)
+                    await self.user_joined_room(user, room_id)
 
     @log_function
-    @defer.inlineCallbacks
-    def backfill(self, dest, room_id, limit, extremities):
+    async def backfill(self, dest, room_id, limit, extremities):
         """ Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -755,7 +750,7 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        events = yield self.federation_client.backfill(
+        events = await self.federation_client.backfill(
             dest, room_id, limit=limit, extremities=extremities
         )
 
@@ -770,7 +765,7 @@ class FederationHandler(BaseHandler):
         #     self._sanity_check_event(ev)
 
         # Don't bother processing events we already have.
-        seen_events = yield self.store.have_events_in_timeline(
+        seen_events = await self.store.have_events_in_timeline(
             set(e.event_id for e in events)
         )
 
@@ -796,7 +791,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self._get_state_for_room(
+            state, auth = await self._get_state_for_room(
                 destination=dest,
                 room_id=room_id,
                 event_id=e_id,
@@ -843,7 +838,7 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        yield self._handle_new_events(dest, ev_infos, backfilled=True)
+        await self._handle_new_events(dest, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -859,16 +854,15 @@ class FederationHandler(BaseHandler):
             # We store these one at a time since each event depends on the
             # previous to work out the state.
             # TODO: We can probably do something more clever here.
-            yield self._handle_new_event(dest, event, backfilled=True)
+            await self._handle_new_event(dest, event, backfilled=True)
 
         return events
 
-    @defer.inlineCallbacks
-    def maybe_backfill(self, room_id, current_depth):
+    async def maybe_backfill(self, room_id, current_depth):
         """Checks the database to see if we should backfill before paginating,
         and if so do.
         """
-        extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+        extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
 
         if not extremities:
             logger.debug("Not backfilling as no extremeties found.")
@@ -900,15 +894,17 @@ class FederationHandler(BaseHandler):
         #   state *before* the event, ignoring the special casing certain event
         #   types have.
 
-        forward_events = yield self.store.get_successor_events(list(extremities))
+        forward_events = await self.store.get_successor_events(list(extremities))
 
-        extremities_events = yield self.store.get_events(
-            forward_events, check_redacted=False, get_prev_content=False
+        extremities_events = await self.store.get_events(
+            forward_events,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
+            get_prev_content=False,
         )
 
         # We set `check_history_visibility_only` as we might otherwise get false
         # positives from users having been erased.
-        filtered_extremities = yield filter_events_for_server(
+        filtered_extremities = await filter_events_for_server(
             self.storage,
             self.server_name,
             list(extremities_events.values()),
@@ -938,7 +934,7 @@ class FederationHandler(BaseHandler):
         # First we try hosts that are already in the room
         # TODO: HEURISTIC ALERT.
 
-        curr_state = yield self.state_handler.get_current_state(room_id)
+        curr_state = await self.state_handler.get_current_state(room_id)
 
         def get_domains_from_state(state):
             """Get joined domains from state
@@ -977,12 +973,11 @@ class FederationHandler(BaseHandler):
             domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
-        @defer.inlineCallbacks
-        def try_backfill(domains):
+        async def try_backfill(domains):
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
-                    yield self.backfill(
+                    await self.backfill(
                         dom, room_id, limit=100, extremities=extremities
                     )
                     # If this succeeded then we probably already have the
@@ -1013,7 +1008,7 @@ class FederationHandler(BaseHandler):
 
             return False
 
-        success = yield try_backfill(likely_domains)
+        success = await try_backfill(likely_domains)
         if success:
             return True
 
@@ -1027,7 +1022,7 @@ class FederationHandler(BaseHandler):
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
         resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
-        states = yield make_deferred_yieldable(
+        states = await make_deferred_yieldable(
             defer.gatherResults(
                 [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
             )
@@ -1037,7 +1032,7 @@ class FederationHandler(BaseHandler):
         # event_ids.
         states = dict(zip(event_ids, [s.state for s in states]))
 
-        state_map = yield self.store.get_events(
+        state_map = await self.store.get_events(
             [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
             get_prev_content=False,
         )
@@ -1053,7 +1048,7 @@ class FederationHandler(BaseHandler):
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
-            success = yield try_backfill(
+            success = await try_backfill(
                 [dom for dom, _ in likely_domains if dom not in tried_domains]
             )
             if success:
@@ -1063,8 +1058,7 @@ class FederationHandler(BaseHandler):
 
         return False
 
-    @defer.inlineCallbacks
-    def _get_events_and_persist(
+    async def _get_events_and_persist(
         self, destination: str, room_id: str, events: Iterable[str]
     ):
         """Fetch the given events from a server, and persist them as outliers.
@@ -1072,7 +1066,7 @@ class FederationHandler(BaseHandler):
         Logs a warning if we can't find the given event.
         """
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version(room_id)
 
         event_infos = []
 
@@ -1108,9 +1102,9 @@ class FederationHandler(BaseHandler):
                         e,
                     )
 
-        yield concurrently_execute(get_event, events, 5)
+        await concurrently_execute(get_event, events, 5)
 
-        yield self._handle_new_events(
+        await self._handle_new_events(
             destination, event_infos,
         )
 
@@ -1253,7 +1247,7 @@ class FederationHandler(BaseHandler):
             # Check whether this room is the result of an upgrade of a room we already know
             # about. If so, migrate over user information
             predecessor = yield self.store.get_room_predecessor(room_id)
-            if not predecessor:
+            if not predecessor or not isinstance(predecessor.get("room_id"), str):
                 return
             old_room_id = predecessor["room_id"]
             logger.debug(
@@ -1281,8 +1275,7 @@ class FederationHandler(BaseHandler):
 
         return True
 
-    @defer.inlineCallbacks
-    def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(self, room_queue):
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
@@ -1298,7 +1291,7 @@ class FederationHandler(BaseHandler):
                     p.room_id,
                 )
                 with nested_logging_context(p.event_id):
-                    yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                    await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
             except Exception as e:
                 logger.warning(
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1428,7 +1421,7 @@ class FederationHandler(BaseHandler):
                 user = UserID.from_string(event.state_key)
                 yield self.user_joined_room(user, event.room_id)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
         auth_chain = yield self.store.get_auth_chain(state_ids)
@@ -1496,7 +1489,7 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
         origin, event, event_format_version = yield self._make_and_verify_event(
-            target_hosts, room_id, user_id, "leave", content=content,
+            target_hosts, room_id, user_id, "leave", content=content
         )
         # Mark as outlier as we don't have any state for this event; we're not
         # even in the room.
@@ -1937,7 +1930,7 @@ class FederationHandler(BaseHandler):
         context = yield self.state_handler.compute_event_context(event, old_state=state)
 
         if not auth_events:
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             auth_events_ids = yield self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
@@ -2346,12 +2339,12 @@ class FederationHandler(BaseHandler):
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
         }
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
         current_state_ids = dict(current_state_ids)
 
         current_state_ids.update(state_updates)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         prev_state_ids = dict(prev_state_ids)
 
         prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
@@ -2635,7 +2628,7 @@ class FederationHandler(BaseHandler):
             event.content["third_party_invite"]["signed"]["token"],
         )
         original_invite = None
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         original_invite_id = prev_state_ids.get(key)
         if original_invite_id:
             original_invite = yield self.store.get_event(
@@ -2683,7 +2676,7 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
 
         invite_event = None
@@ -2857,7 +2850,7 @@ class FederationHandler(BaseHandler):
                 room_id=room_id, user_id=user.to_string(), change="joined"
             )
         else:
-            return user_joined_room(self.distributor, user, room_id)
+            return defer.succeed(user_joined_room(self.distributor, user, room_id))
 
     @defer.inlineCallbacks
     def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..44ec3e66ae 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = SnapshotCache()
+        self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler):
             as_client_event,
             include_archived,
         )
-        now_ms = self.clock.time_msec()
-        result = self.snapshot_cache.get(now_ms, key)
-        if result is not None:
-            return result
 
-        return self.snapshot_cache.set(
-            now_ms,
+        return self.snapshot_cache.wrap(
             key,
-            self._snapshot_all_rooms(
-                user_id, pagin_config, as_client_event, include_archived
-            ),
+            self._snapshot_all_rooms,
+            user_id,
+            pagin_config,
+            as_client_event,
+            include_archived,
         )
 
-    @defer.inlineCallbacks
-    def _snapshot_all_rooms(
+    async def _snapshot_all_rooms(
         self,
         user_id=None,
         pagin_config=None,
@@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
         if include_archived:
             memberships.append(Membership.LEAVE)
 
-        room_list = yield self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_user_where_membership_is(
             user_id=user_id, membership_list=memberships
         )
 
@@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler):
 
         rooms_ret = []
 
-        now_token = yield self.hs.get_event_sources().get_current_token()
+        now_token = await self.hs.get_event_sources().get_current_token()
 
         presence_stream = self.hs.get_event_sources().sources["presence"]
         pagination_config = PaginationConfig(from_token=now_token)
-        presence, _ = yield presence_stream.get_pagination_rows(
+        presence, _ = await presence_stream.get_pagination_rows(
             user, pagination_config.get_source_config("presence"), None
         )
 
         receipt_stream = self.hs.get_event_sources().sources["receipt"]
-        receipt, _ = yield receipt_stream.get_pagination_rows(
+        receipt, _ = await receipt_stream.get_pagination_rows(
             user, pagination_config.get_source_config("receipt"), None
         )
 
-        tags_by_room = yield self.store.get_tags_for_user(user_id)
+        tags_by_room = await self.store.get_tags_for_user(user_id)
 
-        account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+        account_data, account_data_by_room = await self.store.get_account_data_for_user(
             user_id
         )
 
-        public_room_ids = yield self.store.get_public_room_ids()
+        public_room_ids = await self.store.get_public_room_ids()
 
         limit = pagin_config.limit
         if limit is None:
             limit = 10
 
-        @defer.inlineCallbacks
-        def handle_room(event):
+        async def handle_room(event):
             d = {
                 "room_id": event.room_id,
                 "membership": event.membership,
@@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler):
                 time_now = self.clock.time_msec()
                 d["inviter"] = event.sender
 
-                invite_event = yield self.store.get_event(event.event_id)
-                d["invite"] = yield self._event_serializer.serialize_event(
+                invite_event = await self.store.get_event(event.event_id)
+                d["invite"] = await self._event_serializer.serialize_event(
                     invite_event, time_now, as_client_event
                 )
 
@@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler):
                         lambda states: states[event.event_id]
                     )
 
-                (messages, token), current_state = yield make_deferred_yieldable(
+                (messages, token), current_state = await make_deferred_yieldable(
                     defer.gatherResults(
                         [
                             run_in_background(
@@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler):
                     )
                 ).addErrback(unwrapFirstError)
 
-                messages = yield filter_events_for_client(
+                messages = await filter_events_for_client(
                     self.storage, user_id, messages
                 )
 
@@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler):
 
                 d["messages"] = {
                     "chunk": (
-                        yield self._event_serializer.serialize_events(
+                        await self._event_serializer.serialize_events(
                             messages, time_now=time_now, as_client_event=as_client_event
                         )
                     ),
@@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler):
                     "end": end_token.to_string(),
                 }
 
-                d["state"] = yield self._event_serializer.serialize_events(
+                d["state"] = await self._event_serializer.serialize_events(
                     current_state.values(),
                     time_now=time_now,
                     as_client_event=as_client_event,
@@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler):
             except Exception:
                 logger.exception("Failed to get snapshot")
 
-        yield concurrently_execute(handle_room, room_list, 10)
+        await concurrently_execute(handle_room, room_list, 10)
 
         account_data_events = []
         for account_data_type, content in account_data.items():
@@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler):
 
         return ret
 
-    @defer.inlineCallbacks
-    def room_initial_sync(self, requester, room_id, pagin_config=None):
+    async def room_initial_sync(self, requester, room_id, pagin_config=None):
         """Capture the a snapshot of a room. If user is currently a member of
         the room this will be what is currently in the room. If the user left
         the room this will be what was in the room when they left.
@@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler):
             A JSON serialisable dict with the snapshot of the room.
         """
 
-        blocked = yield self.store.is_room_blocked(room_id)
+        blocked = await self.store.is_room_blocked(room_id)
         if blocked:
             raise SynapseError(403, "This room has been blocked on this server")
 
         user_id = requester.user.to_string()
 
-        membership, member_event_id = yield self._check_in_room_or_world_readable(
+        membership, member_event_id = await self._check_in_room_or_world_readable(
             room_id, user_id
         )
         is_peeking = member_event_id is None
 
         if membership == Membership.JOIN:
-            result = yield self._room_initial_sync_joined(
+            result = await self._room_initial_sync_joined(
                 user_id, room_id, pagin_config, membership, is_peeking
             )
         elif membership == Membership.LEAVE:
-            result = yield self._room_initial_sync_parted(
+            result = await self._room_initial_sync_parted(
                 user_id, room_id, pagin_config, membership, member_event_id, is_peeking
             )
 
         account_data_events = []
-        tags = yield self.store.get_tags_for_room(user_id, room_id)
+        tags = await self.store.get_tags_for_room(user_id, room_id)
         if tags:
             account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
 
-        account_data = yield self.store.get_account_data_for_room(user_id, room_id)
+        account_data = await self.store.get_account_data_for_room(user_id, room_id)
         for account_data_type, content in account_data.items():
             account_data_events.append({"type": account_data_type, "content": content})
 
@@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler):
 
         return result
 
-    @defer.inlineCallbacks
-    def _room_initial_sync_parted(
+    async def _room_initial_sync_parted(
         self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
     ):
-        room_state = yield self.state_store.get_state_for_events([member_event_id])
+        room_state = await self.state_store.get_state_for_events([member_event_id])
 
         room_state = room_state[member_event_id]
 
@@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler):
         if limit is None:
             limit = 10
 
-        stream_token = yield self.store.get_stream_token_for_event(member_event_id)
+        stream_token = await self.store.get_stream_token_for_event(member_event_id)
 
-        messages, token = yield self.store.get_recent_events_for_room(
+        messages, token = await self.store.get_recent_events_for_room(
             room_id, limit=limit, end_token=stream_token
         )
 
-        messages = yield filter_events_for_client(
+        messages = await filter_events_for_client(
             self.storage, user_id, messages, is_peeking=is_peeking
         )
 
@@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler):
             "room_id": room_id,
             "messages": {
                 "chunk": (
-                    yield self._event_serializer.serialize_events(messages, time_now)
+                    await self._event_serializer.serialize_events(messages, time_now)
                 ),
                 "start": start_token.to_string(),
                 "end": end_token.to_string(),
             },
             "state": (
-                yield self._event_serializer.serialize_events(
+                await self._event_serializer.serialize_events(
                     room_state.values(), time_now
                 )
             ),
@@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler):
             "receipts": [],
         }
 
-    @defer.inlineCallbacks
-    def _room_initial_sync_joined(
+    async def _room_initial_sync_joined(
         self, user_id, room_id, pagin_config, membership, is_peeking
     ):
-        current_state = yield self.state.get_current_state(room_id=room_id)
+        current_state = await self.state.get_current_state(room_id=room_id)
 
         # TODO: These concurrently
         time_now = self.clock.time_msec()
-        state = yield self._event_serializer.serialize_events(
+        state = await self._event_serializer.serialize_events(
             current_state.values(), time_now
         )
 
-        now_token = yield self.hs.get_event_sources().get_current_token()
+        now_token = await self.hs.get_event_sources().get_current_token()
 
         limit = pagin_config.limit if pagin_config else None
         if limit is None:
@@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler):
 
         presence_handler = self.hs.get_presence_handler()
 
-        @defer.inlineCallbacks
-        def get_presence():
+        async def get_presence():
             # If presence is disabled, return an empty list
             if not self.hs.config.use_presence:
                 return []
 
-            states = yield presence_handler.get_states(
+            states = await presence_handler.get_states(
                 [m.user_id for m in room_members], as_event=True
             )
 
             return states
 
-        @defer.inlineCallbacks
-        def get_receipts():
-            receipts = yield self.store.get_linearized_receipts_for_room(
+        async def get_receipts():
+            receipts = await self.store.get_linearized_receipts_for_room(
                 room_id, to_key=now_token.receipt_key
             )
             if not receipts:
                 receipts = []
             return receipts
 
-        presence, receipts, (messages, token) = yield make_deferred_yieldable(
+        presence, receipts, (messages, token) = await make_deferred_yieldable(
             defer.gatherResults(
                 [
                     run_in_background(get_presence),
@@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler):
             ).addErrback(unwrapFirstError)
         )
 
-        messages = yield filter_events_for_client(
+        messages = await filter_events_for_client(
             self.storage, user_id, messages, is_peeking=is_peeking
         )
 
@@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler):
             "room_id": room_id,
             "messages": {
                 "chunk": (
-                    yield self._event_serializer.serialize_events(messages, time_now)
+                    await self._event_serializer.serialize_events(messages, time_now)
                 ),
                 "start": start_token.to_string(),
                 "end": end_token.to_string(),
@@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler):
 
         return ret
 
-    @defer.inlineCallbacks
-    def _check_in_room_or_world_readable(self, room_id, user_id):
+    async def _check_in_room_or_world_readable(self, room_id, user_id):
         try:
             # check_user_was_in_room will return the most recent membership
             # event for the user if:
             #  * The user is a non-guest user, and was ever in the room
             #  * The user is a guest user, and has joined the room
             # else it will throw.
-            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+            member_event = await self.auth.check_user_was_in_room(room_id, user_id)
             return member_event.membership, member_event.event_id
         except AuthError:
-            visibility = yield self.state_handler.get_current_state(
+            visibility = await self.state_handler.get_current_state(
                 room_id, EventTypes.RoomHistoryVisibility, ""
             )
             if (
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..4ad752205f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
 from synapse.util.async_helpers import Linearizer
@@ -514,7 +515,7 @@ class EventCreationHandler(object):
             # federation as well as those created locally. As of room v3, aliases events
             # can be created by users that are not in the room, therefore we have to
             # tolerate them in event_auth.check().
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
             prev_event = (
                 yield self.store.get_event(prev_event_id, allow_none=True)
@@ -664,7 +665,7 @@ class EventCreationHandler(object):
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         prev_event_id = prev_state_ids.get((event.type, event.state_key))
         if not prev_event_id:
             return
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
             if event.type == EventTypes.Redaction:
                 original_event = yield self.store.get_event(
                     event.redacts,
-                    check_redacted=False,
+                    redact_behaviour=EventRedactBehaviour.AS_IS,
                     get_prev_content=False,
                     allow_rejected=False,
                     allow_none=True,
@@ -913,7 +914,7 @@ class EventCreationHandler(object):
                 def is_inviter_member_event(e):
                     return e.type == EventTypes.Member and e.sender == event.sender
 
-                current_state_ids = yield context.get_current_state_ids(self.store)
+                current_state_ids = yield context.get_current_state_ids()
 
                 state_to_include_ids = [
                     e_id
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
         if event.type == EventTypes.Redaction:
             original_event = yield self.store.get_event(
                 event.redacts,
-                check_redacted=False,
+                redact_behaviour=EventRedactBehaviour.AS_IS,
                 get_prev_content=False,
                 allow_rejected=False,
                 allow_none=True,
@@ -966,7 +967,7 @@ class EventCreationHandler(object):
                 if original_event.room_id != event.room_id:
                     raise SynapseError(400, "Cannot redact event from a different room")
 
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             auth_events_ids = yield self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
@@ -988,7 +989,7 @@ class EventCreationHandler(object):
                 event.internal_metadata.recheck_redaction = False
 
         if event.type == EventTypes.Create:
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
 
             await self.storage.purge_events.purge_room(room_id)
 
-    @defer.inlineCallbacks
-    def get_messages(
+    async def get_messages(
         self,
         requester,
         room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
             room_token = pagin_config.from_token.room_key
         else:
             pagin_config.from_token = (
-                yield self.hs.get_event_sources().get_current_token_for_pagination()
+                await self.hs.get_event_sources().get_current_token_for_pagination()
             )
             room_token = pagin_config.from_token.room_key
 
@@ -319,11 +318,11 @@ class PaginationHandler(object):
 
         source_config = pagin_config.get_source_config("room")
 
-        with (yield self.pagination_lock.read(room_id)):
+        with (await self.pagination_lock.read(room_id)):
             (
                 membership,
                 member_event_id,
-            ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+            ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
 
             if source_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
                 if room_token.topological:
                     max_topo = room_token.topological
                 else:
-                    max_topo = yield self.store.get_max_topological_token(
+                    max_topo = await self.store.get_max_topological_token(
                         room_id, room_token.stream
                     )
 
@@ -339,18 +338,18 @@ class PaginationHandler(object):
                     # If they have left the room then clamp the token to be before
                     # they left the room, to save the effort of loading from the
                     # database.
-                    leave_token = yield self.store.get_topological_token_for_event(
+                    leave_token = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
                     leave_token = RoomStreamToken.parse(leave_token)
                     if leave_token.topological < max_topo:
                         source_config.from_key = str(leave_token)
 
-                yield self.hs.get_handlers().federation_handler.maybe_backfill(
+                await self.hs.get_handlers().federation_handler.maybe_backfill(
                     room_id, max_topo
                 )
 
-            events, next_key = yield self.store.paginate_room_events(
+            events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
                 from_key=source_config.from_key,
                 to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
             if event_filter:
                 events = event_filter.filter(events)
 
-            events = yield filter_events_for_client(
+            events = await filter_events_for_client(
                 self.storage, user_id, events, is_peeking=(member_event_id is None)
             )
 
@@ -385,19 +384,19 @@ class PaginationHandler(object):
                 (EventTypes.Member, event.sender) for event in events
             )
 
-            state_ids = yield self.state_store.get_state_ids_for_event(
+            state_ids = await self.state_store.get_state_ids_for_event(
                 events[0].event_id, state_filter=state_filter
             )
 
             if state_ids:
-                state = yield self.store.get_events(list(state_ids.values()))
+                state = await self.store.get_events(list(state_ids.values()))
                 state = state.values()
 
         time_now = self.clock.time_msec()
 
         chunk = {
             "chunk": (
-                yield self._event_serializer.serialize_events(
+                await self._event_serializer.serialize_events(
                     events, time_now, as_client_event=as_client_event
                 )
             ),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
         }
 
         if state:
-            chunk["state"] = yield self._event_serializer.serialize_events(
+            chunk["state"] = await self._event_serializer.serialize_events(
                 state, time_now, as_client_event=as_client_event
             )
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index eda15bc623..240c4add12 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -230,7 +230,7 @@ class PresenceHandler(object):
         is some spurious presence changes that will self-correct.
         """
         # If the DB pool has already terminated, don't try updating
-        if not self.hs.get_db_pool().running:
+        if not self.store.database.is_running():
             return
 
         logger.info(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 1e5a4613c9..f9579d69ee 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -295,12 +295,16 @@ class BaseProfileHandler(BaseHandler):
                 be found to be in any room the server is in, and therefore the query
                 is denied.
         """
+
         # Implementation of MSC1301: don't allow looking up profiles if the
         # requester isn't in the same room as the target. We expect requester to
         # be None when this function is called outside of a profile query, e.g.
         # when building a membership event. In this case, we must allow the
         # lookup.
-        if not self.hs.config.require_auth_for_profile_requests or not requester:
+        if (
+            not self.hs.config.limit_profile_requests_to_users_who_share_rooms
+            or not requester
+        ):
             return
 
         # Always allow the user to query their own profile.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 60b8bbc7a5..89c9118b26 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler):
             requester, tombstone_event, tombstone_context
         )
 
-        old_room_state = yield tombstone_context.get_current_state_ids(self.store)
+        old_room_state = yield tombstone_context.get_current_state_ids()
 
         # update any aliases
         yield self._move_aliases_to_new_room(
@@ -1011,15 +1011,3 @@ class RoomEventSource(object):
 
     def get_current_key_for_room(self, room_id):
         return self.store.get_room_events_max_id(room_id)
-
-    @defer.inlineCallbacks
-    def get_pagination_rows(self, user, config, key):
-        events, next_key = yield self.store.paginate_room_events(
-            room_id=key,
-            from_key=config.from_key,
-            to_key=config.to_key,
-            direction=config.direction,
-            limit=config.limit,
-        )
-
-        return (events, next_key)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7b7270fc61..44c5e3239c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -193,7 +193,7 @@ class RoomMemberHandler(object):
             requester, event, context, extra_users=[target], ratelimit=ratelimit
         )
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
@@ -601,7 +601,7 @@ class RoomMemberHandler(object):
         if prev_event is not None:
             return
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         if event.membership == Membership.JOIN:
             if requester.is_guest:
                 guest_can_join = yield self._can_guest_join(prev_state_ids)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
+from typing import Tuple
 
 import attr
 import saml2
+import saml2.response
 from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    UserID,
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class Saml2SessionData:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib()
+
+
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
         self._datastore = hs.get_datastore()
         self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
-        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
-        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # plugin to do custom mapping from saml response to mxid
+        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+            hs.config.saml2_user_mapping_provider_config
+        )
 
         # identifier for the external_ids table
         self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
             remote_user_id = saml2_auth.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "uid not in SAML2 response")
-
-        try:
-            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
-        except KeyError:
-            logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
-            )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
-            )
+            raise SynapseError(400, "'uid' not in SAML2 response")
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        displayName = saml2_auth.ava.get("displayName", [None])[0]
-
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
                     )
                     return registered_user_id
 
-            # figure out a new mxid for this user
-            base_mxid_localpart = self._mxid_mapper(mxid_source)
+            # Map saml response to user attributes using the configured mapping provider
+            for i in range(1000):
+                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+                    saml2_auth, i
+                )
+
+                logger.debug(
+                    "Retrieved SAML attributes from user mapping provider: %s "
+                    "(attempt %d)",
+                    attribute_dict,
+                    i,
+                )
+
+                localpart = attribute_dict.get("mxid_localpart")
+                if not localpart:
+                    logger.error(
+                        "SAML mapping provider plugin did not return a "
+                        "mxid_localpart object"
+                    )
+                    raise SynapseError(500, "Error parsing SAML2 response")
 
-            suffix = 0
-            while True:
-                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                displayname = attribute_dict.get("displayname")
+
+                # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
                     UserID(localpart, self._hostname).to_string()
                 ):
+                    # This mxid is free
                     break
-                suffix += 1
-            logger.info("Allocating mxid for new user with localpart %s", localpart)
+            else:
+                # Unable to generate a username in 1000 iterations
+                # Break and return error to the user
+                raise SynapseError(
+                    500, "Unable to generate a Matrix ID from the SAML response"
+                )
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayName
+                localpart=localpart, default_display_name=displayname
             )
+
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
@@ -205,9 +236,120 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
+
+
 @attr.s
-class Saml2SessionData:
-    """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+    mxid_source_attribute = attr.ib()
+    mxid_mapper = attr.ib()
 
-    # time the session was created, in milliseconds
-    creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+    __version__ = "0.0.1"
+
+    def __init__(self, parsed_config: SamlConfig):
+        """The default SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+        """
+        self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        self._mxid_mapper = parsed_config.mxid_mapper
+
+    def saml_response_to_user_attributes(
+        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+        """
+        try:
+            mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
+        # Use the configured mapper for this mxid_source
+        base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid
+        localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+        # Retrieve the display name from the saml response
+        # If displayname is None, the mxid_localpart will be used instead
+        displayname = saml_response.ava.get("displayName", [None])[0]
+
+        return {
+            "mxid_localpart": localpart,
+            "displayname": displayname,
+        }
+
+    @staticmethod
+    def parse_config(config: dict) -> SamlConfig:
+        """Parse the dict provided by the homeserver's config
+        Args:
+            config: A dictionary containing configuration options for this provider
+        Returns:
+            SamlConfig: A custom config object for this module
+        """
+        # Parse config options and use defaults where necessary
+        mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+        mapping_type = config.get("mxid_mapping", "hexencode")
+
+        # Retrieve the associating mapping function
+        try:
+            mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+        except KeyError:
+            raise ConfigError(
+                "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+                "mxid_mapping value" % (mapping_type,)
+            )
+
+        return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+    @staticmethod
+    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+        """Returns the required attributes of a SAML
+
+        Args:
+            config: A SamlConfig object containing configuration params for this provider
+
+        Returns:
+            tuple[set,set]: The first set equates to the saml auth response
+                attributes that are required for the module to function, whereas the
+                second set consists of those attributes which can be used if
+                available, but are not necessary
+        """
+        return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.storage.state import StateFilter
 from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
+        self.auth = hs.get_auth()
 
     @defer.inlineCallbacks
     def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
             room_id (str): id of the room to search through.
 
         Returns:
-            Deferred[iterable[unicode]]: predecessor room ids
+            Deferred[iterable[str]]: predecessor room ids
         """
 
         historical_room_ids = []
 
-        while True:
-            predecessor = yield self.store.get_room_predecessor(room_id)
+        # The initial room must have been known for us to get this far
+        predecessor = yield self.store.get_room_predecessor(room_id)
 
-            # If no predecessor, assume we've hit a dead end
+        while True:
             if not predecessor:
+                # We have reached the end of the chain of predecessors
+                break
+
+            if not isinstance(predecessor.get("room_id"), str):
+                # This predecessor object is malformed. Exit here
+                break
+
+            predecessor_room_id = predecessor["room_id"]
+
+            # Don't add it to the list until we have checked that we are in the room
+            try:
+                next_predecessor_room = yield self.store.get_room_predecessor(
+                    predecessor_room_id
+                )
+            except NotFoundError:
+                # The predecessor is not a known room, so we are done here
                 break
 
-            # Add predecessor's room ID
-            historical_room_ids.append(predecessor["room_id"])
+            historical_room_ids.append(predecessor_room_id)
 
-            # Scan through the old room for further predecessors
-            room_id = predecessor["room_id"]
+            # And repeat
+            predecessor = next_predecessor_room
 
         return historical_room_ids
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6f78454322..b635c339ed 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -317,6 +317,3 @@ class TypingNotificationEventSource(object):
 
     def get_current_key(self):
         return self.get_typing_handler()._latest_room_serial
-
-    def get_pagination_rows(self, user, pagination_config, key):
-        return [], pagination_config.from_key