summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/federation/federation_client.py32
-rw-r--r--synapse/federation/federation_server.py43
-rw-r--r--synapse/federation/transport/client.py70
-rw-r--r--synapse/federation/transport/server.py20
-rw-r--r--synapse/handlers/__init__.py2
-rw-r--r--synapse/handlers/auth.py113
-rw-r--r--synapse/handlers/federation.py25
-rw-r--r--synapse/handlers/login.py83
-rw-r--r--synapse/handlers/message.py50
-rw-r--r--synapse/handlers/presence.py89
-rw-r--r--synapse/handlers/register.py10
-rw-r--r--synapse/handlers/sync.py21
-rw-r--r--synapse/push/pusherpool.py11
-rw-r--r--synapse/rest/client/v1/login.py5
-rw-r--r--synapse/rest/client/v2_alpha/account.py8
-rw-r--r--synapse/rest/client/v2_alpha/keys.py100
-rw-r--r--synapse/rest/client/v2_alpha/register.py3
-rw-r--r--synapse/state.py12
-rw-r--r--synapse/storage/_base.py213
-rw-r--r--synapse/storage/directory.py3
-rw-r--r--synapse/storage/event_federation.py3
-rw-r--r--synapse/storage/keys.py3
-rw-r--r--synapse/storage/presence.py33
-rw-r--r--synapse/storage/push_rule.py3
-rw-r--r--synapse/storage/receipts.py3
-rw-r--r--synapse/storage/registration.py15
-rw-r--r--synapse/storage/room.py3
-rw-r--r--synapse/storage/roommember.py3
-rw-r--r--synapse/storage/state.py326
-rw-r--r--synapse/storage/stream.py6
-rw-r--r--synapse/storage/transactions.py3
-rw-r--r--synapse/util/caches/__init__.py27
-rw-r--r--synapse/util/caches/descriptors.py377
-rw-r--r--synapse/util/caches/dictionary_cache.py103
-rw-r--r--synapse/util/caches/expiringcache.py (renamed from synapse/util/expiringcache.py)0
-rw-r--r--synapse/util/caches/lrucache.py (renamed from synapse/util/lrucache.py)0
-rw-r--r--tests/storage/test__base.py2
-rw-r--r--tests/test_state.py2
-rw-r--r--tests/util/test_dict_cache.py101
-rw-r--r--tests/util/test_lrucache.py4
40 files changed, 1386 insertions, 544 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7736d14fb5..f5e346cdbc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -23,7 +23,7 @@ from synapse.api.errors import (
     CodeMessageException, HttpResponseException, SynapseError,
 )
 from synapse.util import unwrapFirstError
-from synapse.util.expiringcache import ExpiringCache
+from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.events import FrozenEvent
 import synapse.metrics
@@ -134,6 +134,36 @@ class FederationClient(FederationBase):
             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
         )
 
+    @log_function
+    def query_client_keys(self, destination, content):
+        """Query device keys for a device hosted on a remote server.
+
+        Args:
+            destination (str): Domain name of the remote homeserver
+            content (dict): The query content.
+
+        Returns:
+            a Deferred which will eventually yield a JSON object from the
+            response
+        """
+        sent_queries_counter.inc("client_device_keys")
+        return self.transport_layer.query_client_keys(destination, content)
+
+    @log_function
+    def claim_client_keys(self, destination, content):
+        """Claims one-time keys for a device hosted on a remote server.
+
+        Args:
+            destination (str): Domain name of the remote homeserver
+            content (dict): The query content.
+
+        Returns:
+            a Deferred which will eventually yield a JSON object from the
+            response
+        """
+        sent_queries_counter.inc("client_one_time_keys")
+        return self.transport_layer.claim_client_keys(destination, content)
+
     @defer.inlineCallbacks
     @log_function
     def backfill(self, dest, context, limit, extremities):
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index cd79e23f4b..725c6f3fa5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
 
 from synapse.crypto.event_signing import compute_event_signature
 
+import simplejson as json
 import logging
 
 
@@ -314,6 +315,48 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     @log_function
+    def on_query_client_keys(self, origin, content):
+        query = []
+        for user_id, device_ids in content.get("device_keys", {}).items():
+            if not device_ids:
+                query.append((user_id, None))
+            else:
+                for device_id in device_ids:
+                    query.append((user_id, device_id))
+
+        results = yield self.store.get_e2e_device_keys(query)
+
+        json_result = {}
+        for user_id, device_keys in results.items():
+            for device_id, json_bytes in device_keys.items():
+                json_result.setdefault(user_id, {})[device_id] = json.loads(
+                    json_bytes
+                )
+
+        defer.returnValue({"device_keys": json_result})
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_claim_client_keys(self, origin, content):
+        query = []
+        for user_id, device_keys in content.get("one_time_keys", {}).items():
+            for device_id, algorithm in device_keys.items():
+                query.append((user_id, device_id, algorithm))
+
+        results = yield self.store.claim_e2e_one_time_keys(query)
+
+        json_result = {}
+        for user_id, device_keys in results.items():
+            for device_id, keys in device_keys.items():
+                for key_id, json_bytes in keys.items():
+                    json_result.setdefault(user_id, {})[device_id] = {
+                        key_id: json.loads(json_bytes)
+                    }
+
+        defer.returnValue({"one_time_keys": json_result})
+
+    @defer.inlineCallbacks
+    @log_function
     def on_get_missing_events(self, origin, room_id, earliest_events,
                               latest_events, limit, min_depth):
         missing_events = yield self.handler.on_get_missing_events(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 610a4c3163..ced703364b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -224,6 +224,76 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def query_client_keys(self, destination, query_content):
+        """Query the device keys for a list of user ids hosted on a remote
+        server.
+
+        Request:
+            {
+              "device_keys": {
+                "<user_id>": ["<device_id>"]
+            } }
+
+        Response:
+            {
+              "device_keys": {
+                "<user_id>": {
+                  "<device_id>": {...}
+            } } }
+
+        Args:
+            destination(str): The server to query.
+            query_content(dict): The user ids to query.
+        Returns:
+            A dict containg the device keys.
+        """
+        path = PREFIX + "/user/keys/query"
+
+        content = yield self.client.post_json(
+            destination=destination,
+            path=path,
+            data=query_content,
+        )
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
+    def claim_client_keys(self, destination, query_content):
+        """Claim one-time keys for a list of devices hosted on a remote server.
+
+        Request:
+            {
+              "one_time_keys": {
+                "<user_id>": {
+                    "<device_id>": "<algorithm>"
+            } } }
+
+        Response:
+            {
+              "device_keys": {
+                "<user_id>": {
+                  "<device_id>": {
+                    "<algorithm>:<key_id>": "<key_base64>"
+            } } } }
+
+        Args:
+            destination(str): The server to query.
+            query_content(dict): The user ids to query.
+        Returns:
+            A dict containg the one-time keys.
+        """
+
+        path = PREFIX + "/user/keys/claim"
+
+        content = yield self.client.post_json(
+            destination=destination,
+            path=path,
+            data=query_content,
+        )
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
     def get_missing_events(self, destination, room_id, earliest_events,
                            latest_events, limit, min_depth):
         path = PREFIX + "/get_missing_events/%s" % (room_id,)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index bad93c6b2f..36f250e1a3 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
         defer.returnValue((200, content))
 
 
+class FederationClientKeysQueryServlet(BaseFederationServlet):
+    PATH = "/user/keys/query"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query):
+        response = yield self.handler.on_query_client_keys(origin, content)
+        defer.returnValue((200, response))
+
+
+class FederationClientKeysClaimServlet(BaseFederationServlet):
+    PATH = "/user/keys/claim"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query):
+        response = yield self.handler.on_claim_client_keys(origin, content)
+        defer.returnValue((200, response))
+
+
 class FederationQueryAuthServlet(BaseFederationServlet):
     PATH = "/query_auth/([^/]*)/([^/]*)"
 
@@ -373,4 +391,6 @@ SERVLET_CLASSES = (
     FederationQueryAuthServlet,
     FederationGetMissingEventsServlet,
     FederationEventAuthServlet,
+    FederationClientKeysQueryServlet,
+    FederationClientKeysClaimServlet,
 )
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index dc5b6ef79d..8725c3c420 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -22,7 +22,6 @@ from .room import (
 from .message import MessageHandler
 from .events import EventStreamHandler, EventHandler
 from .federation import FederationHandler
-from .login import LoginHandler
 from .profile import ProfileHandler
 from .presence import PresenceHandler
 from .directory import DirectoryHandler
@@ -54,7 +53,6 @@ class Handlers(object):
         self.profile_handler = ProfileHandler(hs)
         self.presence_handler = PresenceHandler(hs)
         self.room_list_handler = RoomListHandler(hs)
-        self.login_handler = LoginHandler(hs)
         self.directory_handler = DirectoryHandler(hs)
         self.typing_notification_handler = TypingNotificationHandler(hs)
         self.admin_handler = AdminHandler(hs)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1ecf7fef17..98d99dd0a8 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -47,17 +47,24 @@ class AuthHandler(BaseHandler):
         self.sessions = {}
 
     @defer.inlineCallbacks
-    def check_auth(self, flows, clientdict, clientip=None):
+    def check_auth(self, flows, clientdict, clientip):
         """
         Takes a dictionary sent by the client in the login / registration
         protocol and handles the login flow.
 
+        As a side effect, this function fills in the 'creds' key on the user's
+        session with a map, which maps each auth-type (str) to the relevant
+        identity authenticated by that auth-type (mostly str, but for captcha, bool).
+
         Args:
-            flows: list of list of stages
-            authdict: The dictionary from the client root level, not the
-                      'auth' key: this method prompts for auth if none is sent.
+            flows (list): A list of login flows. Each flow is an ordered list of
+                          strings representing auth-types. At least one full
+                          flow must be completed in order for auth to be successful.
+            clientdict: The dictionary from the client root level, not the
+                        'auth' key: this method prompts for auth if none is sent.
+            clientip (str): The IP address of the client.
         Returns:
-            A tuple of authed, dict, dict where authed is true if the client
+            A tuple of (authed, dict, dict) where authed is true if the client
             has successfully completed an auth flow. If it is true, the first
             dict contains the authenticated credentials of each stage.
 
@@ -75,7 +82,7 @@ class AuthHandler(BaseHandler):
             del clientdict['auth']
             if 'session' in authdict:
                 sid = authdict['session']
-        sess = self._get_session_info(sid)
+        session = self._get_session_info(sid)
 
         if len(clientdict) > 0:
             # This was designed to allow the client to omit the parameters
@@ -87,20 +94,19 @@ class AuthHandler(BaseHandler):
             # on a home server.
             # Revisit: Assumimg the REST APIs do sensible validation, the data
             # isn't arbintrary.
-            sess['clientdict'] = clientdict
-            self._save_session(sess)
-            pass
-        elif 'clientdict' in sess:
-            clientdict = sess['clientdict']
+            session['clientdict'] = clientdict
+            self._save_session(session)
+        elif 'clientdict' in session:
+            clientdict = session['clientdict']
 
         if not authdict:
             defer.returnValue(
-                (False, self._auth_dict_for_flows(flows, sess), clientdict)
+                (False, self._auth_dict_for_flows(flows, session), clientdict)
             )
 
-        if 'creds' not in sess:
-            sess['creds'] = {}
-        creds = sess['creds']
+        if 'creds' not in session:
+            session['creds'] = {}
+        creds = session['creds']
 
         # check auth type currently being presented
         if 'type' in authdict:
@@ -109,15 +115,15 @@ class AuthHandler(BaseHandler):
             result = yield self.checkers[authdict['type']](authdict, clientip)
             if result:
                 creds[authdict['type']] = result
-                self._save_session(sess)
+                self._save_session(session)
 
         for f in flows:
             if len(set(f) - set(creds.keys())) == 0:
                 logger.info("Auth completed with creds: %r", creds)
-                self._remove_session(sess)
+                self._remove_session(session)
                 defer.returnValue((True, creds, clientdict))
 
-        ret = self._auth_dict_for_flows(flows, sess)
+        ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
         defer.returnValue((False, ret, clientdict))
 
@@ -151,22 +157,13 @@ class AuthHandler(BaseHandler):
         if "user" not in authdict or "password" not in authdict:
             raise LoginError(400, "", Codes.MISSING_PARAM)
 
-        user = authdict["user"]
+        user_id = authdict["user"]
         password = authdict["password"]
-        if not user.startswith('@'):
-            user = UserID.create(user, self.hs.hostname).to_string()
+        if not user_id.startswith('@'):
+            user_id = UserID.create(user_id, self.hs.hostname).to_string()
 
-        user_info = yield self.store.get_user_by_id(user_id=user)
-        if not user_info:
-            logger.warn("Attempted to login as %s but they do not exist", user)
-            raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
-        stored_hash = user_info["password_hash"]
-        if bcrypt.checkpw(password, stored_hash):
-            defer.returnValue(user)
-        else:
-            logger.warn("Failed password login for user %s", user)
-            raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+        self._check_password(user_id, password)
+        defer.returnValue(user_id)
 
     @defer.inlineCallbacks
     def _check_recaptcha(self, authdict, clientip):
@@ -270,6 +267,58 @@ class AuthHandler(BaseHandler):
 
         return self.sessions[session_id]
 
+    @defer.inlineCallbacks
+    def login_with_password(self, user_id, password):
+        """
+        Authenticates the user with their username and password.
+
+        Used only by the v1 login API.
+
+        Args:
+            user_id (str): User ID
+            password (str): Password
+        Returns:
+            The access token for the user's session.
+        Raises:
+            StoreError if there was a problem storing the token.
+            LoginError if there was an authentication problem.
+        """
+        self._check_password(user_id, password)
+
+        reg_handler = self.hs.get_handlers().registration_handler
+        access_token = reg_handler.generate_token(user_id)
+        logger.info("Adding token %s for user %s", access_token, user_id)
+        yield self.store.add_access_token_to_user(user_id, access_token)
+        defer.returnValue(access_token)
+
+    def _check_password(self, user_id, password):
+        """Checks that user_id has passed password, raises LoginError if not."""
+        user_info = yield self.store.get_user_by_id(user_id=user_id)
+        if not user_info:
+            logger.warn("Attempted to login as %s but they do not exist", user_id)
+            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+        stored_hash = user_info["password_hash"]
+        if not bcrypt.checkpw(password, stored_hash):
+            logger.warn("Failed password login for user %s", user_id)
+            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+    @defer.inlineCallbacks
+    def set_password(self, user_id, newpassword):
+        password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
+
+        yield self.store.user_set_password_hash(user_id, password_hash)
+        yield self.store.user_delete_access_tokens(user_id)
+        yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
+        yield self.store.flush_user(user_id)
+
+    @defer.inlineCallbacks
+    def add_threepid(self, user_id, medium, address, validated_at):
+        yield self.store.user_add_threepid(
+            user_id, medium, address, validated_at,
+            self.hs.get_clock().time_msec()
+        )
+
     def _save_session(self, session):
         # TODO: Persistent storage
         logger.debug("Saving session %s", session)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f7155fd8d3..1e3dccf5a8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -229,15 +229,15 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _filter_events_for_server(self, server_name, room_id, events):
-        states = yield self.store.get_state_for_events(
-            room_id, [e.event_id for e in events],
+        event_to_state = yield self.store.get_state_for_events(
+            room_id, frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, None),
+            )
         )
 
-        events_and_states = zip(events, states)
-
-        def redact_disallowed(event_and_state):
-            event, state = event_and_state
-
+        def redact_disallowed(event, state):
             if not state:
                 return event
 
@@ -271,11 +271,10 @@ class FederationHandler(BaseHandler):
 
             return event
 
-        res = map(redact_disallowed, events_and_states)
-
-        logger.info("_filter_events_for_server %r", res)
-
-        defer.returnValue(res)
+        defer.returnValue([
+            redact_disallowed(e, event_to_state[e.event_id])
+            for e in events
+        ])
 
     @log_function
     @defer.inlineCallbacks
@@ -503,7 +502,7 @@ class FederationHandler(BaseHandler):
         event_ids = list(extremities.keys())
 
         states = yield defer.gatherResults([
-            self.state_handler.resolve_state_groups([e])
+            self.state_handler.resolve_state_groups(room_id, [e])
             for e in event_ids
         ])
         states = dict(zip(event_ids, [s[1] for s in states]))
diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py
deleted file mode 100644
index 91d87d503d..0000000000
--- a/synapse/handlers/login.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-from ._base import BaseHandler
-from synapse.api.errors import LoginError, Codes
-
-import bcrypt
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-class LoginHandler(BaseHandler):
-
-    def __init__(self, hs):
-        super(LoginHandler, self).__init__(hs)
-        self.hs = hs
-
-    @defer.inlineCallbacks
-    def login(self, user, password):
-        """Login as the specified user with the specified password.
-
-        Args:
-            user (str): The user ID.
-            password (str): The password.
-        Returns:
-            The newly allocated access token.
-        Raises:
-            StoreError if there was a problem storing the token.
-            LoginError if there was an authentication problem.
-        """
-        # TODO do this better, it can't go in __init__ else it cyclic loops
-        if not hasattr(self, "reg_handler"):
-            self.reg_handler = self.hs.get_handlers().registration_handler
-
-        # pull out the hash for this user if they exist
-        user_info = yield self.store.get_user_by_id(user_id=user)
-        if not user_info:
-            logger.warn("Attempted to login as %s but they do not exist", user)
-            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
-        stored_hash = user_info["password_hash"]
-        if bcrypt.checkpw(password, stored_hash):
-            # generate an access token and store it.
-            token = self.reg_handler._generate_token(user)
-            logger.info("Adding token %s for user %s", token, user)
-            yield self.store.add_access_token_to_user(user, token)
-            defer.returnValue(token)
-        else:
-            logger.warn("Failed password login for user %s", user)
-            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
-    @defer.inlineCallbacks
-    def set_password(self, user_id, newpassword, token_id=None):
-        password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
-
-        yield self.store.user_set_password_hash(user_id, password_hash)
-        yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
-        yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
-            user_id, token_id
-        )
-        yield self.store.flush_user(user_id)
-
-    @defer.inlineCallbacks
-    def add_threepid(self, user_id, medium, address, validated_at):
-        yield self.store.user_add_threepid(
-            user_id, medium, address, validated_at,
-            self.hs.get_clock().time_msec()
-        )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9d6d4f0978..f12465fa2c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -137,15 +137,15 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _filter_events_for_client(self, user_id, room_id, events):
-        states = yield self.store.get_state_for_events(
-            room_id, [e.event_id for e in events],
+        event_id_to_state = yield self.store.get_state_for_events(
+            room_id, frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, user_id),
+            )
         )
 
-        events_and_states = zip(events, states)
-
-        def allowed(event_and_state):
-            event, state = event_and_state
-
+        def allowed(event, state):
             if event.type == EventTypes.RoomHistoryVisibility:
                 return True
 
@@ -175,10 +175,10 @@ class MessageHandler(BaseHandler):
 
             return True
 
-        events_and_states = filter(allowed, events_and_states)
         defer.returnValue([
-            ev
-            for ev, _ in events_and_states
+            event
+            for event in events
+            if allowed(event, event_id_to_state[event.event_id])
         ])
 
     @defer.inlineCallbacks
@@ -401,10 +401,14 @@ class MessageHandler(BaseHandler):
             except:
                 logger.exception("Failed to get snapshot")
 
-        yield defer.gatherResults(
-            [handle_room(e) for e in room_list],
-            consumeErrors=True
-        ).addErrback(unwrapFirstError)
+        # Only do N rooms at once
+        n = 5
+        d_list = [handle_room(e) for e in room_list]
+        for i in range(0, len(d_list), n):
+            yield defer.gatherResults(
+                d_list[i:i + n],
+                consumeErrors=True
+            ).addErrback(unwrapFirstError)
 
         ret = {
             "rooms": rooms_ret,
@@ -456,20 +460,14 @@ class MessageHandler(BaseHandler):
 
         @defer.inlineCallbacks
         def get_presence():
-            presence_defs = yield defer.DeferredList(
-                [
-                    presence_handler.get_state(
-                        target_user=UserID.from_string(m.user_id),
-                        auth_user=auth_user,
-                        as_event=True,
-                        check_auth=False,
-                    )
-                    for m in room_members
-                ],
-                consumeErrors=True,
+            states = yield presence_handler.get_states(
+                target_users=[UserID.from_string(m.user_id) for m in room_members],
+                auth_user=auth_user,
+                as_event=True,
+                check_auth=False,
             )
 
-            defer.returnValue([p for success, p in presence_defs if success])
+            defer.returnValue(states.values())
 
         receipts_handler = self.hs.get_handlers().receipts_handler
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 341a516da2..e91e81831e 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -192,6 +192,20 @@ class PresenceHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
+        """Get the current presence state of the given user.
+
+        Args:
+            target_user (UserID): The user whose presence we want
+            auth_user (UserID): The user requesting the presence, used for
+                checking if said user is allowed to see the persence of the
+                `target_user`
+            as_event (bool): Format the return as an event or not?
+            check_auth (bool): Perform the auth checks or not?
+
+        Returns:
+            dict: The presence state of the `target_user`, whose format depends
+            on the `as_event` argument.
+        """
         if self.hs.is_mine(target_user):
             if check_auth:
                 visible = yield self.is_presence_visible(
@@ -233,6 +247,81 @@ class PresenceHandler(BaseHandler):
             defer.returnValue(state)
 
     @defer.inlineCallbacks
+    def get_states(self, target_users, auth_user, as_event=False, check_auth=True):
+        """A batched version of the `get_state` method that accepts a list of
+        `target_users`
+
+        Args:
+            target_users (list): The list of UserID's whose presence we want
+            auth_user (UserID): The user requesting the presence, used for
+                checking if said user is allowed to see the persence of the
+                `target_users`
+            as_event (bool): Format the return as an event or not?
+            check_auth (bool): Perform the auth checks or not?
+
+        Returns:
+            dict: A mapping from user -> presence_state
+        """
+        local_users, remote_users = partitionbool(
+            target_users,
+            lambda u: self.hs.is_mine(u)
+        )
+
+        if check_auth:
+            for user in local_users:
+                visible = yield self.is_presence_visible(
+                    observer_user=auth_user,
+                    observed_user=user
+                )
+
+                if not visible:
+                    raise SynapseError(404, "Presence information not visible")
+
+        results = {}
+        if local_users:
+            for user in local_users:
+                if user in self._user_cachemap:
+                    results[user] = self._user_cachemap[user].get_state()
+
+            local_to_user = {u.localpart: u for u in local_users}
+
+            states = yield self.store.get_presence_states(
+                [u.localpart for u in local_users if u not in results]
+            )
+
+            for local_part, state in states.items():
+                if state is None:
+                    continue
+                res = {"presence": state["state"]}
+                if "status_msg" in state and state["status_msg"]:
+                    res["status_msg"] = state["status_msg"]
+                results[local_to_user[local_part]] = res
+
+        for user in remote_users:
+            # TODO(paul): Have remote server send us permissions set
+            results[user] = self._get_or_offline_usercache(user).get_state()
+
+        for state in results.values():
+            if "last_active" in state:
+                state["last_active_ago"] = int(
+                    self.clock.time_msec() - state.pop("last_active")
+                )
+
+        if as_event:
+            for user, state in results.items():
+                content = state
+                content["user_id"] = user.to_string()
+
+                if "last_active" in content:
+                    content["last_active_ago"] = int(
+                        self._clock.time_msec() - content.pop("last_active")
+                    )
+
+                results[user] = {"type": "m.presence", "content": content}
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
     @log_function
     def set_state(self, target_user, auth_user, state):
         # return
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f81d75017d..39392d9fdd 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -91,7 +91,7 @@ class RegistrationHandler(BaseHandler):
             user = UserID(localpart, self.hs.hostname)
             user_id = user.to_string()
 
-            token = self._generate_token(user_id)
+            token = self.generate_token(user_id)
             yield self.store.register(
                 user_id=user_id,
                 token=token,
@@ -111,7 +111,7 @@ class RegistrationHandler(BaseHandler):
                     user_id = user.to_string()
                     yield self.check_user_id_is_valid(user_id)
 
-                    token = self._generate_token(user_id)
+                    token = self.generate_token(user_id)
                     yield self.store.register(
                         user_id=user_id,
                         token=token,
@@ -161,7 +161,7 @@ class RegistrationHandler(BaseHandler):
                 400, "Invalid user localpart for this application service.",
                 errcode=Codes.EXCLUSIVE
             )
-        token = self._generate_token(user_id)
+        token = self.generate_token(user_id)
         yield self.store.register(
             user_id=user_id,
             token=token,
@@ -208,7 +208,7 @@ class RegistrationHandler(BaseHandler):
         user_id = user.to_string()
 
         yield self.check_user_id_is_valid(user_id)
-        token = self._generate_token(user_id)
+        token = self.generate_token(user_id)
         try:
             yield self.store.register(
                 user_id=user_id,
@@ -273,7 +273,7 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE
                 )
 
-    def _generate_token(self, user_id):
+    def generate_token(self, user_id):
         # urlsafe variant uses _ and - so use . as the separator and replace
         # all =s with .s so http clients don't quote =s when it is used as
         # query params.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6cff6230c1..7206ae23d7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -294,15 +294,15 @@ class SyncHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _filter_events_for_client(self, user_id, room_id, events):
-        states = yield self.store.get_state_for_events(
-            room_id, [e.event_id for e in events],
+        event_id_to_state = yield self.store.get_state_for_events(
+            room_id, frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, user_id),
+            )
         )
 
-        events_and_states = zip(events, states)
-
-        def allowed(event_and_state):
-            event, state = event_and_state
-
+        def allowed(event, state):
             if event.type == EventTypes.RoomHistoryVisibility:
                 return True
 
@@ -331,10 +331,11 @@ class SyncHandler(BaseHandler):
                 return membership == Membership.INVITE
 
             return True
-        events_and_states = filter(allowed, events_and_states)
+
         defer.returnValue([
-            ev
-            for ev, _ in events_and_states
+            event
+            for event in events
+            if allowed(event, event_id_to_state[event.event_id])
         ])
 
     @defer.inlineCallbacks
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 0ab2f65972..e012c565ee 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -94,17 +94,14 @@ class PusherPool:
                 self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
+    def remove_pushers_by_user(self, user_id):
         all = yield self.store.get_all_pushers()
         logger.info(
-            "Removing all pushers for user %s except access token %s",
-            user_id, not_access_token_id
+            "Removing all pushers for user %s",
+            user_id,
         )
         for p in all:
-            if (
-                p['user_name'] == user_id and
-                p['access_token'] != not_access_token_id
-            ):
+            if p['user_name'] == user_id:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 998d4d44c6..694072693d 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -78,9 +78,8 @@ class LoginRestServlet(ClientV1RestServlet):
             login_submission["user"] = UserID.create(
                 login_submission["user"], self.hs.hostname).to_string()
 
-        handler = self.handlers.login_handler
-        token = yield handler.login(
-            user=login_submission["user"],
+        token = yield self.handlers.auth_handler.login_with_password(
+            user_id=login_submission["user"],
             password=login_submission["password"])
 
         result = {
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index b082140f1f..897c54b539 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_handlers().auth_handler
-        self.login_handler = hs.get_handlers().login_handler
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
         authed, result, params = yield self.auth_handler.check_auth([
             [LoginType.PASSWORD],
             [LoginType.EMAIL_IDENTITY]
-        ], body)
+        ], body, self.hs.get_ip_from_request(request))
 
         if not authed:
             defer.returnValue((401, result))
@@ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet):
             raise SynapseError(400, "", Codes.MISSING_PARAM)
         new_password = params['new_password']
 
-        yield self.login_handler.set_password(
+        yield self.auth_handler.set_password(
             user_id, new_password, None
         )
 
@@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
     def __init__(self, hs):
         super(ThreepidRestServlet, self).__init__()
         self.hs = hs
-        self.login_handler = hs.get_handlers().login_handler
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
 
@@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
                 logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
                 raise SynapseError(500, "Invalid response from ID Server")
 
-        yield self.login_handler.add_threepid(
+        yield self.auth_handler.add_threepid(
             auth_user.to_string(),
             threepid['medium'],
             threepid['address'],
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 5f3a6207b5..718928eedd 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet
+from synapse.types import UserID
 from syutil.jsonutil import encode_canonical_json
 
 from ._base import client_v2_pattern
@@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
         super(KeyQueryServlet, self).__init__()
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.federation = hs.get_replication_layer()
+        self.is_mine = hs.is_mine
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
-        logger.debug("onPOST")
         yield self.auth.get_user_by_req(request)
         try:
             body = json.loads(request.content.read())
         except:
             raise SynapseError(400, "Invalid key JSON")
-        query = []
-        for user_id, device_ids in body.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-        results = yield self.store.get_e2e_device_keys(query)
-        defer.returnValue(self.json_result(request, results))
+        result = yield self.handle_request(body)
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id):
         auth_user, client_info = yield self.auth.get_user_by_req(request)
         auth_user_id = auth_user.to_string()
-        if not user_id:
-            user_id = auth_user_id
-        if not device_id:
-            device_id = None
-        # Returns a map of user_id->device_id->json_bytes.
-        results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
-        defer.returnValue(self.json_result(request, results))
-
-    def json_result(self, request, results):
+        user_id = user_id if user_id else auth_user_id
+        device_ids = [device_id] if device_id else []
+        result = yield self.handle_request(
+            {"device_keys": {user_id: device_ids}}
+        )
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def handle_request(self, body):
+        local_query = []
+        remote_queries = {}
+        for user_id, device_ids in body.get("device_keys", {}).items():
+            user = UserID.from_string(user_id)
+            if self.is_mine(user):
+                if not device_ids:
+                    local_query.append((user_id, None))
+                else:
+                    for device_id in device_ids:
+                        local_query.append((user_id, device_id))
+            else:
+                remote_queries.setdefault(user.domain, {})[user_id] = list(
+                    device_ids
+                )
+        results = yield self.store.get_e2e_device_keys(local_query)
+
         json_result = {}
         for user_id, device_keys in results.items():
             for device_id, json_bytes in device_keys.items():
                 json_result.setdefault(user_id, {})[device_id] = json.loads(
                     json_bytes
                 )
-        return (200, {"device_keys": json_result})
+
+        for destination, device_keys in remote_queries.items():
+            remote_result = yield self.federation.query_client_keys(
+                destination, {"device_keys": device_keys}
+            )
+            for user_id, keys in remote_result["device_keys"].items():
+                if user_id in device_keys:
+                    json_result[user_id] = keys
+        defer.returnValue((200, {"device_keys": json_result}))
 
 
 class OneTimeKeyServlet(RestServlet):
@@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
+        self.federation = hs.get_replication_layer()
+        self.is_mine = hs.is_mine
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
-        results = yield self.store.claim_e2e_one_time_keys(
-            [(user_id, device_id, algorithm)]
+        result = yield self.handle_request(
+            {"one_time_keys": {user_id: {device_id: algorithm}}}
         )
-        defer.returnValue(self.json_result(request, results))
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
@@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
             body = json.loads(request.content.read())
         except:
             raise SynapseError(400, "Invalid key JSON")
-        query = []
+        result = yield self.handle_request(body)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def handle_request(self, body):
+        local_query = []
+        remote_queries = {}
         for user_id, device_keys in body.get("one_time_keys", {}).items():
-            for device_id, algorithm in device_keys.items():
-                query.append((user_id, device_id, algorithm))
-        results = yield self.store.claim_e2e_one_time_keys(query)
-        defer.returnValue(self.json_result(request, results))
+            user = UserID.from_string(user_id)
+            if self.is_mine(user):
+                for device_id, algorithm in device_keys.items():
+                    local_query.append((user_id, device_id, algorithm))
+            else:
+                remote_queries.setdefault(user.domain, {})[user_id] = (
+                    device_keys
+                )
+        results = yield self.store.claim_e2e_one_time_keys(local_query)
 
-    def json_result(self, request, results):
         json_result = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
@@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
                     json_result.setdefault(user_id, {})[device_id] = {
                         key_id: json.loads(json_bytes)
                     }
-        return (200, {"one_time_keys": json_result})
+
+        for destination, device_keys in remote_queries.items():
+            remote_result = yield self.federation.claim_client_keys(
+                destination, {"one_time_keys": device_keys}
+            )
+            for user_id, keys in remote_result["one_time_keys"].items():
+                if user_id in device_keys:
+                    json_result[user_id] = keys
+
+        defer.returnValue((200, {"one_time_keys": json_result}))
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b5926f9ca6..012c447e88 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -50,7 +50,6 @@ class RegisterRestServlet(RestServlet):
         self.auth_handler = hs.get_handlers().auth_handler
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
-        self.login_handler = hs.get_handlers().login_handler
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -143,7 +142,7 @@ class RegisterRestServlet(RestServlet):
                 if reqd not in threepid:
                     logger.info("Can't add incomplete 3pid")
                 else:
-                    yield self.login_handler.add_threepid(
+                    yield self.auth_handler.add_threepid(
                         user_id,
                         threepid['medium'],
                         threepid['address'],
diff --git a/synapse/state.py b/synapse/state.py
index 80da90a72c..1fe4d066bd 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
-from synapse.util.expiringcache import ExpiringCache
+from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.auth import AuthEventTypes
@@ -96,7 +96,7 @@ class StateHandler(object):
             cache.ts = self.clock.time_msec()
             state = cache.state
         else:
-            res = yield self.resolve_state_groups(event_ids)
+            res = yield self.resolve_state_groups(room_id, event_ids)
             state = res[1]
 
         if event_type:
@@ -155,13 +155,13 @@ class StateHandler(object):
 
         if event.is_state():
             ret = yield self.resolve_state_groups(
-                [e for e, _ in event.prev_events],
+                event.room_id, [e for e, _ in event.prev_events],
                 event_type=event.type,
                 state_key=event.state_key,
             )
         else:
             ret = yield self.resolve_state_groups(
-                [e for e, _ in event.prev_events],
+                event.room_id, [e for e, _ in event.prev_events],
             )
 
         group, curr_state, prev_state = ret
@@ -180,7 +180,7 @@ class StateHandler(object):
 
     @defer.inlineCallbacks
     @log_function
-    def resolve_state_groups(self, event_ids, event_type=None, state_key=""):
+    def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
@@ -205,7 +205,7 @@ class StateHandler(object):
                 )
 
         state_groups = yield self.store.get_state_groups(
-            event_ids
+            room_id, event_ids
         )
 
         logger.debug(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 73eea157a4..1444767a52 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,25 +15,22 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.util.async import ObservableDeferred
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
-from synapse.util.lrucache import LruCache
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.caches.descriptors import Cache
 import synapse.metrics
 
 from util.id_generators import IdGenerator, StreamIdGenerator
 
 from twisted.internet import defer
 
-from collections import namedtuple, OrderedDict
+from collections import namedtuple
 
-import functools
-import inspect
 import sys
 import time
 import threading
 
-DEBUG_CACHES = False
 
 logger = logging.getLogger(__name__)
 
@@ -49,208 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
 sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
 sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
 
-caches_by_name = {}
-cache_counter = metrics.register_cache(
-    "cache",
-    lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
-    labels=["name"],
-)
-
-
-_CacheSentinel = object()
-
-
-class Cache(object):
-
-    def __init__(self, name, max_entries=1000, keylen=1, lru=True):
-        if lru:
-            self.cache = LruCache(max_size=max_entries)
-            self.max_entries = None
-        else:
-            self.cache = OrderedDict()
-            self.max_entries = max_entries
-
-        self.name = name
-        self.keylen = keylen
-        self.sequence = 0
-        self.thread = None
-        caches_by_name[name] = self.cache
-
-    def check_thread(self):
-        expected_thread = self.thread
-        if expected_thread is None:
-            self.thread = threading.current_thread()
-        else:
-            if expected_thread is not threading.current_thread():
-                raise ValueError(
-                    "Cache objects can only be accessed from the main thread"
-                )
-
-    def get(self, key, default=_CacheSentinel):
-        val = self.cache.get(key, _CacheSentinel)
-        if val is not _CacheSentinel:
-            cache_counter.inc_hits(self.name)
-            return val
-
-        cache_counter.inc_misses(self.name)
-
-        if default is _CacheSentinel:
-            raise KeyError()
-        else:
-            return default
-
-    def update(self, sequence, key, value):
-        self.check_thread()
-        if self.sequence == sequence:
-            # Only update the cache if the caches sequence number matches the
-            # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(key, value)
-
-    def prefill(self, key, value):
-        if self.max_entries is not None:
-            while len(self.cache) >= self.max_entries:
-                self.cache.popitem(last=False)
-
-        self.cache[key] = value
-
-    def invalidate(self, key):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError(
-                "The cache key must be a tuple not %r" % (type(key),)
-            )
-
-        # Increment the sequence number so that any SELECT statements that
-        # raced with the INSERT don't update the cache (SYN-369)
-        self.sequence += 1
-        self.cache.pop(key, None)
-
-    def invalidate_all(self):
-        self.check_thread()
-        self.sequence += 1
-        self.cache.clear()
-
-
-class CacheDescriptor(object):
-    """ A method decorator that applies a memoizing cache around the function.
-
-    This caches deferreds, rather than the results themselves. Deferreds that
-    fail are removed from the cache.
-
-    The function is presumed to take zero or more arguments, which are used in
-    a tuple as the key for the cache. Hits are served directly from the cache;
-    misses use the function body to generate the value.
-
-    The wrapped function has an additional member, a callable called
-    "invalidate". This can be used to remove individual entries from the cache.
-
-    The wrapped function has another additional callable, called "prefill",
-    which can be used to insert values into the cache specifically, without
-    calling the calculation function.
-    """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
-                 inlineCallbacks=False):
-        self.orig = orig
-
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
-        self.max_entries = max_entries
-        self.num_args = num_args
-        self.lru = lru
-
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
-
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
-                % (orig.__name__,)
-            )
-
-        self.cache = Cache(
-            name=self.orig.__name__,
-            max_entries=self.max_entries,
-            keylen=self.num_args,
-            lru=self.lru,
-        )
-
-    def __get__(self, obj, objtype=None):
-
-        @functools.wraps(self.orig)
-        def wrapped(*args, **kwargs):
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
-            try:
-                cached_result_d = self.cache.get(cache_key)
-
-                observer = cached_result_d.observe()
-                if DEBUG_CACHES:
-                    @defer.inlineCallbacks
-                    def check_result(cached_result):
-                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
-                        if actual_result != cached_result:
-                            logger.error(
-                                "Stale cache entry %s%r: cached: %r, actual %r",
-                                self.orig.__name__, cache_key,
-                                cached_result, actual_result,
-                            )
-                            raise ValueError("Stale cache entry")
-                        defer.returnValue(cached_result)
-                    observer.addCallback(check_result)
-
-                return observer
-            except KeyError:
-                # Get the sequence number of the cache before reading from the
-                # database so that we can tell if the cache is invalidated
-                # while the SELECT is executing (SYN-369)
-                sequence = self.cache.sequence
-
-                ret = defer.maybeDeferred(
-                    self.function_to_call,
-                    obj, *args, **kwargs
-                )
-
-                def onErr(f):
-                    self.cache.invalidate(cache_key)
-                    return f
-
-                ret.addErrback(onErr)
-
-                ret = ObservableDeferred(ret, consumeErrors=True)
-                self.cache.update(sequence, cache_key, ret)
-
-                return ret.observe()
-
-        wrapped.invalidate = self.cache.invalidate
-        wrapped.invalidate_all = self.cache.invalidate_all
-        wrapped.prefill = self.cache.prefill
-
-        obj.__dict__[self.orig.__name__] = wrapped
-
-        return wrapped
-
-
-def cached(max_entries=1000, num_args=1, lru=True):
-    return lambda orig: CacheDescriptor(
-        orig,
-        max_entries=max_entries,
-        num_args=num_args,
-        lru=lru
-    )
-
-
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
-    return lambda orig: CacheDescriptor(
-        orig,
-        max_entries=max_entries,
-        num_args=num_args,
-        lru=lru,
-        inlineCallbacks=True,
-    )
-
 
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
@@ -372,6 +167,8 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
+        self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000)
+
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index f3947bbe89..d92028ea43 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
 
 from synapse.api.errors import SynapseError
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 910b6598a7..25cc84eb95 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -15,7 +15,8 @@
 
 from twisted.internet import defer
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
 from syutil.base64util import encode_base64
 
 import logging
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 49b8e37cfd..ffd6daa880 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from _base import SQLBaseStore, cachedInlineCallbacks
+from _base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 from twisted.internet import defer
 
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 576cf670cc..34ca3b9a54 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -13,19 +13,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedList
 
 from twisted.internet import defer
 
 
 class PresenceStore(SQLBaseStore):
     def create_presence(self, user_localpart):
-        return self._simple_insert(
+        res = self._simple_insert(
             table="presence",
             values={"user_id": user_localpart},
             desc="create_presence",
         )
 
+        self.get_presence_state.invalidate((user_localpart,))
+        return res
+
     def has_presence_state(self, user_localpart):
         return self._simple_select_one(
             table="presence",
@@ -35,6 +39,7 @@ class PresenceStore(SQLBaseStore):
             desc="has_presence_state",
         )
 
+    @cached(max_entries=2000)
     def get_presence_state(self, user_localpart):
         return self._simple_select_one(
             table="presence",
@@ -43,8 +48,27 @@ class PresenceStore(SQLBaseStore):
             desc="get_presence_state",
         )
 
+    @cachedList(get_presence_state.cache, list_name="user_localparts")
+    def get_presence_states(self, user_localparts):
+        def f(txn):
+            results = {}
+            for user_localpart in user_localparts:
+                res = self._simple_select_one_txn(
+                    txn,
+                    table="presence",
+                    keyvalues={"user_id": user_localpart},
+                    retcols=["state", "status_msg", "mtime"],
+                    allow_none=True,
+                )
+                if res:
+                    results[user_localpart] = res
+
+            return results
+
+        return self.runInteraction("get_presence_states", f)
+
     def set_presence_state(self, user_localpart, new_state):
-        return self._simple_update_one(
+        res = self._simple_update_one(
             table="presence",
             keyvalues={"user_id": user_localpart},
             updatevalues={"state": new_state["state"],
@@ -53,6 +77,9 @@ class PresenceStore(SQLBaseStore):
             desc="set_presence_state",
         )
 
+        self.get_presence_state.invalidate((user_localpart,))
+        return res
+
     def allow_presence_visible(self, observed_localpart, observer_userid):
         return self._simple_insert(
             table="presence_allow_inbound",
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 9b88ca7b39..5305b7e122 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cachedInlineCallbacks
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 from twisted.internet import defer
 
 import logging
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index b79d6683ca..cac1a5657e 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cachedInlineCallbacks
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 from twisted.internet import defer
 
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 4eaa088b36..bf803f2c6e 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError, Codes
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
 
 
 class RegistrationStore(SQLBaseStore):
@@ -111,16 +112,16 @@ class RegistrationStore(SQLBaseStore):
         })
 
     @defer.inlineCallbacks
-    def user_delete_access_tokens_apart_from(self, user_id, token_id):
+    def user_delete_access_tokens(self, user_id):
         yield self.runInteraction(
-            "user_delete_access_tokens_apart_from",
-            self._user_delete_access_tokens_apart_from, user_id, token_id
+            "user_delete_access_tokens",
+            self._user_delete_access_tokens, user_id
         )
 
-    def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id):
+    def _user_delete_access_tokens(self, txn, user_id):
         txn.execute(
-            "DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
-            (user_id, token_id)
+            "DELETE FROM access_tokens WHERE user_id = ?",
+            (user_id, )
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index dd5bc2c8fb..5e07b7e0e5 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 
-from ._base import SQLBaseStore, cachedInlineCallbacks
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 import collections
 import logging
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9f14f38f24..8eee2dfbcc 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
 
 from collections import namedtuple
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
 
 from synapse.api.constants import Membership
 from synapse.types import UserID
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7ce51b9bdc..ab3ad5a076 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cachedInlineCallbacks
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import (
+    cached, cachedInlineCallbacks, cachedList
+)
 
 from twisted.internet import defer
 
@@ -44,59 +47,25 @@ class StateStore(SQLBaseStore):
     """
 
     @defer.inlineCallbacks
-    def get_state_groups(self, event_ids):
+    def get_state_groups(self, room_id, event_ids):
         """ Get the state groups for the given list of event_ids
 
         The return value is a dict mapping group names to lists of events.
         """
+        if not event_ids:
+            defer.returnValue({})
 
-        def f(txn):
-            groups = set()
-            for event_id in event_ids:
-                group = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="event_to_state_groups",
-                    keyvalues={"event_id": event_id},
-                    retcol="state_group",
-                    allow_none=True,
-                )
-                if group:
-                    groups.add(group)
-
-            res = {}
-            for group in groups:
-                state_ids = self._simple_select_onecol_txn(
-                    txn,
-                    table="state_groups_state",
-                    keyvalues={"state_group": group},
-                    retcol="event_id",
-                )
-
-                res[group] = state_ids
-
-            return res
-
-        states = yield self.runInteraction(
-            "get_state_groups",
-            f,
+        event_to_groups = yield self._get_state_group_for_events(
+            room_id, event_ids,
         )
 
-        state_list = yield defer.gatherResults(
-            [
-                self._fetch_events_for_group(group, vals)
-                for group, vals in states.items()
-            ],
-            consumeErrors=True,
-        )
+        groups = set(event_to_groups.values())
+        group_to_state = yield self._get_state_for_groups(groups)
 
-        defer.returnValue(dict(state_list))
-
-    def _fetch_events_for_group(self, key, events):
-        return self._get_events(
-            events, get_prev_content=False
-        ).addCallback(
-            lambda evs: (key, evs)
-        )
+        defer.returnValue({
+            group: state_map.values()
+            for group, state_map in group_to_state.items()
+        })
 
     def _store_state_groups_txn(self, txn, event, context):
         return self._store_mult_state_groups_txn(txn, [(event, context)])
@@ -204,64 +173,251 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
+    def _get_state_groups_from_groups(self, groups_and_types):
+        """Returns dictionary state_group -> state event ids
+
+        Args:
+            groups_and_types (list): list of 2-tuple (`group`, `types`)
+        """
+        def f(txn):
+            results = {}
+            for group, types in groups_and_types:
+                if types is not None:
+                    where_clause = "AND (%s)" % (
+                        " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
+                    )
+                else:
+                    where_clause = ""
+
+                sql = (
+                    "SELECT event_id FROM state_groups_state WHERE"
+                    " state_group = ? %s"
+                ) % (where_clause,)
+
+                args = [group]
+                if types is not None:
+                    args.extend([i for typ in types for i in typ])
+
+                txn.execute(sql, args)
+
+                results[group] = [r[0] for r in txn.fetchall()]
+
+            return results
+
+        return self.runInteraction(
+            "_get_state_groups_from_groups",
+            f,
+        )
+
     @defer.inlineCallbacks
-    def get_state_for_events(self, room_id, event_ids):
+    def get_state_for_events(self, room_id, event_ids, types):
+        """Given a list of event_ids and type tuples, return a list of state
+        dicts for each event. The state dicts will only have the type/state_keys
+        that are in the `types` list.
+
+        Args:
+            room_id (str)
+            event_ids (list)
+            types (list): List of (type, state_key) tuples which are used to
+                filter the state fetched. `state_key` may be None, which matches
+                any `state_key`
+
+        Returns:
+            deferred: A list of dicts corresponding to the event_ids given.
+            The dicts are mappings from (type, state_key) -> state_events
+        """
+        event_to_groups = yield self._get_state_group_for_events(
+            room_id, event_ids,
+        )
+
+        groups = set(event_to_groups.values())
+        group_to_state = yield self._get_state_for_groups(groups, types)
+
+        event_to_state = {
+            event_id: group_to_state[group]
+            for event_id, group in event_to_groups.items()
+        }
+
+        defer.returnValue({event: event_to_state[event] for event in event_ids})
+
+    @cached(num_args=2, lru=True, max_entries=10000)
+    def _get_state_group_for_event(self, room_id, event_id):
+        return self._simple_select_one_onecol(
+            table="event_to_state_groups",
+            keyvalues={
+                "event_id": event_id,
+            },
+            retcol="state_group",
+            allow_none=True,
+            desc="_get_state_group_for_event",
+        )
+
+    @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
+                num_args=2)
+    def _get_state_group_for_events(self, room_id, event_ids):
+        """Returns mapping event_id -> state_group
+        """
         def f(txn):
-            groups = set()
-            event_to_group = {}
+            results = {}
             for event_id in event_ids:
-                # TODO: Remove this loop.
-                group = self._simple_select_one_onecol_txn(
+                results[event_id] = self._simple_select_one_onecol_txn(
                     txn,
                     table="event_to_state_groups",
-                    keyvalues={"event_id": event_id},
+                    keyvalues={
+                        "event_id": event_id,
+                    },
                     retcol="state_group",
                     allow_none=True,
                 )
-                if group:
-                    event_to_group[event_id] = group
-                    groups.add(group)
 
-            group_to_state_ids = {}
-            for group in groups:
-                state_ids = self._simple_select_onecol_txn(
-                    txn,
-                    table="state_groups_state",
-                    keyvalues={"state_group": group},
-                    retcol="event_id",
+            return results
+
+        return self.runInteraction("_get_state_group_for_events", f)
+
+    def _get_some_state_from_cache(self, group, types):
+        """Checks if group is in cache. See `_get_state_for_groups`
+
+        Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
+        `missing_types` is the list of types that aren't in the cache for that
+        group. `got_all` is a bool indicating if we successfully retrieved all
+        requests state from the cache, if False we need to query the DB for the
+        missing state.
+
+        Args:
+            group: The state group to lookup
+            types (list): List of 2-tuples of the form (`type`, `state_key`),
+                where a `state_key` of `None` matches all state_keys for the
+                `type`.
+        """
+        is_all, state_dict = self._state_group_cache.get(group)
+
+        type_to_key = {}
+        missing_types = set()
+        for typ, state_key in types:
+            if state_key is None:
+                type_to_key[typ] = None
+                missing_types.add((typ, state_key))
+            else:
+                if type_to_key.get(typ, object()) is not None:
+                    type_to_key.setdefault(typ, set()).add(state_key)
+
+                if (typ, state_key) not in state_dict:
+                    missing_types.add((typ, state_key))
+
+        sentinel = object()
+
+        def include(typ, state_key):
+            valid_state_keys = type_to_key.get(typ, sentinel)
+            if valid_state_keys is sentinel:
+                return False
+            if valid_state_keys is None:
+                return True
+            if state_key in valid_state_keys:
+                return True
+            return False
+
+        got_all = not (missing_types or types is None)
+
+        return {
+            k: v for k, v in state_dict.items()
+            if include(k[0], k[1])
+        }, missing_types, got_all
+
+    def _get_all_state_from_cache(self, group):
+        """Checks if group is in cache. See `_get_state_for_groups`
+
+        Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
+        indicating if we successfully retrieved all requests state from the
+        cache, if False we need to query the DB for the missing state.
+
+        Args:
+            group: The state group to lookup
+        """
+        is_all, state_dict = self._state_group_cache.get(group)
+        return state_dict, is_all
+
+    @defer.inlineCallbacks
+    def _get_state_for_groups(self, groups, types=None):
+        """Given list of groups returns dict of group -> list of state events
+        with matching types. `types` is a list of `(type, state_key)`, where
+        a `state_key` of None matches all state_keys. If `types` is None then
+        all events are returned.
+        """
+        results = {}
+        missing_groups_and_types = []
+        if types is not None:
+            for group in set(groups):
+                state_dict, missing_types, got_all = self._get_some_state_from_cache(
+                    group, types
+                )
+                results[group] = state_dict
+
+                if not got_all:
+                    missing_groups_and_types.append((group, missing_types))
+        else:
+            for group in set(groups):
+                state_dict, got_all = self._get_all_state_from_cache(
+                    group
                 )
+                results[group] = state_dict
 
-                group_to_state_ids[group] = state_ids
+                if not got_all:
+                    missing_groups_and_types.append((group, None))
 
-            return event_to_group, group_to_state_ids
+        if not missing_groups_and_types:
+            defer.returnValue({
+                group: {
+                    type_tuple: event
+                    for type_tuple, event in state.items()
+                    if event
+                }
+                for group, state in results.items()
+            })
 
-        res = yield self.runInteraction(
-            "annotate_events_with_state_groups",
-            f,
-        )
+        # Okay, so we have some missing_types, lets fetch them.
+        cache_seq_num = self._state_group_cache.sequence
 
-        event_to_group, group_to_state_ids = res
+        group_state_dict = yield self._get_state_groups_from_groups(
+            missing_groups_and_types
+        )
 
-        state_list = yield defer.gatherResults(
-            [
-                self._fetch_events_for_group(group, vals)
-                for group, vals in group_to_state_ids.items()
-            ],
-            consumeErrors=True,
+        state_events = yield self._get_events(
+            [e_id for l in group_state_dict.values() for e_id in l],
+            get_prev_content=False
         )
 
-        state_dict = {
-            group: {
-                (ev.type, ev.state_key): ev
-                for ev in state
+        state_events = {e.event_id: e for e in state_events}
+
+        # Now we want to update the cache with all the things we fetched
+        # from the database.
+        for group, state_ids in group_state_dict.items():
+            if types:
+                # We delibrately put key -> None mappings into the cache to
+                # cache absence of the key, on the assumption that if we've
+                # explicitly asked for some types then we will probably ask
+                # for them again.
+                state_dict = {key: None for key in types}
+                state_dict.update(results[group])
+            else:
+                state_dict = results[group]
+
+            for event_id in state_ids:
+                state_event = state_events[event_id]
+                state_dict[(state_event.type, state_event.state_key)] = state_event
+
+            self._state_group_cache.update(
+                cache_seq_num,
+                key=group,
+                value=state_dict,
+                full=(types is None),
+            )
+
+            # We replace here to remove all the entries with None values.
+            results[group] = {
+                key: value for key, value in state_dict.items() if value
             }
-            for group, state in state_list
-        }
 
-        defer.returnValue([
-            state_dict.get(event_to_group.get(event, None), None)
-            for event in event_ids
-        ])
+        defer.returnValue(results)
 
 
 def _make_group_id(clock):
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index af45fc5619..d7fe423f5a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -36,6 +36,7 @@ what sort order was used:
 from twisted.internet import defer
 
 from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.api.constants import EventTypes
 from synapse.types import RoomStreamToken
 from synapse.util.logutils import log_function
@@ -299,9 +300,8 @@ class StreamStore(SQLBaseStore):
 
         defer.returnValue((events, token))
 
-    @defer.inlineCallbacks
-    def get_recent_events_for_room(self, room_id, limit, end_token,
-                                   with_feedback=False, from_token=None):
+    @cachedInlineCallbacks(num_args=4)
+    def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
         # TODO (erikj): Handle compressed feedback
 
         end_token = RoomStreamToken.parse_stream_token(end_token)
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 624da4a9dc..c8c7e6591a 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
 
 from collections import namedtuple
 
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
new file mode 100644
index 0000000000..da0e06a468
--- /dev/null
+++ b/synapse/util/caches/__init__.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse.metrics
+
+DEBUG_CACHES = False
+
+metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
+
+caches_by_name = {}
+cache_counter = metrics.register_cache(
+    "cache",
+    lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
+    labels=["name"],
+)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
new file mode 100644
index 0000000000..362944bc51
--- /dev/null
+++ b/synapse/util/caches/descriptors.py
@@ -0,0 +1,377 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError
+from synapse.util.caches.lrucache import LruCache
+
+from . import caches_by_name, DEBUG_CACHES, cache_counter
+
+from twisted.internet import defer
+
+from collections import OrderedDict
+
+import functools
+import inspect
+import threading
+
+logger = logging.getLogger(__name__)
+
+
+_CacheSentinel = object()
+
+
+class Cache(object):
+
+    def __init__(self, name, max_entries=1000, keylen=1, lru=True):
+        if lru:
+            self.cache = LruCache(max_size=max_entries)
+            self.max_entries = None
+        else:
+            self.cache = OrderedDict()
+            self.max_entries = max_entries
+
+        self.name = name
+        self.keylen = keylen
+        self.sequence = 0
+        self.thread = None
+        caches_by_name[name] = self.cache
+
+    def check_thread(self):
+        expected_thread = self.thread
+        if expected_thread is None:
+            self.thread = threading.current_thread()
+        else:
+            if expected_thread is not threading.current_thread():
+                raise ValueError(
+                    "Cache objects can only be accessed from the main thread"
+                )
+
+    def get(self, key, default=_CacheSentinel):
+        val = self.cache.get(key, _CacheSentinel)
+        if val is not _CacheSentinel:
+            cache_counter.inc_hits(self.name)
+            return val
+
+        cache_counter.inc_misses(self.name)
+
+        if default is _CacheSentinel:
+            raise KeyError()
+        else:
+            return default
+
+    def update(self, sequence, key, value):
+        self.check_thread()
+        if self.sequence == sequence:
+            # Only update the cache if the caches sequence number matches the
+            # number that the cache had before the SELECT was started (SYN-369)
+            self.prefill(key, value)
+
+    def prefill(self, key, value):
+        if self.max_entries is not None:
+            while len(self.cache) >= self.max_entries:
+                self.cache.popitem(last=False)
+
+        self.cache[key] = value
+
+    def invalidate(self, key):
+        self.check_thread()
+        if not isinstance(key, tuple):
+            raise TypeError(
+                "The cache key must be a tuple not %r" % (type(key),)
+            )
+
+        # Increment the sequence number so that any SELECT statements that
+        # raced with the INSERT don't update the cache (SYN-369)
+        self.sequence += 1
+        self.cache.pop(key, None)
+
+    def invalidate_all(self):
+        self.check_thread()
+        self.sequence += 1
+        self.cache.clear()
+
+
+class CacheDescriptor(object):
+    """ A method decorator that applies a memoizing cache around the function.
+
+    This caches deferreds, rather than the results themselves. Deferreds that
+    fail are removed from the cache.
+
+    The function is presumed to take zero or more arguments, which are used in
+    a tuple as the key for the cache. Hits are served directly from the cache;
+    misses use the function body to generate the value.
+
+    The wrapped function has an additional member, a callable called
+    "invalidate". This can be used to remove individual entries from the cache.
+
+    The wrapped function has another additional callable, called "prefill",
+    which can be used to insert values into the cache specifically, without
+    calling the calculation function.
+    """
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+                 inlineCallbacks=False):
+        self.orig = orig
+
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
+
+        self.max_entries = max_entries
+        self.num_args = num_args
+        self.lru = lru
+
+        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+        if len(self.arg_names) < self.num_args:
+            raise Exception(
+                "Not enough explicit positional arguments to key off of for %r."
+                " (@cached cannot key off of *args or **kwars)"
+                % (orig.__name__,)
+            )
+
+        self.cache = Cache(
+            name=self.orig.__name__,
+            max_entries=self.max_entries,
+            keylen=self.num_args,
+            lru=self.lru,
+        )
+
+    def __get__(self, obj, objtype=None):
+
+        @functools.wraps(self.orig)
+        def wrapped(*args, **kwargs):
+            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            try:
+                cached_result_d = self.cache.get(cache_key)
+
+                observer = cached_result_d.observe()
+                if DEBUG_CACHES:
+                    @defer.inlineCallbacks
+                    def check_result(cached_result):
+                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
+                        if actual_result != cached_result:
+                            logger.error(
+                                "Stale cache entry %s%r: cached: %r, actual %r",
+                                self.orig.__name__, cache_key,
+                                cached_result, actual_result,
+                            )
+                            raise ValueError("Stale cache entry")
+                        defer.returnValue(cached_result)
+                    observer.addCallback(check_result)
+
+                return observer
+            except KeyError:
+                # Get the sequence number of the cache before reading from the
+                # database so that we can tell if the cache is invalidated
+                # while the SELECT is executing (SYN-369)
+                sequence = self.cache.sequence
+
+                ret = defer.maybeDeferred(
+                    self.function_to_call,
+                    obj, *args, **kwargs
+                )
+
+                def onErr(f):
+                    self.cache.invalidate(cache_key)
+                    return f
+
+                ret.addErrback(onErr)
+
+                ret = ObservableDeferred(ret, consumeErrors=True)
+                self.cache.update(sequence, cache_key, ret)
+
+                return ret.observe()
+
+        wrapped.invalidate = self.cache.invalidate
+        wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.prefill = self.cache.prefill
+
+        obj.__dict__[self.orig.__name__] = wrapped
+
+        return wrapped
+
+
+class CacheListDescriptor(object):
+    """Wraps an existing cache to support bulk fetching of keys.
+
+    Given a list of keys it looks in the cache to find any hits, then passes
+    the list of missing keys to the wrapped fucntion.
+    """
+
+    def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
+        """
+        Args:
+            orig (function)
+            cache (Cache)
+            list_name (str): Name of the argument which is the bulk lookup list
+            num_args (int)
+            inlineCallbacks (bool): Whether orig is a generator that should
+                be wrapped by defer.inlineCallbacks
+        """
+        self.orig = orig
+
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
+
+        self.num_args = num_args
+        self.list_name = list_name
+
+        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+        self.list_pos = self.arg_names.index(self.list_name)
+
+        self.cache = cache
+
+        self.sentinel = object()
+
+        if len(self.arg_names) < self.num_args:
+            raise Exception(
+                "Not enough explicit positional arguments to key off of for %r."
+                " (@cached cannot key off of *args or **kwars)"
+                % (orig.__name__,)
+            )
+
+        if self.list_name not in self.arg_names:
+            raise Exception(
+                "Couldn't see arguments %r for %r."
+                % (self.list_name, cache.name,)
+            )
+
+    def __get__(self, obj, objtype=None):
+
+        @functools.wraps(self.orig)
+        def wrapped(*args, **kwargs):
+            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
+            list_args = arg_dict[self.list_name]
+
+            # cached is a dict arg -> deferred, where deferred results in a
+            # 2-tuple (`arg`, `result`)
+            cached = {}
+            missing = []
+            for arg in list_args:
+                key = list(keyargs)
+                key[self.list_pos] = arg
+
+                try:
+                    res = self.cache.get(tuple(key)).observe()
+                    res.addCallback(lambda r, arg: (arg, r), arg)
+                    cached[arg] = res
+                except KeyError:
+                    missing.append(arg)
+
+            if missing:
+                sequence = self.cache.sequence
+                args_to_call = dict(arg_dict)
+                args_to_call[self.list_name] = missing
+
+                ret_d = defer.maybeDeferred(
+                    self.function_to_call,
+                    **args_to_call
+                )
+
+                ret_d = ObservableDeferred(ret_d)
+
+                # We need to create deferreds for each arg in the list so that
+                # we can insert the new deferred into the cache.
+                for arg in missing:
+                    observer = ret_d.observe()
+                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)
+
+                    observer = ObservableDeferred(observer)
+
+                    key = list(keyargs)
+                    key[self.list_pos] = arg
+                    self.cache.update(sequence, tuple(key), observer)
+
+                    def invalidate(f, key):
+                        self.cache.invalidate(key)
+                        return f
+                    observer.addErrback(invalidate, tuple(key))
+
+                    res = observer.observe()
+                    res.addCallback(lambda r, arg: (arg, r), arg)
+
+                    cached[arg] = res
+
+            return defer.gatherResults(
+                cached.values(),
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+
+        obj.__dict__[self.orig.__name__] = wrapped
+
+        return wrapped
+
+
+def cached(max_entries=1000, num_args=1, lru=True):
+    return lambda orig: CacheDescriptor(
+        orig,
+        max_entries=max_entries,
+        num_args=num_args,
+        lru=lru
+    )
+
+
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+    return lambda orig: CacheDescriptor(
+        orig,
+        max_entries=max_entries,
+        num_args=num_args,
+        lru=lru,
+        inlineCallbacks=True,
+    )
+
+
+def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
+    """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
+
+    Used to do batch lookups for an already created cache. A single argument
+    is specified as a list that is iterated through to lookup keys in the
+    original cache. A new list consisting of the keys that weren't in the cache
+    get passed to the original function, the result of which is stored in the
+    cache.
+
+    Args:
+        cache (Cache): The underlying cache to use.
+        list_name (str): The name of the argument that is the list to use to
+            do batch lookups in the cache.
+        num_args (int): Number of arguments to use as the key in the cache.
+        inlineCallbacks (bool): Should the function be wrapped in an
+            `defer.inlineCallbacks`?
+
+    Example:
+
+        class Example(object):
+            @cached(num_args=2)
+            def do_something(self, first_arg):
+                ...
+
+            @cachedList(do_something.cache, list_name="second_args", num_args=2)
+            def batch_do_something(self, first_arg, second_args):
+                ...
+    """
+    return lambda orig: CacheListDescriptor(
+        orig,
+        cache=cache,
+        list_name=list_name,
+        num_args=num_args,
+        inlineCallbacks=inlineCallbacks,
+    )
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
new file mode 100644
index 0000000000..e69adf62fe
--- /dev/null
+++ b/synapse/util/caches/dictionary_cache.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches.lrucache import LruCache
+from collections import namedtuple
+from . import caches_by_name, cache_counter
+import threading
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
+
+
+class DictionaryCache(object):
+    """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
+    fetching a subset of dictionary keys for a particular key.
+    """
+
+    def __init__(self, name, max_entries=1000):
+        self.cache = LruCache(max_size=max_entries)
+
+        self.name = name
+        self.sequence = 0
+        self.thread = None
+        # caches_by_name[name] = self.cache
+
+        class Sentinel(object):
+            __slots__ = []
+
+        self.sentinel = Sentinel()
+        caches_by_name[name] = self.cache
+
+    def check_thread(self):
+        expected_thread = self.thread
+        if expected_thread is None:
+            self.thread = threading.current_thread()
+        else:
+            if expected_thread is not threading.current_thread():
+                raise ValueError(
+                    "Cache objects can only be accessed from the main thread"
+                )
+
+    def get(self, key, dict_keys=None):
+        entry = self.cache.get(key, self.sentinel)
+        if entry is not self.sentinel:
+            cache_counter.inc_hits(self.name)
+
+            if dict_keys is None:
+                return DictionaryEntry(entry.full, dict(entry.value))
+            else:
+                return DictionaryEntry(entry.full, {
+                    k: entry.value[k]
+                    for k in dict_keys
+                    if k in entry.value
+                })
+
+        cache_counter.inc_misses(self.name)
+        return DictionaryEntry(False, {})
+
+    def invalidate(self, key):
+        self.check_thread()
+
+        # Increment the sequence number so that any SELECT statements that
+        # raced with the INSERT don't update the cache (SYN-369)
+        self.sequence += 1
+        self.cache.pop(key, None)
+
+    def invalidate_all(self):
+        self.check_thread()
+        self.sequence += 1
+        self.cache.clear()
+
+    def update(self, sequence, key, value, full=False):
+        self.check_thread()
+        if self.sequence == sequence:
+            # Only update the cache if the caches sequence number matches the
+            # number that the cache had before the SELECT was started (SYN-369)
+            if full:
+                self._insert(key, value)
+            else:
+                self._update_or_insert(key, value)
+
+    def _update_or_insert(self, key, value):
+        entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+        entry.value.update(value)
+
+    def _insert(self, key, value):
+        self.cache[key] = DictionaryEntry(True, value)
diff --git a/synapse/util/expiringcache.py b/synapse/util/caches/expiringcache.py
index 06d1eea01b..06d1eea01b 100644
--- a/synapse/util/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
diff --git a/synapse/util/lrucache.py b/synapse/util/caches/lrucache.py
index cacd7e45fa..cacd7e45fa 100644
--- a/synapse/util/lrucache.py
+++ b/synapse/util/caches/lrucache.py
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index abee2f631d..e72cace8ff 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
 
 from synapse.util.async import ObservableDeferred
 
-from synapse.storage._base import Cache, cached
+from synapse.util.caches.descriptors import Cache, cached
 
 
 class CacheTestCase(unittest.TestCase):
diff --git a/tests/test_state.py b/tests/test_state.py
index fea25f7021..5845358754 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -69,7 +69,7 @@ class StateGroupStore(object):
 
         self._next_group = 1
 
-    def get_state_groups(self, event_ids):
+    def get_state_groups(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
new file mode 100644
index 0000000000..54ff26cd97
--- /dev/null
+++ b/tests/util/test_dict_cache.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+from tests import unittest
+
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+
+class DictCacheTestCase(unittest.TestCase):
+
+    def setUp(self):
+        self.cache = DictionaryCache("foobar")
+
+    def test_simple_cache_hit_full(self):
+        key = "test_simple_cache_hit_full"
+
+        v = self.cache.get(key)
+        self.assertEqual((False, {}), v)
+
+        seq = self.cache.sequence
+        test_value = {"test": "test_simple_cache_hit_full"}
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key)
+        self.assertEqual(test_value, c.value)
+
+    def test_simple_cache_hit_partial(self):
+        key = "test_simple_cache_hit_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_hit_partial"
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test"])
+        self.assertEqual(test_value, c.value)
+
+    def test_simple_cache_miss_partial(self):
+        key = "test_simple_cache_miss_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_miss_partial"
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test2"])
+        self.assertEqual({}, c.value)
+
+    def test_simple_cache_hit_miss_partial(self):
+        key = "test_simple_cache_hit_miss_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_hit_miss_partial",
+            "test2": "test_simple_cache_hit_miss_partial2",
+            "test3": "test_simple_cache_hit_miss_partial3",
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test2"])
+        self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
+
+    def test_multi_insert(self):
+        key = "test_simple_cache_hit_miss_partial"
+
+        seq = self.cache.sequence
+        test_value_1 = {
+            "test": "test_simple_cache_hit_miss_partial",
+        }
+        self.cache.update(seq, key, test_value_1, full=False)
+
+        seq = self.cache.sequence
+        test_value_2 = {
+            "test2": "test_simple_cache_hit_miss_partial2",
+        }
+        self.cache.update(seq, key, test_value_2, full=False)
+
+        c = self.cache.get(key)
+        self.assertEqual(
+            {
+                "test": "test_simple_cache_hit_miss_partial",
+                "test2": "test_simple_cache_hit_miss_partial2",
+            },
+            c.value
+        )
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index ab934bf928..fc5a904323 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -16,7 +16,7 @@
 
 from .. import unittest
 
-from synapse.util.lrucache import LruCache
+from synapse.util.caches.lrucache import LruCache
 
 class LruCacheTestCase(unittest.TestCase):
 
@@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase):
         cache["key"] = 1
         self.assertEquals(cache.pop("key"), 1)
         self.assertEquals(cache.pop("key"), None)
-
-