summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py279
-rw-r--r--synapse/handlers/federation.py15
-rw-r--r--synapse/handlers/initial_sync.py443
-rw-r--r--synapse/handlers/message.py381
-rw-r--r--synapse/handlers/room_list.py3
-rw-r--r--synapse/handlers/typing.py177
6 files changed, 757 insertions, 541 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6986930c0d..3933ce171a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -31,6 +31,7 @@ import simplejson
 
 try:
     import ldap3
+    import ldap3.core.exceptions
 except ImportError:
     ldap3 = None
     pass
@@ -504,6 +505,144 @@ class AuthHandler(BaseHandler):
             raise LoginError(403, "", errcode=Codes.FORBIDDEN)
         defer.returnValue(user_id)
 
+    def _ldap_simple_bind(self, server, localpart, password):
+        """ Attempt a simple bind with the credentials
+            given by the user against the LDAP server.
+
+            Returns True, LDAP3Connection
+                if the bind was successful
+            Returns False, None
+                if an error occured
+        """
+
+        try:
+            # bind with the the local users ldap credentials
+            bind_dn = "{prop}={value},{base}".format(
+                prop=self.ldap_attributes['uid'],
+                value=localpart,
+                base=self.ldap_base
+            )
+            conn = ldap3.Connection(server, bind_dn, password)
+            logger.debug(
+                "Established LDAP connection in simple bind mode: %s",
+                conn
+            )
+
+            if self.ldap_start_tls:
+                conn.start_tls()
+                logger.debug(
+                    "Upgraded LDAP connection in simple bind mode through StartTLS: %s",
+                    conn
+                )
+
+            if conn.bind():
+                # GOOD: bind okay
+                logger.debug("LDAP Bind successful in simple bind mode.")
+                return True, conn
+
+            # BAD: bind failed
+            logger.info(
+                "Binding against LDAP failed for '%s' failed: %s",
+                localpart, conn.result['description']
+            )
+            conn.unbind()
+            return False, None
+
+        except ldap3.core.exceptions.LDAPException as e:
+            logger.warn("Error during LDAP authentication: %s", e)
+            return False, None
+
+    def _ldap_authenticated_search(self, server, localpart, password):
+        """ Attempt to login with the preconfigured bind_dn
+            and then continue searching and filtering within
+            the base_dn
+
+            Returns (True, LDAP3Connection)
+                if a single matching DN within the base was found
+                that matched the filter expression, and with which
+                a successful bind was achieved
+
+                The LDAP3Connection returned is the instance that was used to
+                verify the password not the one using the configured bind_dn.
+            Returns (False, None)
+                if an error occured
+        """
+
+        try:
+            conn = ldap3.Connection(
+                server,
+                self.ldap_bind_dn,
+                self.ldap_bind_password
+            )
+            logger.debug(
+                "Established LDAP connection in search mode: %s",
+                conn
+            )
+
+            if self.ldap_start_tls:
+                conn.start_tls()
+                logger.debug(
+                    "Upgraded LDAP connection in search mode through StartTLS: %s",
+                    conn
+                )
+
+            if not conn.bind():
+                logger.warn(
+                    "Binding against LDAP with `bind_dn` failed: %s",
+                    conn.result['description']
+                )
+                conn.unbind()
+                return False, None
+
+            # construct search_filter like (uid=localpart)
+            query = "({prop}={value})".format(
+                prop=self.ldap_attributes['uid'],
+                value=localpart
+            )
+            if self.ldap_filter:
+                # combine with the AND expression
+                query = "(&{query}{filter})".format(
+                    query=query,
+                    filter=self.ldap_filter
+                )
+            logger.debug(
+                "LDAP search filter: %s",
+                query
+            )
+            conn.search(
+                search_base=self.ldap_base,
+                search_filter=query
+            )
+
+            if len(conn.response) == 1:
+                # GOOD: found exactly one result
+                user_dn = conn.response[0]['dn']
+                logger.debug('LDAP search found dn: %s', user_dn)
+
+                # unbind and simple bind with user_dn to verify the password
+                # Note: do not use rebind(), for some reason it did not verify
+                #       the password for me!
+                conn.unbind()
+                return self._ldap_simple_bind(server, localpart, password)
+            else:
+                # BAD: found 0 or > 1 results, abort!
+                if len(conn.response) == 0:
+                    logger.info(
+                        "LDAP search returned no results for '%s'",
+                        localpart
+                    )
+                else:
+                    logger.info(
+                        "LDAP search returned too many (%s) results for '%s'",
+                        len(conn.response), localpart
+                    )
+                conn.unbind()
+                return False, None
+
+        except ldap3.core.exceptions.LDAPException as e:
+            logger.warn("Error during LDAP authentication: %s", e)
+            return False, None
+
     @defer.inlineCallbacks
     def _check_ldap_password(self, user_id, password):
         """ Attempt to authenticate a user against an LDAP Server
@@ -516,106 +655,62 @@ class AuthHandler(BaseHandler):
         if not ldap3 or not self.ldap_enabled:
             defer.returnValue(False)
 
-        if self.ldap_mode not in LDAPMode.LIST:
-            raise RuntimeError(
-                'Invalid ldap mode specified: {mode}'.format(
-                    mode=self.ldap_mode
-                )
-            )
+        localpart = UserID.from_string(user_id).localpart
 
         try:
             server = ldap3.Server(self.ldap_uri)
             logger.debug(
-                "Attempting ldap connection with %s",
+                "Attempting LDAP connection with %s",
                 self.ldap_uri
             )
 
-            localpart = UserID.from_string(user_id).localpart
             if self.ldap_mode == LDAPMode.SIMPLE:
-                # bind with the the local users ldap credentials
-                bind_dn = "{prop}={value},{base}".format(
-                    prop=self.ldap_attributes['uid'],
-                    value=localpart,
-                    base=self.ldap_base
+                result, conn = self._ldap_simple_bind(
+                    server=server, localpart=localpart, password=password
                 )
-                conn = ldap3.Connection(server, bind_dn, password)
                 logger.debug(
-                    "Established ldap connection in simple mode: %s",
+                    'LDAP authentication method simple bind returned: %s (conn: %s)',
+                    result,
                     conn
                 )
-
-                if self.ldap_start_tls:
-                    conn.start_tls()
-                    logger.debug(
-                        "Upgraded ldap connection in simple mode through StartTLS: %s",
-                        conn
-                    )
-
-                conn.bind()
-
+                if not result:
+                    defer.returnValue(False)
             elif self.ldap_mode == LDAPMode.SEARCH:
-                # connect with preconfigured credentials and search for local user
-                conn = ldap3.Connection(
-                    server,
-                    self.ldap_bind_dn,
-                    self.ldap_bind_password
+                result, conn = self._ldap_authenticated_search(
+                    server=server, localpart=localpart, password=password
                 )
                 logger.debug(
-                    "Established ldap connection in search mode: %s",
+                    'LDAP auth method authenticated search returned: %s (conn: %s)',
+                    result,
                     conn
                 )
-
-                if self.ldap_start_tls:
-                    conn.start_tls()
-                    logger.debug(
-                        "Upgraded ldap connection in search mode through StartTLS: %s",
-                        conn
+                if not result:
+                    defer.returnValue(False)
+            else:
+                raise RuntimeError(
+                    'Invalid LDAP mode specified: {mode}'.format(
+                        mode=self.ldap_mode
                     )
-
-                conn.bind()
-
-                # find matching dn
-                query = "({prop}={value})".format(
-                    prop=self.ldap_attributes['uid'],
-                    value=localpart
                 )
-                if self.ldap_filter:
-                    query = "(&{query}{filter})".format(
-                        query=query,
-                        filter=self.ldap_filter
-                    )
-                logger.debug("ldap search filter: %s", query)
-                result = conn.search(self.ldap_base, query)
-
-                if result and len(conn.response) == 1:
-                    # found exactly one result
-                    user_dn = conn.response[0]['dn']
-                    logger.debug('ldap search found dn: %s', user_dn)
-
-                    # unbind and reconnect, rebind with found dn
-                    conn.unbind()
-                    conn = ldap3.Connection(
-                        server,
-                        user_dn,
-                        password,
-                        auto_bind=True
-                    )
-                else:
-                    # found 0 or > 1 results, abort!
-                    logger.warn(
-                        "ldap search returned unexpected (%d!=1) amount of results",
-                        len(conn.response)
-                    )
-                    defer.returnValue(False)
 
-            logger.info(
-                "User authenticated against ldap server: %s",
-                conn
-            )
+            try:
+                logger.info(
+                    "User authenticated against LDAP server: %s",
+                    conn
+                )
+            except NameError:
+                logger.warn("Authentication method yielded no LDAP connection, aborting!")
+                defer.returnValue(False)
+
+            # check if user with user_id exists
+            if (yield self.check_user_exists(user_id)):
+                # exists, authentication complete
+                conn.unbind()
+                defer.returnValue(True)
 
-            # check for existing account, if none exists, create one
-            if not (yield self.check_user_exists(user_id)):
-                # query user metadata for account creation
+            else:
+                # does not exist, fetch metadata for account creation from
+                # existing ldap connection
                 query = "({prop}={value})".format(
                     prop=self.ldap_attributes['uid'],
                     value=localpart
@@ -626,9 +721,12 @@ class AuthHandler(BaseHandler):
                         filter=query,
                         user_filter=self.ldap_filter
                     )
-                logger.debug("ldap registration filter: %s", query)
+                logger.debug(
+                    "ldap registration filter: %s",
+                    query
+                )
 
-                result = conn.search(
+                conn.search(
                     search_base=self.ldap_base,
                     search_filter=query,
                     attributes=[
@@ -651,20 +749,27 @@ class AuthHandler(BaseHandler):
                     # TODO: bind email, set displayname with data from ldap directory
 
                     logger.info(
-                        "ldap registration successful: %d: %s (%s, %)",
+                        "Registration based on LDAP data was successful: %d: %s (%s, %)",
                         user_id,
                         localpart,
                         name,
                         mail
                     )
+
+                    defer.returnValue(True)
                 else:
-                    logger.warn(
-                        "ldap registration failed: unexpected (%d!=1) amount of results",
-                        len(conn.response)
-                    )
+                    if len(conn.response) == 0:
+                        logger.warn("LDAP registration failed, no result.")
+                    else:
+                        logger.warn(
+                            "LDAP registration failed, too many results (%s)",
+                            len(conn.response)
+                        )
+
                     defer.returnValue(False)
 
-            defer.returnValue(True)
+            defer.returnValue(False)
+
         except ldap3.core.exceptions.LDAPException as e:
             logger.warn("Error during ldap authentication: %s", e)
             defer.returnValue(False)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f7cb3c1bb2..2d801bad47 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1922,15 +1922,18 @@ class FederationHandler(BaseHandler):
             original_invite = yield self.store.get_event(
                 original_invite_id, allow_none=True
             )
-        if not original_invite:
+        if original_invite:
+            display_name = original_invite.content["display_name"]
+            event_dict["content"]["third_party_invite"]["display_name"] = display_name
+        else:
             logger.info(
-                "Could not find invite event for third_party_invite - "
-                "discarding: %s" % (event_dict,)
+                "Could not find invite event for third_party_invite: %r",
+                event_dict
             )
-            return
+            # We don't discard here as this is not the appropriate place to do
+            # auth checks. If we need the invite and don't have it then the
+            # auth check code will explode appropriately.
 
-        display_name = original_invite.content["display_name"]
-        event_dict["content"]["third_party_invite"]["display_name"] = display_name
         builder = self.event_builder_factory.new(event_dict)
         EventValidator().validate_new(builder)
         message_handler = self.hs.get_handlers().message_handler
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
new file mode 100644
index 0000000000..fbfa5a0281
--- /dev/null
+++ b/synapse/handlers/initial_sync.py
@@ -0,0 +1,443 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import AuthError, Codes
+from synapse.events.utils import serialize_event
+from synapse.events.validator import EventValidator
+from synapse.streams.config import PaginationConfig
+from synapse.types import (
+    UserID, StreamToken,
+)
+from synapse.util import unwrapFirstError
+from synapse.util.async import concurrently_execute
+from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.visibility import filter_events_for_client
+
+from ._base import BaseHandler
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class InitialSyncHandler(BaseHandler):
+    def __init__(self, hs):
+        super(InitialSyncHandler, self).__init__(hs)
+        self.hs = hs
+        self.state = hs.get_state_handler()
+        self.clock = hs.get_clock()
+        self.validator = EventValidator()
+        self.snapshot_cache = SnapshotCache()
+
+    def snapshot_all_rooms(self, user_id=None, pagin_config=None,
+                           as_client_event=True, include_archived=False):
+        """Retrieve a snapshot of all rooms the user is invited or has joined.
+
+        This snapshot may include messages for all rooms where the user is
+        joined, depending on the pagination config.
+
+        Args:
+            user_id (str): The ID of the user making the request.
+            pagin_config (synapse.api.streams.PaginationConfig): The pagination
+            config used to determine how many messages *PER ROOM* to return.
+            as_client_event (bool): True to get events in client-server format.
+            include_archived (bool): True to get rooms that the user has left
+        Returns:
+            A list of dicts with "room_id" and "membership" keys for all rooms
+            the user is currently invited or joined in on. Rooms where the user
+            is joined on, may return a "messages" key with messages, depending
+            on the specified PaginationConfig.
+        """
+        key = (
+            user_id,
+            pagin_config.from_token,
+            pagin_config.to_token,
+            pagin_config.direction,
+            pagin_config.limit,
+            as_client_event,
+            include_archived,
+        )
+        now_ms = self.clock.time_msec()
+        result = self.snapshot_cache.get(now_ms, key)
+        if result is not None:
+            return result
+
+        return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
+            user_id, pagin_config, as_client_event, include_archived
+        ))
+
+    @defer.inlineCallbacks
+    def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
+                            as_client_event=True, include_archived=False):
+
+        memberships = [Membership.INVITE, Membership.JOIN]
+        if include_archived:
+            memberships.append(Membership.LEAVE)
+
+        room_list = yield self.store.get_rooms_for_user_where_membership_is(
+            user_id=user_id, membership_list=memberships
+        )
+
+        user = UserID.from_string(user_id)
+
+        rooms_ret = []
+
+        now_token = yield self.hs.get_event_sources().get_current_token()
+
+        presence_stream = self.hs.get_event_sources().sources["presence"]
+        pagination_config = PaginationConfig(from_token=now_token)
+        presence, _ = yield presence_stream.get_pagination_rows(
+            user, pagination_config.get_source_config("presence"), None
+        )
+
+        receipt_stream = self.hs.get_event_sources().sources["receipt"]
+        receipt, _ = yield receipt_stream.get_pagination_rows(
+            user, pagination_config.get_source_config("receipt"), None
+        )
+
+        tags_by_room = yield self.store.get_tags_for_user(user_id)
+
+        account_data, account_data_by_room = (
+            yield self.store.get_account_data_for_user(user_id)
+        )
+
+        public_room_ids = yield self.store.get_public_room_ids()
+
+        limit = pagin_config.limit
+        if limit is None:
+            limit = 10
+
+        @defer.inlineCallbacks
+        def handle_room(event):
+            d = {
+                "room_id": event.room_id,
+                "membership": event.membership,
+                "visibility": (
+                    "public" if event.room_id in public_room_ids
+                    else "private"
+                ),
+            }
+
+            if event.membership == Membership.INVITE:
+                time_now = self.clock.time_msec()
+                d["inviter"] = event.sender
+
+                invite_event = yield self.store.get_event(event.event_id)
+                d["invite"] = serialize_event(invite_event, time_now, as_client_event)
+
+            rooms_ret.append(d)
+
+            if event.membership not in (Membership.JOIN, Membership.LEAVE):
+                return
+
+            try:
+                if event.membership == Membership.JOIN:
+                    room_end_token = now_token.room_key
+                    deferred_room_state = self.state_handler.get_current_state(
+                        event.room_id
+                    )
+                elif event.membership == Membership.LEAVE:
+                    room_end_token = "s%d" % (event.stream_ordering,)
+                    deferred_room_state = self.store.get_state_for_events(
+                        [event.event_id], None
+                    )
+                    deferred_room_state.addCallback(
+                        lambda states: states[event.event_id]
+                    )
+
+                (messages, token), current_state = yield preserve_context_over_deferred(
+                    defer.gatherResults(
+                        [
+                            preserve_fn(self.store.get_recent_events_for_room)(
+                                event.room_id,
+                                limit=limit,
+                                end_token=room_end_token,
+                            ),
+                            deferred_room_state,
+                        ]
+                    )
+                ).addErrback(unwrapFirstError)
+
+                messages = yield filter_events_for_client(
+                    self.store, user_id, messages
+                )
+
+                start_token = now_token.copy_and_replace("room_key", token[0])
+                end_token = now_token.copy_and_replace("room_key", token[1])
+                time_now = self.clock.time_msec()
+
+                d["messages"] = {
+                    "chunk": [
+                        serialize_event(m, time_now, as_client_event)
+                        for m in messages
+                    ],
+                    "start": start_token.to_string(),
+                    "end": end_token.to_string(),
+                }
+
+                d["state"] = [
+                    serialize_event(c, time_now, as_client_event)
+                    for c in current_state.values()
+                ]
+
+                account_data_events = []
+                tags = tags_by_room.get(event.room_id)
+                if tags:
+                    account_data_events.append({
+                        "type": "m.tag",
+                        "content": {"tags": tags},
+                    })
+
+                account_data = account_data_by_room.get(event.room_id, {})
+                for account_data_type, content in account_data.items():
+                    account_data_events.append({
+                        "type": account_data_type,
+                        "content": content,
+                    })
+
+                d["account_data"] = account_data_events
+            except:
+                logger.exception("Failed to get snapshot")
+
+        yield concurrently_execute(handle_room, room_list, 10)
+
+        account_data_events = []
+        for account_data_type, content in account_data.items():
+            account_data_events.append({
+                "type": account_data_type,
+                "content": content,
+            })
+
+        ret = {
+            "rooms": rooms_ret,
+            "presence": presence,
+            "account_data": account_data_events,
+            "receipts": receipt,
+            "end": now_token.to_string(),
+        }
+
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    def room_initial_sync(self, requester, room_id, pagin_config=None):
+        """Capture the a snapshot of a room. If user is currently a member of
+        the room this will be what is currently in the room. If the user left
+        the room this will be what was in the room when they left.
+
+        Args:
+            requester(Requester): The user to get a snapshot for.
+            room_id(str): The room to get a snapshot of.
+            pagin_config(synapse.streams.config.PaginationConfig):
+                The pagination config used to determine how many messages to
+                return.
+        Raises:
+            AuthError if the user wasn't in the room.
+        Returns:
+            A JSON serialisable dict with the snapshot of the room.
+        """
+
+        user_id = requester.user.to_string()
+
+        membership, member_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id,
+        )
+        is_peeking = member_event_id is None
+
+        if membership == Membership.JOIN:
+            result = yield self._room_initial_sync_joined(
+                user_id, room_id, pagin_config, membership, is_peeking
+            )
+        elif membership == Membership.LEAVE:
+            result = yield self._room_initial_sync_parted(
+                user_id, room_id, pagin_config, membership, member_event_id, is_peeking
+            )
+
+        account_data_events = []
+        tags = yield self.store.get_tags_for_room(user_id, room_id)
+        if tags:
+            account_data_events.append({
+                "type": "m.tag",
+                "content": {"tags": tags},
+            })
+
+        account_data = yield self.store.get_account_data_for_room(user_id, room_id)
+        for account_data_type, content in account_data.items():
+            account_data_events.append({
+                "type": account_data_type,
+                "content": content,
+            })
+
+        result["account_data"] = account_data_events
+
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
+                                  membership, member_event_id, is_peeking):
+        room_state = yield self.store.get_state_for_events(
+            [member_event_id], None
+        )
+
+        room_state = room_state[member_event_id]
+
+        limit = pagin_config.limit if pagin_config else None
+        if limit is None:
+            limit = 10
+
+        stream_token = yield self.store.get_stream_token_for_event(
+            member_event_id
+        )
+
+        messages, token = yield self.store.get_recent_events_for_room(
+            room_id,
+            limit=limit,
+            end_token=stream_token
+        )
+
+        messages = yield filter_events_for_client(
+            self.store, user_id, messages, is_peeking=is_peeking
+        )
+
+        start_token = StreamToken.START.copy_and_replace("room_key", token[0])
+        end_token = StreamToken.START.copy_and_replace("room_key", token[1])
+
+        time_now = self.clock.time_msec()
+
+        defer.returnValue({
+            "membership": membership,
+            "room_id": room_id,
+            "messages": {
+                "chunk": [serialize_event(m, time_now) for m in messages],
+                "start": start_token.to_string(),
+                "end": end_token.to_string(),
+            },
+            "state": [serialize_event(s, time_now) for s in room_state.values()],
+            "presence": [],
+            "receipts": [],
+        })
+
+    @defer.inlineCallbacks
+    def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
+                                  membership, is_peeking):
+        current_state = yield self.state.get_current_state(
+            room_id=room_id,
+        )
+
+        # TODO: These concurrently
+        time_now = self.clock.time_msec()
+        state = [
+            serialize_event(x, time_now)
+            for x in current_state.values()
+        ]
+
+        now_token = yield self.hs.get_event_sources().get_current_token()
+
+        limit = pagin_config.limit if pagin_config else None
+        if limit is None:
+            limit = 10
+
+        room_members = [
+            m for m in current_state.values()
+            if m.type == EventTypes.Member
+            and m.content["membership"] == Membership.JOIN
+        ]
+
+        presence_handler = self.hs.get_presence_handler()
+
+        @defer.inlineCallbacks
+        def get_presence():
+            states = yield presence_handler.get_states(
+                [m.user_id for m in room_members],
+                as_event=True,
+            )
+
+            defer.returnValue(states)
+
+        @defer.inlineCallbacks
+        def get_receipts():
+            receipts_handler = self.hs.get_handlers().receipts_handler
+            receipts = yield receipts_handler.get_receipts_for_room(
+                room_id,
+                now_token.receipt_key
+            )
+            defer.returnValue(receipts)
+
+        presence, receipts, (messages, token) = yield defer.gatherResults(
+            [
+                preserve_fn(get_presence)(),
+                preserve_fn(get_receipts)(),
+                preserve_fn(self.store.get_recent_events_for_room)(
+                    room_id,
+                    limit=limit,
+                    end_token=now_token.room_key,
+                )
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
+
+        messages = yield filter_events_for_client(
+            self.store, user_id, messages, is_peeking=is_peeking,
+        )
+
+        start_token = now_token.copy_and_replace("room_key", token[0])
+        end_token = now_token.copy_and_replace("room_key", token[1])
+
+        time_now = self.clock.time_msec()
+
+        ret = {
+            "room_id": room_id,
+            "messages": {
+                "chunk": [serialize_event(m, time_now) for m in messages],
+                "start": start_token.to_string(),
+                "end": end_token.to_string(),
+            },
+            "state": state,
+            "presence": presence,
+            "receipts": receipts,
+        }
+        if not is_peeking:
+            ret["membership"] = membership
+
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    def _check_in_room_or_world_readable(self, room_id, user_id):
+        try:
+            # check_user_was_in_room will return the most recent membership
+            # event for the user if:
+            #  * The user is a non-guest user, and was ever in the room
+            #  * The user is a guest user, and has joined the room
+            # else it will throw.
+            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+            defer.returnValue((member_event.membership, member_event.event_id))
+            return
+        except AuthError:
+            visibility = yield self.state_handler.get_current_state(
+                room_id, EventTypes.RoomHistoryVisibility, ""
+            )
+            if (
+                visibility and
+                visibility.content["history_visibility"] == "world_readable"
+            ):
+                defer.returnValue((Membership.JOIN, None))
+                return
+            raise AuthError(
+                403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+            )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 178209a209..30ea9630f7 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -21,14 +21,11 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
 from synapse.push.action_generator import ActionGenerator
-from synapse.streams.config import PaginationConfig
 from synapse.types import (
-    UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
+    UserID, RoomAlias, RoomStreamToken, get_domain_from_id
 )
-from synapse.util import unwrapFirstError
-from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
-from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.async import run_on_reactor, ReadWriteLock
+from synapse.util.logcontext import preserve_fn
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_events_for_client
 
@@ -49,7 +46,6 @@ class MessageHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = SnapshotCache()
 
         self.pagination_lock = ReadWriteLock()
 
@@ -392,377 +388,6 @@ class MessageHandler(BaseHandler):
             [serialize_event(c, now) for c in room_state.values()]
         )
 
-    def snapshot_all_rooms(self, user_id=None, pagin_config=None,
-                           as_client_event=True, include_archived=False):
-        """Retrieve a snapshot of all rooms the user is invited or has joined.
-
-        This snapshot may include messages for all rooms where the user is
-        joined, depending on the pagination config.
-
-        Args:
-            user_id (str): The ID of the user making the request.
-            pagin_config (synapse.api.streams.PaginationConfig): The pagination
-            config used to determine how many messages *PER ROOM* to return.
-            as_client_event (bool): True to get events in client-server format.
-            include_archived (bool): True to get rooms that the user has left
-        Returns:
-            A list of dicts with "room_id" and "membership" keys for all rooms
-            the user is currently invited or joined in on. Rooms where the user
-            is joined on, may return a "messages" key with messages, depending
-            on the specified PaginationConfig.
-        """
-        key = (
-            user_id,
-            pagin_config.from_token,
-            pagin_config.to_token,
-            pagin_config.direction,
-            pagin_config.limit,
-            as_client_event,
-            include_archived,
-        )
-        now_ms = self.clock.time_msec()
-        result = self.snapshot_cache.get(now_ms, key)
-        if result is not None:
-            return result
-
-        return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
-            user_id, pagin_config, as_client_event, include_archived
-        ))
-
-    @defer.inlineCallbacks
-    def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
-                            as_client_event=True, include_archived=False):
-
-        memberships = [Membership.INVITE, Membership.JOIN]
-        if include_archived:
-            memberships.append(Membership.LEAVE)
-
-        room_list = yield self.store.get_rooms_for_user_where_membership_is(
-            user_id=user_id, membership_list=memberships
-        )
-
-        user = UserID.from_string(user_id)
-
-        rooms_ret = []
-
-        now_token = yield self.hs.get_event_sources().get_current_token()
-
-        presence_stream = self.hs.get_event_sources().sources["presence"]
-        pagination_config = PaginationConfig(from_token=now_token)
-        presence, _ = yield presence_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("presence"), None
-        )
-
-        receipt_stream = self.hs.get_event_sources().sources["receipt"]
-        receipt, _ = yield receipt_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("receipt"), None
-        )
-
-        tags_by_room = yield self.store.get_tags_for_user(user_id)
-
-        account_data, account_data_by_room = (
-            yield self.store.get_account_data_for_user(user_id)
-        )
-
-        public_room_ids = yield self.store.get_public_room_ids()
-
-        limit = pagin_config.limit
-        if limit is None:
-            limit = 10
-
-        @defer.inlineCallbacks
-        def handle_room(event):
-            d = {
-                "room_id": event.room_id,
-                "membership": event.membership,
-                "visibility": (
-                    "public" if event.room_id in public_room_ids
-                    else "private"
-                ),
-            }
-
-            if event.membership == Membership.INVITE:
-                time_now = self.clock.time_msec()
-                d["inviter"] = event.sender
-
-                invite_event = yield self.store.get_event(event.event_id)
-                d["invite"] = serialize_event(invite_event, time_now, as_client_event)
-
-            rooms_ret.append(d)
-
-            if event.membership not in (Membership.JOIN, Membership.LEAVE):
-                return
-
-            try:
-                if event.membership == Membership.JOIN:
-                    room_end_token = now_token.room_key
-                    deferred_room_state = self.state_handler.get_current_state(
-                        event.room_id
-                    )
-                elif event.membership == Membership.LEAVE:
-                    room_end_token = "s%d" % (event.stream_ordering,)
-                    deferred_room_state = self.store.get_state_for_events(
-                        [event.event_id], None
-                    )
-                    deferred_room_state.addCallback(
-                        lambda states: states[event.event_id]
-                    )
-
-                (messages, token), current_state = yield preserve_context_over_deferred(
-                    defer.gatherResults(
-                        [
-                            preserve_fn(self.store.get_recent_events_for_room)(
-                                event.room_id,
-                                limit=limit,
-                                end_token=room_end_token,
-                            ),
-                            deferred_room_state,
-                        ]
-                    )
-                ).addErrback(unwrapFirstError)
-
-                messages = yield filter_events_for_client(
-                    self.store, user_id, messages
-                )
-
-                start_token = now_token.copy_and_replace("room_key", token[0])
-                end_token = now_token.copy_and_replace("room_key", token[1])
-                time_now = self.clock.time_msec()
-
-                d["messages"] = {
-                    "chunk": [
-                        serialize_event(m, time_now, as_client_event)
-                        for m in messages
-                    ],
-                    "start": start_token.to_string(),
-                    "end": end_token.to_string(),
-                }
-
-                d["state"] = [
-                    serialize_event(c, time_now, as_client_event)
-                    for c in current_state.values()
-                ]
-
-                account_data_events = []
-                tags = tags_by_room.get(event.room_id)
-                if tags:
-                    account_data_events.append({
-                        "type": "m.tag",
-                        "content": {"tags": tags},
-                    })
-
-                account_data = account_data_by_room.get(event.room_id, {})
-                for account_data_type, content in account_data.items():
-                    account_data_events.append({
-                        "type": account_data_type,
-                        "content": content,
-                    })
-
-                d["account_data"] = account_data_events
-            except:
-                logger.exception("Failed to get snapshot")
-
-        yield concurrently_execute(handle_room, room_list, 10)
-
-        account_data_events = []
-        for account_data_type, content in account_data.items():
-            account_data_events.append({
-                "type": account_data_type,
-                "content": content,
-            })
-
-        ret = {
-            "rooms": rooms_ret,
-            "presence": presence,
-            "account_data": account_data_events,
-            "receipts": receipt,
-            "end": now_token.to_string(),
-        }
-
-        defer.returnValue(ret)
-
-    @defer.inlineCallbacks
-    def room_initial_sync(self, requester, room_id, pagin_config=None):
-        """Capture the a snapshot of a room. If user is currently a member of
-        the room this will be what is currently in the room. If the user left
-        the room this will be what was in the room when they left.
-
-        Args:
-            requester(Requester): The user to get a snapshot for.
-            room_id(str): The room to get a snapshot of.
-            pagin_config(synapse.streams.config.PaginationConfig):
-                The pagination config used to determine how many messages to
-                return.
-        Raises:
-            AuthError if the user wasn't in the room.
-        Returns:
-            A JSON serialisable dict with the snapshot of the room.
-        """
-
-        user_id = requester.user.to_string()
-
-        membership, member_event_id = yield self._check_in_room_or_world_readable(
-            room_id, user_id,
-        )
-        is_peeking = member_event_id is None
-
-        if membership == Membership.JOIN:
-            result = yield self._room_initial_sync_joined(
-                user_id, room_id, pagin_config, membership, is_peeking
-            )
-        elif membership == Membership.LEAVE:
-            result = yield self._room_initial_sync_parted(
-                user_id, room_id, pagin_config, membership, member_event_id, is_peeking
-            )
-
-        account_data_events = []
-        tags = yield self.store.get_tags_for_room(user_id, room_id)
-        if tags:
-            account_data_events.append({
-                "type": "m.tag",
-                "content": {"tags": tags},
-            })
-
-        account_data = yield self.store.get_account_data_for_room(user_id, room_id)
-        for account_data_type, content in account_data.items():
-            account_data_events.append({
-                "type": account_data_type,
-                "content": content,
-            })
-
-        result["account_data"] = account_data_events
-
-        defer.returnValue(result)
-
-    @defer.inlineCallbacks
-    def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
-                                  membership, member_event_id, is_peeking):
-        room_state = yield self.store.get_state_for_events(
-            [member_event_id], None
-        )
-
-        room_state = room_state[member_event_id]
-
-        limit = pagin_config.limit if pagin_config else None
-        if limit is None:
-            limit = 10
-
-        stream_token = yield self.store.get_stream_token_for_event(
-            member_event_id
-        )
-
-        messages, token = yield self.store.get_recent_events_for_room(
-            room_id,
-            limit=limit,
-            end_token=stream_token
-        )
-
-        messages = yield filter_events_for_client(
-            self.store, user_id, messages, is_peeking=is_peeking
-        )
-
-        start_token = StreamToken.START.copy_and_replace("room_key", token[0])
-        end_token = StreamToken.START.copy_and_replace("room_key", token[1])
-
-        time_now = self.clock.time_msec()
-
-        defer.returnValue({
-            "membership": membership,
-            "room_id": room_id,
-            "messages": {
-                "chunk": [serialize_event(m, time_now) for m in messages],
-                "start": start_token.to_string(),
-                "end": end_token.to_string(),
-            },
-            "state": [serialize_event(s, time_now) for s in room_state.values()],
-            "presence": [],
-            "receipts": [],
-        })
-
-    @defer.inlineCallbacks
-    def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
-                                  membership, is_peeking):
-        current_state = yield self.state.get_current_state(
-            room_id=room_id,
-        )
-
-        # TODO: These concurrently
-        time_now = self.clock.time_msec()
-        state = [
-            serialize_event(x, time_now)
-            for x in current_state.values()
-        ]
-
-        now_token = yield self.hs.get_event_sources().get_current_token()
-
-        limit = pagin_config.limit if pagin_config else None
-        if limit is None:
-            limit = 10
-
-        room_members = [
-            m for m in current_state.values()
-            if m.type == EventTypes.Member
-            and m.content["membership"] == Membership.JOIN
-        ]
-
-        presence_handler = self.hs.get_presence_handler()
-
-        @defer.inlineCallbacks
-        def get_presence():
-            states = yield presence_handler.get_states(
-                [m.user_id for m in room_members],
-                as_event=True,
-            )
-
-            defer.returnValue(states)
-
-        @defer.inlineCallbacks
-        def get_receipts():
-            receipts_handler = self.hs.get_handlers().receipts_handler
-            receipts = yield receipts_handler.get_receipts_for_room(
-                room_id,
-                now_token.receipt_key
-            )
-            defer.returnValue(receipts)
-
-        presence, receipts, (messages, token) = yield defer.gatherResults(
-            [
-                preserve_fn(get_presence)(),
-                preserve_fn(get_receipts)(),
-                preserve_fn(self.store.get_recent_events_for_room)(
-                    room_id,
-                    limit=limit,
-                    end_token=now_token.room_key,
-                )
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError)
-
-        messages = yield filter_events_for_client(
-            self.store, user_id, messages, is_peeking=is_peeking,
-        )
-
-        start_token = now_token.copy_and_replace("room_key", token[0])
-        end_token = now_token.copy_and_replace("room_key", token[1])
-
-        time_now = self.clock.time_msec()
-
-        ret = {
-            "room_id": room_id,
-            "messages": {
-                "chunk": [serialize_event(m, time_now) for m in messages],
-                "start": start_token.to_string(),
-                "end": end_token.to_string(),
-            },
-            "state": state,
-            "presence": presence,
-            "receipts": receipts,
-        }
-        if not is_peeking:
-            ret["membership"] = membership
-
-        defer.returnValue(ret)
-
     @measure_func("_create_new_client_event")
     @defer.inlineCallbacks
     def _create_new_client_event(self, builder, prev_event_ids=None):
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5a533682c5..b04aea0110 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -125,6 +125,8 @@ class RoomListHandler(BaseHandler):
             if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
         ]
 
+        total_room_count = len(rooms_to_scan)
+
         if since_token:
             # Filter out rooms we've already returned previously
             # `since_token.current_limit` is the index of the last room we
@@ -188,6 +190,7 @@ class RoomListHandler(BaseHandler):
 
         results = {
             "chunk": chunk,
+            "total_room_count_estimate": total_room_count,
         }
 
         if since_token:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0548b81c34..08313417b2 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,10 +16,9 @@
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import preserve_fn
 from synapse.util.metrics import Measure
+from synapse.util.wheel_timer import WheelTimer
 from synapse.types import UserID, get_domain_from_id
 
 import logging
@@ -35,6 +34,13 @@ logger = logging.getLogger(__name__)
 RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
 
 
+# How often we expect remote servers to resend us presence.
+FEDERATION_TIMEOUT = 60 * 1000
+
+# How often to resend typing across federation.
+FEDERATION_PING_INTERVAL = 40 * 1000
+
+
 class TypingHandler(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -44,7 +50,10 @@ class TypingHandler(object):
         self.notifier = hs.get_notifier()
         self.state = hs.get_state_handler()
 
+        self.hs = hs
+
         self.clock = hs.get_clock()
+        self.wheel_timer = WheelTimer(bucket_size=5000)
 
         self.federation = hs.get_replication_layer()
 
@@ -53,7 +62,7 @@ class TypingHandler(object):
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
         self._member_typing_until = {}  # clock time we expect to stop
-        self._member_typing_timer = {}  # deferreds to manage theabove
+        self._member_last_federation_poke = {}
 
         # map room IDs to serial numbers
         self._room_serials = {}
@@ -61,12 +70,41 @@ class TypingHandler(object):
         # map room IDs to sets of users currently typing
         self._room_typing = {}
 
-    def tearDown(self):
-        """Cancels all the pending timers.
-        Normally this shouldn't be needed, but it's required from unit tests
-        to avoid a "Reactor was unclean" warning."""
-        for t in self._member_typing_timer.values():
-            self.clock.cancel_call_later(t)
+        self.clock.looping_call(
+            self._handle_timeouts,
+            5000,
+        )
+
+    def _handle_timeouts(self):
+        logger.info("Checking for typing timeouts")
+
+        now = self.clock.time_msec()
+
+        members = set(self.wheel_timer.fetch(now))
+
+        for member in members:
+            if not self.is_typing(member):
+                # Nothing to do if they're no longer typing
+                continue
+
+            until = self._member_typing_until.get(member, None)
+            if not until or until < now:
+                logger.info("Timing out typing for: %s", member.user_id)
+                preserve_fn(self._stopped_typing)(member)
+                continue
+
+            # Check if we need to resend a keep alive over federation for this
+            # user.
+            if self.hs.is_mine_id(member.user_id):
+                last_fed_poke = self._member_last_federation_poke.get(member, None)
+                if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
+                    preserve_fn(self._push_remote)(
+                        member=member,
+                        typing=True
+                    )
+
+    def is_typing(self, member):
+        return member.user_id in self._room_typing.get(member.room_id, [])
 
     @defer.inlineCallbacks
     def started_typing(self, target_user, auth_user, room_id, timeout):
@@ -85,23 +123,17 @@ class TypingHandler(object):
             "%s has started typing in %s", target_user_id, room_id
         )
 
-        until = self.clock.time_msec() + timeout
         member = RoomMember(room_id=room_id, user_id=target_user_id)
 
-        was_present = member in self._member_typing_until
-
-        if member in self._member_typing_timer:
-            self.clock.cancel_call_later(self._member_typing_timer[member])
+        was_present = member.user_id in self._room_typing.get(room_id, set())
 
-        def _cb():
-            logger.debug(
-                "%s has timed out in %s", target_user.to_string(), room_id
-            )
-            self._stopped_typing(member)
+        now = self.clock.time_msec()
+        self._member_typing_until[member] = now + timeout
 
-        self._member_typing_until[member] = until
-        self._member_typing_timer[member] = self.clock.call_later(
-            timeout / 1000.0, _cb
+        self.wheel_timer.insert(
+            now=now,
+            obj=member,
+            then=now + timeout,
         )
 
         if was_present:
@@ -109,8 +141,7 @@ class TypingHandler(object):
             defer.returnValue(None)
 
         yield self._push_update(
-            room_id=room_id,
-            user_id=target_user_id,
+            member=member,
             typing=True,
         )
 
@@ -133,10 +164,6 @@ class TypingHandler(object):
 
         member = RoomMember(room_id=room_id, user_id=target_user_id)
 
-        if member in self._member_typing_timer:
-            self.clock.cancel_call_later(self._member_typing_timer[member])
-            del self._member_typing_timer[member]
-
         yield self._stopped_typing(member)
 
     @defer.inlineCallbacks
@@ -148,57 +175,61 @@ class TypingHandler(object):
 
     @defer.inlineCallbacks
     def _stopped_typing(self, member):
-        if member not in self._member_typing_until:
+        if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
             defer.returnValue(None)
 
+        self._member_typing_until.pop(member, None)
+        self._member_last_federation_poke.pop(member, None)
+
         yield self._push_update(
-            room_id=member.room_id,
-            user_id=member.user_id,
+            member=member,
             typing=False,
         )
 
-        del self._member_typing_until[member]
-
-        if member in self._member_typing_timer:
-            # Don't cancel it - either it already expired, or the real
-            # stopped_typing() will cancel it
-            del self._member_typing_timer[member]
+    @defer.inlineCallbacks
+    def _push_update(self, member, typing):
+        if self.hs.is_mine_id(member.user_id):
+            # Only send updates for changes to our own users.
+            yield self._push_remote(member, typing)
+
+        self._push_update_local(
+            member=member,
+            typing=typing
+        )
 
     @defer.inlineCallbacks
-    def _push_update(self, room_id, user_id, typing):
-        users = yield self.state.get_current_user_in_room(room_id)
-        domains = set(get_domain_from_id(u) for u in users)
+    def _push_remote(self, member, typing):
+        users = yield self.state.get_current_user_in_room(member.room_id)
+        self._member_last_federation_poke[member] = self.clock.time_msec()
+
+        now = self.clock.time_msec()
+        self.wheel_timer.insert(
+            now=now,
+            obj=member,
+            then=now + FEDERATION_PING_INTERVAL,
+        )
 
-        deferreds = []
-        for domain in domains:
-            if domain == self.server_name:
-                preserve_fn(self._push_update_local)(
-                    room_id=room_id,
-                    user_id=user_id,
-                    typing=typing
-                )
-            else:
-                deferreds.append(preserve_fn(self.federation.send_edu)(
+        for domain in set(get_domain_from_id(u) for u in users):
+            if domain != self.server_name:
+                self.federation.send_edu(
                     destination=domain,
                     edu_type="m.typing",
                     content={
-                        "room_id": room_id,
-                        "user_id": user_id,
+                        "room_id": member.room_id,
+                        "user_id": member.user_id,
                         "typing": typing,
                     },
-                    key=(room_id, user_id),
-                ))
-
-        yield preserve_context_over_deferred(
-            defer.DeferredList(deferreds, consumeErrors=True)
-        )
+                    key=member,
+                )
 
     @defer.inlineCallbacks
     def _recv_edu(self, origin, content):
         room_id = content["room_id"]
         user_id = content["user_id"]
 
+        member = RoomMember(user_id=user_id, room_id=room_id)
+
         # Check that the string is a valid user id
         user = UserID.from_string(user_id)
 
@@ -213,26 +244,32 @@ class TypingHandler(object):
         domains = set(get_domain_from_id(u) for u in users)
 
         if self.server_name in domains:
+            logger.info("Got typing update from %s: %r", user_id, content)
+            now = self.clock.time_msec()
+            self._member_typing_until[member] = now + FEDERATION_TIMEOUT
+            self.wheel_timer.insert(
+                now=now,
+                obj=member,
+                then=now + FEDERATION_TIMEOUT,
+            )
             self._push_update_local(
-                room_id=room_id,
-                user_id=user_id,
+                member=member,
                 typing=content["typing"]
             )
 
-    def _push_update_local(self, room_id, user_id, typing):
-        room_set = self._room_typing.setdefault(room_id, set())
+    def _push_update_local(self, member, typing):
+        room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
-            room_set.add(user_id)
+            room_set.add(member.user_id)
         else:
-            room_set.discard(user_id)
+            room_set.discard(member.user_id)
 
         self._latest_room_serial += 1
-        self._room_serials[room_id] = self._latest_room_serial
+        self._room_serials[member.room_id] = self._latest_room_serial
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_event(
-                "typing_key", self._latest_room_serial, rooms=[room_id]
-            )
+        self.notifier.on_new_event(
+            "typing_key", self._latest_room_serial, rooms=[member.room_id]
+        )
 
     def get_all_typing_updates(self, last_id, current_id):
         # TODO: Work out a way to do this without scanning the entire state.