summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_data.py7
-rw-r--r--synapse/handlers/appservice.py5
-rw-r--r--synapse/handlers/device.py13
-rw-r--r--synapse/handlers/directory.py2
-rw-r--r--synapse/handlers/e2e_keys.py174
-rw-r--r--synapse/handlers/federation.py77
-rw-r--r--synapse/handlers/initial_sync.py4
-rw-r--r--synapse/handlers/message.py14
-rw-r--r--synapse/handlers/pagination.py13
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/room.py37
-rw-r--r--synapse/handlers/room_member.py116
-rw-r--r--synapse/handlers/search.py12
-rw-r--r--synapse/handlers/stats.py5
-rw-r--r--synapse/handlers/sync.py16
-rw-r--r--synapse/handlers/ui_auth/checkers.py2
16 files changed, 357 insertions, 144 deletions
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 38bc67191c..2d7e6df6e4 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -38,9 +38,10 @@ class AccountDataEventSource(object):
                 {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
             )
 
-        account_data, room_account_data = (
-            yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
-        )
+        (
+            account_data,
+            room_account_data,
+        ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
 
         for account_data_type, content in account_data.items():
             results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3e9b298154..fe62f78e67 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -73,7 +73,10 @@ class ApplicationServicesHandler(object):
             try:
                 limit = 100
                 while True:
-                    upper_bound, events = yield self.store.get_new_events_for_appservice(
+                    (
+                        upper_bound,
+                        events,
+                    ) = yield self.store.get_new_events_for_appservice(
                         self.current_max, limit
                     )
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 16b4617f68..26ef5e150c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
     @defer.inlineCallbacks
     def on_federation_query_user_devices(self, user_id):
         stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
-        return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+        master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+        self_signing_key = yield self.store.get_e2e_cross_signing_key(
+            user_id, "self_signing"
+        )
+
+        return {
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+            "master_key": master_key,
+            "self_signing_key": self_signing_key,
+        }
 
     @defer.inlineCallbacks
     def user_left_room(self, user, room_id):
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 526379c6f7..c4632f8984 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler):
                     ignore_backoff=True,
                 )
             except CodeMessageException as e:
-                logging.warn("Error retrieving alias")
+                logging.warning("Error retrieving alias")
                 if e.code == 404:
                     result = None
                 else:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5ea54f60be..f09a0b73c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,6 +36,8 @@ from synapse.types import (
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
+        self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+        federation_registry = hs.get_federation_registry()
+
+        # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+        federation_registry.register_edu_handler(
+            "org.matrix.signing_key_update",
+            self._edu_updater.incoming_signing_key_update,
+        )
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
         # "query handler" interface.
-        hs.get_federation_registry().register_query_handler(
+        federation_registry.register_query_handler(
             "client_keys", self.on_federation_query_client_keys
         )
 
@@ -119,9 +130,10 @@ class E2eKeysHandler(object):
                 else:
                     query_list.append((user_id, None))
 
-            user_ids_not_in_cache, remote_results = (
-                yield self.store.get_user_devices_from_cache(query_list)
-            )
+            (
+                user_ids_not_in_cache,
+                remote_results,
+            ) = yield self.store.get_user_devices_from_cache(query_list)
             for user_id, devices in iteritems(remote_results):
                 user_devices = results.setdefault(user_id, {})
                 for device_id, device in iteritems(devices):
@@ -207,13 +219,15 @@ class E2eKeysHandler(object):
                     if user_id in destination_query:
                         results[user_id] = keys
 
-                for user_id, key in remote_result["master_keys"].items():
-                    if user_id in destination_query:
-                        cross_signing_keys["master_keys"][user_id] = key
+                if "master_keys" in remote_result:
+                    for user_id, key in remote_result["master_keys"].items():
+                        if user_id in destination_query:
+                            cross_signing_keys["master_keys"][user_id] = key
 
-                for user_id, key in remote_result["self_signing_keys"].items():
-                    if user_id in destination_query:
-                        cross_signing_keys["self_signing_keys"][user_id] = key
+                if "self_signing_keys" in remote_result:
+                    for user_id, key in remote_result["self_signing_keys"].items():
+                        if user_id in destination_query:
+                            cross_signing_keys["self_signing_keys"][user_id] = key
 
             except Exception as e:
                 failure = _exception_to_failure(e)
@@ -251,7 +265,7 @@ class E2eKeysHandler(object):
 
         Returns:
             defer.Deferred[dict[str, dict[str, dict]]]: map from
-                (master|self_signing|user_signing) -> user_id -> key
+                (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
         """
         master_keys = {}
         self_signing_keys = {}
@@ -343,7 +357,16 @@ class E2eKeysHandler(object):
         """
         device_keys_query = query_body.get("device_keys", {})
         res = yield self.query_local_devices(device_keys_query)
-        return {"device_keys": res}
+        ret = {"device_keys": res}
+
+        # add in the cross-signing keys
+        cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+            device_keys_query, None
+        )
+
+        ret.update(cross_signing_keys)
+
+        return ret
 
     @trace
     @defer.inlineCallbacks
@@ -688,17 +711,21 @@ class E2eKeysHandler(object):
 
         try:
             # get our self-signing key to verify the signatures
-            _, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
-                user_id, "self_signing"
-            )
+            (
+                _,
+                self_signing_key_id,
+                self_signing_verify_key,
+            ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
 
             # get our master key, since we may have received a signature of it.
             # We need to fetch it here so that we know what its key ID is, so
             # that we can check if a signature that was sent is a signature of
             # the master key or of a device
-            master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key(
-                user_id, "master"
-            )
+            (
+                master_key,
+                _,
+                master_verify_key,
+            ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
 
             # fetch our stored devices.  This is used to 1. verify
             # signatures on the master key, and 2. to compare with what
@@ -838,9 +865,11 @@ class E2eKeysHandler(object):
 
         try:
             # get our user-signing key to verify the signatures
-            user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
-                user_id, "user_signing"
-            )
+            (
+                user_signing_key,
+                user_signing_key_id,
+                user_signing_verify_key,
+            ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
         except SynapseError as e:
             failure = _exception_to_failure(e)
             for user, devicemap in signatures.items():
@@ -859,7 +888,11 @@ class E2eKeysHandler(object):
             try:
                 # get the target user's master key, to make sure it matches
                 # what was sent
-                master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key(
+                (
+                    master_key,
+                    master_key_id,
+                    _,
+                ) = yield self._get_e2e_cross_signing_verify_key(
                     target_user, "master", user_id
                 )
 
@@ -1047,3 +1080,100 @@ class SignatureListItem:
     target_user_id = attr.ib()
     target_device_id = attr.ib()
     signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+    """Handles incoming signing key updates from federation and updates the DB"""
+
+    def __init__(self, hs, e2e_keys_handler):
+        self.store = hs.get_datastore()
+        self.federation = hs.get_federation_client()
+        self.clock = hs.get_clock()
+        self.e2e_keys_handler = e2e_keys_handler
+
+        self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+        # user_id -> list of updates waiting to be handled.
+        self._pending_updates = {}
+
+        # Recently seen stream ids. We don't bother keeping these in the DB,
+        # but they're useful to have them about to reduce the number of spurious
+        # resyncs.
+        self._seen_updates = ExpiringCache(
+            cache_name="signing_key_update_edu",
+            clock=self.clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+            iterable=True,
+        )
+
+    @defer.inlineCallbacks
+    def incoming_signing_key_update(self, origin, edu_content):
+        """Called on incoming signing key update from federation. Responsible for
+        parsing the EDU and adding to pending updates list.
+
+        Args:
+            origin (string): the server that sent the EDU
+            edu_content (dict): the contents of the EDU
+        """
+
+        user_id = edu_content.pop("user_id")
+        master_key = edu_content.pop("master_key", None)
+        self_signing_key = edu_content.pop("self_signing_key", None)
+
+        if get_domain_from_id(user_id) != origin:
+            logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+            return
+
+        room_ids = yield self.store.get_rooms_for_user(user_id)
+        if not room_ids:
+            # We don't share any rooms with this user. Ignore update, as we
+            # probably won't get any further updates.
+            return
+
+        self._pending_updates.setdefault(user_id, []).append(
+            (master_key, self_signing_key)
+        )
+
+        yield self._handle_signing_key_updates(user_id)
+
+    @defer.inlineCallbacks
+    def _handle_signing_key_updates(self, user_id):
+        """Actually handle pending updates.
+
+        Args:
+            user_id (string): the user whose updates we are processing
+        """
+
+        device_handler = self.e2e_keys_handler.device_handler
+
+        with (yield self._remote_edu_linearizer.queue(user_id)):
+            pending_updates = self._pending_updates.pop(user_id, [])
+            if not pending_updates:
+                # This can happen since we batch updates
+                return
+
+            device_ids = []
+
+            logger.info("pending updates: %r", pending_updates)
+
+            for master_key, self_signing_key in pending_updates:
+                if master_key:
+                    yield self.store.set_e2e_cross_signing_key(
+                        user_id, "master", master_key
+                    )
+                    _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+                    # verify_key is a VerifyKey from signedjson, which uses
+                    # .version to denote the portion of the key ID after the
+                    # algorithm and colon, which is the device ID
+                    device_ids.append(verify_key.version)
+                if self_signing_key:
+                    yield self.store.set_e2e_cross_signing_key(
+                        user_id, "self_signing", self_signing_key
+                    )
+                    _, verify_key = get_verify_key_from_cross_signing_key(
+                        self_signing_key
+                    )
+                    device_ids.append(verify_key.version)
+
+            yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d2d9f8c26a..8cafcfdab0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import auth_types_for_event
+from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.logging.context import (
     make_deferred_yieldable,
@@ -352,10 +353,11 @@ class FederationHandler(BaseHandler):
                             # note that if any of the missing prevs share missing state or
                             # auth events, the requests to fetch those events are deduped
                             # by the get_pdu_cache in federation_client.
-                            remote_state, got_auth_chain = (
-                                yield self.federation_client.get_state_for_room(
-                                    origin, room_id, p
-                                )
+                            (
+                                remote_state,
+                                got_auth_chain,
+                            ) = yield self.federation_client.get_state_for_room(
+                                origin, room_id, p
                             )
 
                             # we want the state *after* p; get_state_for_room returns the
@@ -1105,7 +1107,7 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def do_invite_join(self, target_hosts, room_id, joinee, content):
         """ Attempts to join the `joinee` to the room `room_id` via the
-        server `target_host`.
+        servers contained in `target_hosts`.
 
         This first triggers a /make_join/ request that returns a partial
         event that we can fill out and sign. This is then sent to the
@@ -1114,6 +1116,15 @@ class FederationHandler(BaseHandler):
 
         We suspend processing of any received events from this room until we
         have finished processing the join.
+
+        Args:
+            target_hosts (Iterable[str]): List of servers to attempt to join the room with.
+
+            room_id (str): The ID of the room to join.
+
+            joinee (str): The User ID of the joining user.
+
+            content (dict): The event content to use for the join event.
         """
         logger.debug("Joining %s to %s", joinee, room_id)
 
@@ -1173,6 +1184,22 @@ class FederationHandler(BaseHandler):
 
             yield self._persist_auth_tree(origin, auth_chain, state, event)
 
+            # Check whether this room is the result of an upgrade of a room we already know
+            # about. If so, migrate over user information
+            predecessor = yield self.store.get_room_predecessor(room_id)
+            if not predecessor:
+                return
+            old_room_id = predecessor["room_id"]
+            logger.debug(
+                "Found predecessor for %s during remote join: %s", room_id, old_room_id
+            )
+
+            # We retrieve the room member handler here as to not cause a cyclic dependency
+            member_handler = self.hs.get_room_member_handler()
+            yield member_handler.transfer_room_state_on_room_upgrade(
+                old_room_id, room_id
+            )
+
             logger.debug("Finished joining %s to %s", joinee, room_id)
         finally:
             room_queue = self.room_queues[room_id]
@@ -1845,14 +1872,7 @@ class FederationHandler(BaseHandler):
                 if c and c.type == EventTypes.Create:
                     auth_events[(c.type, c.state_key)] = c
 
-        try:
-            yield self.do_auth(origin, event, context, auth_events=auth_events)
-        except AuthError as e:
-            logger.warning(
-                "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
-            )
-
-            context.rejected = RejectedReason.AUTH_ERROR
+        context = yield self.do_auth(origin, event, context, auth_events=auth_events)
 
         if not context.rejected:
             yield self._check_for_soft_fail(event, state, backfilled)
@@ -2021,12 +2041,12 @@ class FederationHandler(BaseHandler):
 
                 Also NB that this function adds entries to it.
         Returns:
-            defer.Deferred[None]
+            defer.Deferred[EventContext]: updated context object
         """
         room_version = yield self.store.get_room_version(event.room_id)
 
         try:
-            yield self._update_auth_events_and_context_for_auth(
+            context = yield self._update_auth_events_and_context_for_auth(
                 origin, event, context, auth_events
             )
         except Exception:
@@ -2044,7 +2064,9 @@ class FederationHandler(BaseHandler):
             event_auth.check(room_version, event, auth_events=auth_events)
         except AuthError as e:
             logger.warning("Failed auth resolution for %r because %s", event, e)
-            raise e
+            context.rejected = RejectedReason.AUTH_ERROR
+
+        return context
 
     @defer.inlineCallbacks
     def _update_auth_events_and_context_for_auth(
@@ -2068,7 +2090,7 @@ class FederationHandler(BaseHandler):
             auth_events (dict[(str, str)->synapse.events.EventBase]):
 
         Returns:
-            defer.Deferred[None]
+            defer.Deferred[EventContext]: updated context
         """
         event_auth_events = set(event.auth_event_ids())
 
@@ -2107,7 +2129,7 @@ class FederationHandler(BaseHandler):
                     # The other side isn't around or doesn't implement the
                     # endpoint, so lets just bail out.
                     logger.info("Failed to get event auth from remote: %s", e)
-                    return
+                    return context
 
                 seen_remotes = yield self.store.have_seen_events(
                     [e.event_id for e in remote_auth_chain]
@@ -2148,7 +2170,7 @@ class FederationHandler(BaseHandler):
 
         if event.internal_metadata.is_outlier():
             logger.info("Skipping auth_event fetch for outlier")
-            return
+            return context
 
         # FIXME: Assumes we have and stored all the state for all the
         # prev_events
@@ -2157,7 +2179,7 @@ class FederationHandler(BaseHandler):
         )
 
         if not different_auth:
-            return
+            return context
 
         logger.info(
             "auth_events refers to events which are not in our calculated auth "
@@ -2204,10 +2226,12 @@ class FederationHandler(BaseHandler):
 
             auth_events.update(new_state)
 
-            yield self._update_context_for_auth_events(
+            context = yield self._update_context_for_auth_events(
                 event, context, auth_events, event_key
             )
 
+        return context
+
     @defer.inlineCallbacks
     def _update_context_for_auth_events(self, event, context, auth_events, event_key):
         """Update the state_ids in an event context after auth event resolution,
@@ -2216,14 +2240,16 @@ class FederationHandler(BaseHandler):
         Args:
             event (Event): The event we're handling the context for
 
-            context (synapse.events.snapshot.EventContext): event context
-                to be updated
+            context (synapse.events.snapshot.EventContext): initial event context
 
             auth_events (dict[(str, str)->str]): Events to update in the event
                 context.
 
             event_key ((str, str)): (type, state_key) for the current event.
                 this will not be included in the current_state in the context.
+
+        Returns:
+            Deferred[EventContext]: new event context
         """
         state_updates = {
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
@@ -2248,7 +2274,7 @@ class FederationHandler(BaseHandler):
             current_state_ids=current_state_ids,
         )
 
-        yield context.update_state(
+        return EventContext.with_state(
             state_group=state_group,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
@@ -2441,6 +2467,8 @@ class FederationHandler(BaseHandler):
                 raise e
 
             yield self._check_signature(event, context)
+
+            # We retrieve the room member handler here as to not cause a cyclic dependency
             member_handler = self.hs.get_room_member_handler()
             yield member_handler.send_membership_event(None, event, context)
         else:
@@ -2501,6 +2529,7 @@ class FederationHandler(BaseHandler):
         # though the sender isn't a local user.
         event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
 
+        # We retrieve the room member handler here as to not cause a cyclic dependency
         member_handler = self.hs.get_room_member_handler()
         yield member_handler.send_membership_event(None, event, context)
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 49c9e031f9..81dce96f4b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -128,8 +128,8 @@ class InitialSyncHandler(BaseHandler):
 
         tags_by_room = yield self.store.get_tags_for_user(user_id)
 
-        account_data, account_data_by_room = (
-            yield self.store.get_account_data_for_user(user_id)
+        account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+            user_id
         )
 
         public_room_ids = yield self.store.get_public_room_ids()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0d546d2487..d682dc2b7a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -76,9 +76,10 @@ class MessageHandler(object):
         Raises:
             SynapseError if something went wrong.
         """
-        membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
-            room_id, user_id
-        )
+        (
+            membership,
+            membership_event_id,
+        ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
 
         if membership == Membership.JOIN:
             data = yield self.state.get_current_state(room_id, event_type, state_key)
@@ -153,9 +154,10 @@ class MessageHandler(object):
                     % (user_id, room_id, at_token),
                 )
         else:
-            membership, membership_event_id = (
-                yield self.auth.check_in_room_or_world_readable(room_id, user_id)
-            )
+            (
+                membership,
+                membership_event_id,
+            ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
 
             if membership == Membership.JOIN:
                 state_ids = yield self.store.get_filtered_current_state_ids(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6d8b04efe3..260a4351ca 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -214,9 +214,10 @@ class PaginationHandler(object):
         source_config = pagin_config.get_source_config("room")
 
         with (yield self.pagination_lock.read(room_id)):
-            membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
-                room_id, user_id
-            )
+            (
+                membership,
+                member_event_id,
+            ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
 
             if source_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
@@ -299,10 +300,8 @@ class PaginationHandler(object):
         }
 
         if state:
-            chunk["state"] = (
-                yield self._event_serializer.serialize_events(
-                    state, time_now, as_client_event=as_client_event
-                )
+            chunk["state"] = yield self._event_serializer.serialize_events(
+                state, time_now, as_client_event=as_client_event
             )
 
         return chunk
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 53410f120b..cff6b0d375 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -396,8 +396,8 @@ class RegistrationHandler(BaseHandler):
             room_id = room_identifier
         elif RoomAlias.is_valid(room_identifier):
             room_alias = RoomAlias.from_string(room_identifier)
-            room_id, remote_room_hosts = (
-                yield room_member_handler.lookup_room_alias(room_alias)
+            room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
+                room_alias
             )
             room_id = room_id.to_string()
         else:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 650bd28abb..e92b2eafd5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler):
             old_room_id,
             new_version,  # args for _upgrade_room
         )
+
         return ret
 
     @defer.inlineCallbacks
@@ -147,21 +148,22 @@ class RoomCreationHandler(BaseHandler):
 
         # we create and auth the tombstone event before properly creating the new
         # room, to check our user has perms in the old room.
-        tombstone_event, tombstone_context = (
-            yield self.event_creation_handler.create_event(
-                requester,
-                {
-                    "type": EventTypes.Tombstone,
-                    "state_key": "",
-                    "room_id": old_room_id,
-                    "sender": user_id,
-                    "content": {
-                        "body": "This room has been replaced",
-                        "replacement_room": new_room_id,
-                    },
+        (
+            tombstone_event,
+            tombstone_context,
+        ) = yield self.event_creation_handler.create_event(
+            requester,
+            {
+                "type": EventTypes.Tombstone,
+                "state_key": "",
+                "room_id": old_room_id,
+                "sender": user_id,
+                "content": {
+                    "body": "This room has been replaced",
+                    "replacement_room": new_room_id,
                 },
-                token_id=requester.access_token_id,
-            )
+            },
+            token_id=requester.access_token_id,
         )
         old_room_version = yield self.store.get_room_version(old_room_id)
         yield self.auth.check_from_context(
@@ -188,7 +190,12 @@ class RoomCreationHandler(BaseHandler):
             requester, old_room_id, new_room_id, old_room_state
         )
 
-        # and finally, shut down the PLs in the old room, and update them in the new
+        # Copy over user push rules, tags and migrate room directory state
+        yield self.room_member_handler.transfer_room_state_on_room_upgrade(
+            old_room_id, new_room_id
+        )
+
+        # finally, shut down the PLs in the old room, and update them in the new
         # room.
         yield self._update_upgraded_room_pls(
             requester, old_room_id, new_room_id, old_room_state
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 380e2fad5e..06d09c2947 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -203,10 +203,6 @@ class RoomMemberHandler(object):
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
             if newly_joined:
-                # Copy over user state if we're joining an upgraded room
-                yield self.copy_user_state_if_room_upgrade(
-                    room_id, requester.user.to_string()
-                )
                 yield self._user_joined_room(target, room_id)
         elif event.membership == Membership.LEAVE:
             if prev_member_event_id:
@@ -455,11 +451,6 @@ class RoomMemberHandler(object):
                     requester, remote_room_hosts, room_id, target, content
                 )
 
-                # Copy over user state if this is a join on an remote upgraded room
-                yield self.copy_user_state_if_room_upgrade(
-                    room_id, requester.user.to_string()
-                )
-
                 return remote_join_response
 
         elif effective_membership_state == Membership.LEAVE:
@@ -498,36 +489,72 @@ class RoomMemberHandler(object):
         return res
 
     @defer.inlineCallbacks
-    def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
-        """Copy user-specific information when they join a new room if that new room is the
+    def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+        """Upon our server becoming aware of an upgraded room, either by upgrading a room
+        ourselves or joining one, we can transfer over information from the previous room.
+
+        Copies user state (tags/push rules) for every local user that was in the old room, as
+        well as migrating the room directory state.
+
+        Args:
+            old_room_id (str): The ID of the old room
+
+            room_id (str): The ID of the new room
+
+        Returns:
+            Deferred
+        """
+        # Find all local users that were in the old room and copy over each user's state
+        users = yield self.store.get_users_in_room(old_room_id)
+        yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
+
+        # Add new room to the room directory if the old room was there
+        # Remove old room from the room directory
+        old_room = yield self.store.get_room(old_room_id)
+        if old_room and old_room["is_public"]:
+            yield self.store.set_room_is_public(old_room_id, False)
+            yield self.store.set_room_is_public(room_id, True)
+
+    @defer.inlineCallbacks
+    def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+        """Copy user-specific information when they join a new room when that new room is the
         result of a room upgrade
 
         Args:
-            new_room_id (str): The ID of the room the user is joining
-            user_id (str): The ID of the user
+            old_room_id (str): The ID of upgraded room
+            new_room_id (str): The ID of the new room
+            user_ids (Iterable[str]): User IDs to copy state for
 
         Returns:
             Deferred
         """
-        # Check if the new room is an upgraded room
-        predecessor = yield self.store.get_room_predecessor(new_room_id)
-        if not predecessor:
-            return
 
         logger.debug(
-            "Found predecessor for %s: %s. Copying over room tags and push " "rules",
+            "Copying over room tags and push rules from %s to %s for users %s",
+            old_room_id,
             new_room_id,
-            predecessor,
+            user_ids,
         )
 
-        # It is an upgraded room. Copy over old tags
-        yield self.copy_room_tags_and_direct_to_room(
-            predecessor["room_id"], new_room_id, user_id
-        )
-        # Copy over push rules
-        yield self.store.copy_push_rules_from_room_to_room_for_user(
-            predecessor["room_id"], new_room_id, user_id
-        )
+        for user_id in user_ids:
+            try:
+                # It is an upgraded room. Copy over old tags
+                yield self.copy_room_tags_and_direct_to_room(
+                    old_room_id, new_room_id, user_id
+                )
+                # Copy over push rules
+                yield self.store.copy_push_rules_from_room_to_room_for_user(
+                    old_room_id, new_room_id, user_id
+                )
+            except Exception:
+                logger.exception(
+                    "Error copying tags and/or push rules from rooms %s to %s for user %s. "
+                    "Skipping...",
+                    old_room_id,
+                    new_room_id,
+                    user_id,
+                )
+                continue
 
     @defer.inlineCallbacks
     def send_membership_event(self, requester, event, context, ratelimit=True):
@@ -759,22 +786,25 @@ class RoomMemberHandler(object):
         if room_avatar_event:
             room_avatar_url = room_avatar_event.content.get("url", "")
 
-        token, public_keys, fallback_public_key, display_name = (
-            yield self.identity_handler.ask_id_server_for_third_party_invite(
-                requester=requester,
-                id_server=id_server,
-                medium=medium,
-                address=address,
-                room_id=room_id,
-                inviter_user_id=user.to_string(),
-                room_alias=canonical_room_alias,
-                room_avatar_url=room_avatar_url,
-                room_join_rules=room_join_rules,
-                room_name=room_name,
-                inviter_display_name=inviter_display_name,
-                inviter_avatar_url=inviter_avatar_url,
-                id_access_token=id_access_token,
-            )
+        (
+            token,
+            public_keys,
+            fallback_public_key,
+            display_name,
+        ) = yield self.identity_handler.ask_id_server_for_third_party_invite(
+            requester=requester,
+            id_server=id_server,
+            medium=medium,
+            address=address,
+            room_id=room_id,
+            inviter_user_id=user.to_string(),
+            room_alias=canonical_room_alias,
+            room_avatar_url=room_avatar_url,
+            room_join_rules=room_join_rules,
+            room_name=room_name,
+            inviter_display_name=inviter_display_name,
+            inviter_avatar_url=inviter_avatar_url,
+            id_access_token=id_access_token,
         )
 
         yield self.event_creation_handler.create_and_send_nonmember_event(
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index f4d8a60774..56ed262a1f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -396,15 +396,11 @@ class SearchHandler(BaseHandler):
         time_now = self.clock.time_msec()
 
         for context in contexts.values():
-            context["events_before"] = (
-                yield self._event_serializer.serialize_events(
-                    context["events_before"], time_now
-                )
+            context["events_before"] = yield self._event_serializer.serialize_events(
+                context["events_before"], time_now
             )
-            context["events_after"] = (
-                yield self._event_serializer.serialize_events(
-                    context["events_after"], time_now
-                )
+            context["events_after"] = yield self._event_serializer.serialize_events(
+                context["events_after"], time_now
             )
 
         state_results = {}
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 26bc276692..7f7d56390e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler):
                 user_deltas = {}
 
             # Then count deltas for total_events and total_event_bytes.
-            room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
+            (
+                room_count,
+                user_count,
+            ) = yield self.store.get_changes_room_total_events_and_bytes(
                 self.pos, max_pos
             )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 43a082dcda..b536d410e5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1206,10 +1206,11 @@ class SyncHandler(object):
         since_token = sync_result_builder.since_token
 
         if since_token and not sync_result_builder.full_state:
-            account_data, account_data_by_room = (
-                yield self.store.get_updated_account_data_for_user(
-                    user_id, since_token.account_data_key
-                )
+            (
+                account_data,
+                account_data_by_room,
+            ) = yield self.store.get_updated_account_data_for_user(
+                user_id, since_token.account_data_key
             )
 
             push_rules_changed = yield self.store.have_push_rules_changed_for_user(
@@ -1221,9 +1222,10 @@ class SyncHandler(object):
                     sync_config.user
                 )
         else:
-            account_data, account_data_by_room = (
-                yield self.store.get_account_data_for_user(sync_config.user.to_string())
-            )
+            (
+                account_data,
+                account_data_by_room,
+            ) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
 
             account_data["m.push_rules"] = yield self.push_rules_for_user(
                 sync_config.user
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 29aa1e5aaf..8363d887a9 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
     def __init__(self, hs):
         super().__init__(hs)
         self._enabled = bool(hs.config.recaptcha_private_key)
-        self._http_client = hs.get_simple_http_client()
+        self._http_client = hs.get_proxied_http_client()
         self._url = hs.config.recaptcha_siteverify_api
         self._secret = hs.config.recaptcha_private_key