summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/appservice.py6
-rw-r--r--synapse/handlers/auth.py273
-rw-r--r--synapse/handlers/device.py7
-rw-r--r--synapse/handlers/federation.py66
-rw-r--r--synapse/handlers/groups_local.py1
-rw-r--r--synapse/handlers/profile.py21
-rw-r--r--synapse/handlers/register.py3
-rw-r--r--synapse/handlers/room_list.py9
-rw-r--r--synapse/handlers/sync.py6
-rw-r--r--synapse/handlers/user_directory.py4
10 files changed, 268 insertions, 128 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 05af54d31b..543bf28aec 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -70,11 +70,10 @@ class ApplicationServicesHandler(object):
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
-                upper_bound = self.current_max
                 limit = 100
                 while True:
                     upper_bound, events = yield self.store.get_new_events_for_appservice(
-                        upper_bound, limit
+                        self.current_max, limit
                     )
 
                     if not events:
@@ -105,9 +104,6 @@ class ApplicationServicesHandler(object):
                             )
 
                     yield self.store.set_appservice_last_pos(upper_bound)
-
-                    if len(events) < limit:
-                        break
             finally:
                 self.is_processing = False
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 9cef9d184b..7a0ba6ef35 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,13 +13,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 from twisted.internet import defer
 
 from ._base import BaseHandler
 from synapse.api.constants import LoginType
-from synapse.types import UserID
 from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
+from synapse.module_api import ModuleApi
+from synapse.types import UserID
 from synapse.util.async import run_on_reactor
 from synapse.util.caches.expiringcache import ExpiringCache
 
@@ -63,10 +63,7 @@ class AuthHandler(BaseHandler):
             reset_expiry_on_get=True,
         )
 
-        account_handler = _AccountHandler(
-            hs, check_user_exists=self.check_user_exists
-        )
-
+        account_handler = ModuleApi(hs, self)
         self.password_providers = [
             module(config=config, account_handler=account_handler)
             for module, config in hs.config.password_providers
@@ -75,14 +72,24 @@ class AuthHandler(BaseHandler):
         logger.info("Extra password_providers: %r", self.password_providers)
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
-        self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
+        self._password_enabled = hs.config.password_enabled
+
+        login_types = set()
+        if self._password_enabled:
+            login_types.add(LoginType.PASSWORD)
+        for provider in self.password_providers:
+            if hasattr(provider, "get_supported_login_types"):
+                login_types.update(
+                    provider.get_supported_login_types().keys()
+                )
+        self._supported_login_types = frozenset(login_types)
 
     @defer.inlineCallbacks
     def check_auth(self, flows, clientdict, clientip):
         """
         Takes a dictionary sent by the client in the login / registration
-        protocol and handles the login flow.
+        protocol and handles the User-Interactive Auth flow.
 
         As a side effect, this function fills in the 'creds' key on the user's
         session with a map, which maps each auth-type (str) to the relevant
@@ -260,16 +267,19 @@ class AuthHandler(BaseHandler):
         sess = self._get_session_info(session_id)
         return sess.setdefault('serverdict', {}).get(key, default)
 
+    @defer.inlineCallbacks
     def _check_password_auth(self, authdict, _):
         if "user" not in authdict or "password" not in authdict:
             raise LoginError(400, "", Codes.MISSING_PARAM)
 
         user_id = authdict["user"]
         password = authdict["password"]
-        if not user_id.startswith('@'):
-            user_id = UserID(user_id, self.hs.hostname).to_string()
 
-        return self._check_password(user_id, password)
+        (canonical_id, callback) = yield self.validate_login(user_id, {
+            "type": LoginType.PASSWORD,
+            "password": password,
+        })
+        defer.returnValue(canonical_id)
 
     @defer.inlineCallbacks
     def _check_recaptcha(self, authdict, clientip):
@@ -398,26 +408,8 @@ class AuthHandler(BaseHandler):
 
         return self.sessions[session_id]
 
-    def validate_password_login(self, user_id, password):
-        """
-        Authenticates the user with their username and password.
-
-        Used only by the v1 login API.
-
-        Args:
-            user_id (str): complete @user:id
-            password (str): Password
-        Returns:
-            defer.Deferred: (str) canonical user id
-        Raises:
-            StoreError if there was a problem accessing the database
-            LoginError if there was an authentication problem.
-        """
-        return self._check_password(user_id, password)
-
     @defer.inlineCallbacks
-    def get_access_token_for_user_id(self, user_id, device_id=None,
-                                     initial_display_name=None):
+    def get_access_token_for_user_id(self, user_id, device_id=None):
         """
         Creates a new access token for the user with the given user ID.
 
@@ -431,13 +423,10 @@ class AuthHandler(BaseHandler):
             device_id (str|None): the device ID to associate with the tokens.
                None to leave the tokens unassociated with a device (deprecated:
                we should always have a device ID)
-            initial_display_name (str): display name to associate with the
-               device if it needs re-registering
         Returns:
               The access token for the user's session.
         Raises:
             StoreError if there was a problem storing the token.
-            LoginError if there was an authentication problem.
         """
         logger.info("Logging in user %s on device %s", user_id, device_id)
         access_token = yield self.issue_access_token(user_id, device_id)
@@ -447,9 +436,11 @@ class AuthHandler(BaseHandler):
         # really don't want is active access_tokens without a record of the
         # device, so we double-check it here.
         if device_id is not None:
-            yield self.device_handler.check_device_registered(
-                user_id, device_id, initial_display_name
-            )
+            try:
+                yield self.store.get_device(user_id, device_id)
+            except StoreError:
+                yield self.store.delete_access_token(access_token)
+                raise StoreError(400, "Login raced against device deletion")
 
         defer.returnValue(access_token)
 
@@ -501,29 +492,115 @@ class AuthHandler(BaseHandler):
             )
         defer.returnValue(result)
 
+    def get_supported_login_types(self):
+        """Get a the login types supported for the /login API
+
+        By default this is just 'm.login.password' (unless password_enabled is
+        False in the config file), but password auth providers can provide
+        other login types.
+
+        Returns:
+            Iterable[str]: login types
+        """
+        return self._supported_login_types
+
     @defer.inlineCallbacks
-    def _check_password(self, user_id, password):
-        """Authenticate a user against the LDAP and local databases.
+    def validate_login(self, username, login_submission):
+        """Authenticates the user for the /login API
 
-        user_id is checked case insensitively against the local database, but
-        will throw if there are multiple inexact matches.
+        Also used by the user-interactive auth flow to validate
+        m.login.password auth types.
 
         Args:
-            user_id (str): complete @user:id
+            username (str): username supplied by the user
+            login_submission (dict): the whole of the login submission
+                (including 'type' and other relevant fields)
         Returns:
-            (str) the canonical_user_id
+            Deferred[str, func]: canonical user id, and optional callback
+                to be called once the access token and device id are issued
         Raises:
-            LoginError if login fails
+            StoreError if there was a problem accessing the database
+            SynapseError if there was a problem with the request
+            LoginError if there was an authentication problem.
         """
+
+        if username.startswith('@'):
+            qualified_user_id = username
+        else:
+            qualified_user_id = UserID(
+                username, self.hs.hostname
+            ).to_string()
+
+        login_type = login_submission.get("type")
+        known_login_type = False
+
+        # special case to check for "password" for the check_password interface
+        # for the auth providers
+        password = login_submission.get("password")
+        if login_type == LoginType.PASSWORD:
+            if not self._password_enabled:
+                raise SynapseError(400, "Password login has been disabled.")
+            if not password:
+                raise SynapseError(400, "Missing parameter: password")
+
         for provider in self.password_providers:
-            is_valid = yield provider.check_password(user_id, password)
-            if is_valid:
-                defer.returnValue(user_id)
+            if (hasattr(provider, "check_password")
+                    and login_type == LoginType.PASSWORD):
+                known_login_type = True
+                is_valid = yield provider.check_password(
+                    qualified_user_id, password,
+                )
+                if is_valid:
+                    defer.returnValue(qualified_user_id)
+
+            if (not hasattr(provider, "get_supported_login_types")
+                    or not hasattr(provider, "check_auth")):
+                # this password provider doesn't understand custom login types
+                continue
+
+            supported_login_types = provider.get_supported_login_types()
+            if login_type not in supported_login_types:
+                # this password provider doesn't understand this login type
+                continue
+
+            known_login_type = True
+            login_fields = supported_login_types[login_type]
+
+            missing_fields = []
+            login_dict = {}
+            for f in login_fields:
+                if f not in login_submission:
+                    missing_fields.append(f)
+                else:
+                    login_dict[f] = login_submission[f]
+            if missing_fields:
+                raise SynapseError(
+                    400, "Missing parameters for login type %s: %s" % (
+                        login_type,
+                        missing_fields,
+                    ),
+                )
+
+            result = yield provider.check_auth(
+                username, login_type, login_dict,
+            )
+            if result:
+                if isinstance(result, str):
+                    result = (result, None)
+                defer.returnValue(result)
+
+        if login_type == LoginType.PASSWORD:
+            known_login_type = True
+
+            canonical_user_id = yield self._check_local_password(
+                qualified_user_id, password,
+            )
 
-        canonical_user_id = yield self._check_local_password(user_id, password)
+            if canonical_user_id:
+                defer.returnValue((canonical_user_id, None))
 
-        if canonical_user_id:
-            defer.returnValue(canonical_user_id)
+        if not known_login_type:
+            raise SynapseError(400, "Unknown login type %s" % login_type)
 
         # unknown username or invalid password. We raise a 403 here, but note
         # that if we're doing user-interactive login, it turns all LoginErrors
@@ -584,14 +661,81 @@ class AuthHandler(BaseHandler):
             if e.code == 404:
                 raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
             raise e
-        yield self.store.user_delete_access_tokens(
-            user_id, except_access_token_id
+        yield self.delete_access_tokens_for_user(
+            user_id, except_token_id=except_access_token_id,
         )
         yield self.hs.get_pusherpool().remove_pushers_by_user(
             user_id, except_access_token_id
         )
 
     @defer.inlineCallbacks
+    def deactivate_account(self, user_id):
+        """Deactivate a user's account
+
+        Args:
+            user_id (str): ID of user to be deactivated
+
+        Returns:
+            Deferred
+        """
+        # FIXME: Theoretically there is a race here wherein user resets
+        # password using threepid.
+        yield self.delete_access_tokens_for_user(user_id)
+        yield self.store.user_delete_threepids(user_id)
+        yield self.store.user_set_password_hash(user_id, None)
+
+    @defer.inlineCallbacks
+    def delete_access_token(self, access_token):
+        """Invalidate a single access token
+
+        Args:
+            access_token (str): access token to be deleted
+
+        Returns:
+            Deferred
+        """
+        user_info = yield self.auth.get_user_by_access_token(access_token)
+        yield self.store.delete_access_token(access_token)
+
+        # see if any of our auth providers want to know about this
+        for provider in self.password_providers:
+            if hasattr(provider, "on_logged_out"):
+                yield provider.on_logged_out(
+                    user_id=str(user_info["user"]),
+                    device_id=user_info["device_id"],
+                    access_token=access_token,
+                )
+
+    @defer.inlineCallbacks
+    def delete_access_tokens_for_user(self, user_id, except_token_id=None,
+                                      device_id=None):
+        """Invalidate access tokens belonging to a user
+
+        Args:
+            user_id (str):  ID of user the tokens belong to
+            except_token_id (str|None): access_token ID which should *not* be
+                deleted
+            device_id (str|None):  ID of device the tokens are associated with.
+                If None, tokens associated with any device (or no device) will
+                be deleted
+        Returns:
+            Deferred
+        """
+        tokens_and_devices = yield self.store.user_delete_access_tokens(
+            user_id, except_token_id=except_token_id, device_id=device_id,
+        )
+
+        # see if any of our auth providers want to know about this
+        for provider in self.password_providers:
+            if hasattr(provider, "on_logged_out"):
+                for token, device_id in tokens_and_devices:
+                    yield provider.on_logged_out(
+                        user_id=user_id,
+                        device_id=device_id,
+                        access_token=token,
+                    )
+
+    @defer.inlineCallbacks
     def add_threepid(self, user_id, medium, address, validated_at):
         # 'Canonicalise' email addresses down to lower case.
         # We've now moving towards the Home Server being the entity that
@@ -696,30 +840,3 @@ class MacaroonGeneartor(object):
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon
-
-
-class _AccountHandler(object):
-    """A proxy object that gets passed to password auth providers so they
-    can register new users etc if necessary.
-    """
-    def __init__(self, hs, check_user_exists):
-        self.hs = hs
-
-        self._check_user_exists = check_user_exists
-
-    def check_user_exists(self, user_id):
-        """Check if user exissts.
-
-        Returns:
-            Deferred(bool)
-        """
-        return self._check_user_exists(user_id)
-
-    def register(self, localpart):
-        """Registers a new user with given localpart
-
-        Returns:
-            Deferred: a 2-tuple of (user_id, access_token)
-        """
-        reg = self.hs.get_handlers().registration_handler
-        return reg.register(localpart=localpart)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index dac4b3f4e0..579d8477ba 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler):
 
         self.hs = hs
         self.state = hs.get_state_handler()
+        self._auth_handler = hs.get_auth_handler()
         self.federation_sender = hs.get_federation_sender()
         self.federation = hs.get_replication_layer()
 
@@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler):
             else:
                 raise
 
-        yield self.store.user_delete_access_tokens(
+        yield self._auth_handler.delete_access_tokens_for_user(
             user_id, device_id=device_id,
-            delete_refresh_tokens=True,
         )
 
         yield self.store.delete_e2e_keys_by_device(
@@ -194,9 +194,8 @@ class DeviceHandler(BaseHandler):
         # Delete access tokens and e2e keys for each device. Not optimised as it is not
         # considered as part of a critical path.
         for device_id in device_ids:
-            yield self.store.user_delete_access_tokens(
+            yield self._auth_handler.delete_access_tokens_for_user(
                 user_id, device_id=device_id,
-                delete_refresh_tokens=True,
             )
             yield self.store.delete_e2e_keys_by_device(
                 user_id=user_id, device_id=device_id
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8b1e606754..ac70730885 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1706,6 +1706,17 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     @log_function
     def do_auth(self, origin, event, context, auth_events):
+        """
+
+        Args:
+            origin (str):
+            event (synapse.events.FrozenEvent):
+            context (synapse.events.snapshot.EventContext):
+            auth_events (dict[(str, str)->str]):
+
+        Returns:
+            defer.Deferred[None]
+        """
         # Check if we have all the auth events.
         current_state = set(e.event_id for e in auth_events.values())
         event_auth_events = set(e_id for e_id, _ in event.auth_events)
@@ -1817,16 +1828,9 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                context.current_state_ids = dict(context.current_state_ids)
-                context.current_state_ids.update({
-                    k: a.event_id for k, a in auth_events.items()
-                    if k != event_key
-                })
-                context.prev_state_ids = dict(context.prev_state_ids)
-                context.prev_state_ids.update({
-                    k: a.event_id for k, a in auth_events.items()
-                })
-                context.state_group = self.store.get_next_state_group()
+                self._update_context_for_auth_events(
+                    context, auth_events, event_key,
+                )
 
         if different_auth and not event.internal_metadata.is_outlier():
             logger.info("Different auth after resolution: %s", different_auth)
@@ -1906,16 +1910,9 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                context.current_state_ids = dict(context.current_state_ids)
-                context.current_state_ids.update({
-                    k: a.event_id for k, a in auth_events.items()
-                    if k != event_key
-                })
-                context.prev_state_ids = dict(context.prev_state_ids)
-                context.prev_state_ids.update({
-                    k: a.event_id for k, a in auth_events.items()
-                })
-                context.state_group = self.store.get_next_state_group()
+                self._update_context_for_auth_events(
+                    context, auth_events, event_key,
+                )
 
         try:
             self.auth.check(event, auth_events=auth_events)
@@ -1923,6 +1920,35 @@ class FederationHandler(BaseHandler):
             logger.warn("Failed auth resolution for %r because %s", event, e)
             raise e
 
+    def _update_context_for_auth_events(self, context, auth_events,
+                                        event_key):
+        """Update the state_ids in an event context after auth event resolution
+
+        Args:
+            context (synapse.events.snapshot.EventContext): event context
+                to be updated
+
+            auth_events (dict[(str, str)->str]): Events to update in the event
+                context.
+
+            event_key ((str, str)): (type, state_key) for the current event.
+                this will not be included in the current_state in the context.
+        """
+        state_updates = {
+            k: a.event_id for k, a in auth_events.iteritems()
+            if k != event_key
+        }
+        context.current_state_ids = dict(context.current_state_ids)
+        context.current_state_ids.update(state_updates)
+        if context.delta_ids is not None:
+            context.delta_ids = dict(context.delta_ids)
+            context.delta_ids.update(state_updates)
+        context.prev_state_ids = dict(context.prev_state_ids)
+        context.prev_state_ids.update({
+            k: a.event_id for k, a in auth_events.iteritems()
+        })
+        context.state_group = self.store.get_next_state_group()
+
     @defer.inlineCallbacks
     def construct_auth_difference(self, local_auth, remote_auth):
         """ Given a local and remote auth chain, find the differences. This
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 6699d0888f..da00aeb0f4 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -71,6 +71,7 @@ class GroupsLocalHandler(object):
     get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
 
     add_room_to_group = _create_rerouter("add_room_to_group")
+    update_room_in_group = _create_rerouter("update_room_in_group")
     remove_room_from_group = _create_rerouter("remove_room_from_group")
 
     update_group_summary_room = _create_rerouter("update_group_summary_room")
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 62b9bd503e..5e5b1952dd 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,7 +17,6 @@ import logging
 
 from twisted.internet import defer
 
-import synapse.types
 from synapse.api.errors import SynapseError, AuthError, CodeMessageException
 from synapse.types import UserID, get_domain_from_id
 from ._base import BaseHandler
@@ -140,7 +139,7 @@ class ProfileHandler(BaseHandler):
             target_user.localpart, new_displayname
         )
 
-        yield self._update_join_states(requester)
+        yield self._update_join_states(requester, target_user)
 
     @defer.inlineCallbacks
     def get_avatar_url(self, target_user):
@@ -184,7 +183,7 @@ class ProfileHandler(BaseHandler):
             target_user.localpart, new_avatar_url
         )
 
-        yield self._update_join_states(requester)
+        yield self._update_join_states(requester, target_user)
 
     @defer.inlineCallbacks
     def on_profile_query(self, args):
@@ -209,28 +208,24 @@ class ProfileHandler(BaseHandler):
         defer.returnValue(response)
 
     @defer.inlineCallbacks
-    def _update_join_states(self, requester):
-        user = requester.user
-        if not self.hs.is_mine(user):
+    def _update_join_states(self, requester, target_user):
+        if not self.hs.is_mine(target_user):
             return
 
         yield self.ratelimit(requester)
 
         room_ids = yield self.store.get_rooms_for_user(
-            user.to_string(),
+            target_user.to_string(),
         )
 
         for room_id in room_ids:
             handler = self.hs.get_handlers().room_member_handler
             try:
-                # Assume the user isn't a guest because we don't let guests set
-                # profile or avatar data.
-                # XXX why are we recreating `requester` here for each room?
-                # what was wrong with the `requester` we were passed?
-                requester = synapse.types.create_requester(user)
+                # Assume the target_user isn't a guest,
+                # because we don't let guests set profile or avatar data.
                 yield handler.update_membership(
                     requester,
-                    user,
+                    target_user,
                     room_id,
                     "join",  # We treat a profile update like a join.
                     ratelimit=False,  # Try to hide that these events aren't atomic.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 49dc33c147..f6e7e58563 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
         super(RegistrationHandler, self).__init__(hs)
 
         self.auth = hs.get_auth()
+        self._auth_handler = hs.get_auth_handler()
         self.profile_handler = hs.get_profile_handler()
         self.captcha_client = CaptchaServerHttpClient(hs)
 
@@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler):
                 create_profile_with_localpart=user.localpart,
             )
         else:
-            yield self.store.user_delete_access_tokens(user_id=user_id)
+            yield self._auth_handler.delete_access_tokens_for_user(user_id)
             yield self.store.add_access_token_to_user(user_id=user_id, token=token)
 
         if displayname is not None:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 41e1781df7..2cf34e51cb 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,6 +20,7 @@ from ._base import BaseHandler
 from synapse.api.constants import (
     EventTypes, JoinRules,
 )
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 from synapse.util.async import concurrently_execute
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.util.caches.response_cache import ResponseCache
@@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler):
         if search_filter:
             # We explicitly don't bother caching searches or requests for
             # appservice specific lists.
+            logger.info("Bypassing cache as search request.")
             return self._get_public_room_list(
                 limit, since_token, search_filter, network_tuple=network_tuple,
             )
@@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler):
         key = (limit, since_token, network_tuple)
         result = self.response_cache.get(key)
         if not result:
+            logger.info("No cached result, calculating one.")
             result = self.response_cache.set(
                 key,
-                self._get_public_room_list(
+                preserve_fn(self._get_public_room_list)(
                     limit, since_token, network_tuple=network_tuple
                 )
             )
-        return result
+        else:
+            logger.info("Using cached deferred result.")
+        return make_deferred_yieldable(result)
 
     @defer.inlineCallbacks
     def _get_public_room_list(self, limit=None, since_token=None,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 219529936f..b12988f3c9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -15,7 +15,7 @@
 
 from synapse.api.constants import Membership, EventTypes
 from synapse.util.async import concurrently_execute
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
 from synapse.util.metrics import Measure, measure_func
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.push.clientformat import format_push_rules_for_user
@@ -184,11 +184,11 @@ class SyncHandler(object):
         if not result:
             result = self.response_cache.set(
                 sync_config.request_key,
-                self._wait_for_sync_for_user(
+                preserve_fn(self._wait_for_sync_for_user)(
                     sync_config, since_token, timeout, full_state
                 )
             )
-        return result
+        return make_deferred_yieldable(result)
 
     @defer.inlineCallbacks
     def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 2a49456bfc..b5be5d9623 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -152,7 +152,7 @@ class UserDirectoyHandler(object):
 
         for room_id in room_ids:
             logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
-            yield self._handle_intial_room(room_id)
+            yield self._handle_initial_room(room_id)
             num_processed_rooms += 1
             yield sleep(self.INITIAL_SLEEP_MS / 1000.)
 
@@ -166,7 +166,7 @@ class UserDirectoyHandler(object):
         yield self.store.update_user_directory_stream_pos(new_pos)
 
     @defer.inlineCallbacks
-    def _handle_intial_room(self, room_id):
+    def _handle_initial_room(self, room_id):
         """Called when we initially fill out user_directory one room at a time
         """
         is_in_room = yield self.store.is_host_joined(room_id, self.server_name)