summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py581
-rw-r--r--synapse/event_auth.py678
-rw-r--r--synapse/events/__init__.py8
-rw-r--r--synapse/events/builder.py6
-rw-r--r--synapse/federation/federation_client.py6
-rw-r--r--synapse/federation/transaction_queue.py2
-rw-r--r--synapse/handlers/federation.py6
-rw-r--r--synapse/push/push_tools.py2
-rw-r--r--synapse/rest/client/v1/login.py8
-rw-r--r--synapse/rest/client/v2_alpha/account.py5
-rw-r--r--synapse/state.py431
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/deviceinbox.py35
-rw-r--r--synapse/storage/events.py4
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/roommember.py3
-rw-r--r--synapse/storage/schema/delta/40/device_inbox.sql21
-rw-r--r--synapse/storage/state.py2
-rw-r--r--synapse/util/caches/__init__.py11
-rw-r--r--synapse/util/caches/descriptors.py109
-rw-r--r--synapse/util/caches/dictionary_cache.py6
-rw-r--r--synapse/util/caches/expiringcache.py41
-rw-r--r--synapse/util/caches/lrucache.py61
-rw-r--r--synapse/util/caches/treecache.py17
-rw-r--r--tests/api/test_filtering.py5
-rw-r--r--tests/events/test_utils.py22
-rw-r--r--tests/storage/test__base.py6
-rw-r--r--tests/util/test_expiring_cache.py84
-rw-r--r--tests/util/test_lrucache.py55
29 files changed, 1381 insertions, 838 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index f93e45a744..03a215ab1b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -16,18 +16,14 @@
 import logging
 
 import pymacaroons
-from canonicaljson import encode_canonical_json
-from signedjson.key import decode_verify_key_bytes
-from signedjson.sign import verify_signed_json, SignatureVerifyException
 from twisted.internet import defer
-from unpaddedbase64 import decode_base64
 
 import synapse.types
+from synapse import event_auth
 from synapse.api.constants import EventTypes, Membership, JoinRules
-from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
-from synapse.types import UserID, get_domain_from_id
+from synapse.api.errors import AuthError, Codes
+from synapse.types import UserID
 from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -78,147 +74,7 @@ class Auth(object):
             True if the auth checks pass.
         """
         with Measure(self.clock, "auth.check"):
-            self.check_size_limits(event)
-
-            if not hasattr(event, "room_id"):
-                raise AuthError(500, "Event has no room_id: %s" % event)
-
-            if do_sig_check:
-                sender_domain = get_domain_from_id(event.sender)
-                event_id_domain = get_domain_from_id(event.event_id)
-
-                is_invite_via_3pid = (
-                    event.type == EventTypes.Member
-                    and event.membership == Membership.INVITE
-                    and "third_party_invite" in event.content
-                )
-
-                # Check the sender's domain has signed the event
-                if not event.signatures.get(sender_domain):
-                    # We allow invites via 3pid to have a sender from a different
-                    # HS, as the sender must match the sender of the original
-                    # 3pid invite. This is checked further down with the
-                    # other dedicated membership checks.
-                    if not is_invite_via_3pid:
-                        raise AuthError(403, "Event not signed by sender's server")
-
-                # Check the event_id's domain has signed the event
-                if not event.signatures.get(event_id_domain):
-                    raise AuthError(403, "Event not signed by sending server")
-
-            if auth_events is None:
-                # Oh, we don't know what the state of the room was, so we
-                # are trusting that this is allowed (at least for now)
-                logger.warn("Trusting event: %s", event.event_id)
-                return True
-
-            if event.type == EventTypes.Create:
-                room_id_domain = get_domain_from_id(event.room_id)
-                if room_id_domain != sender_domain:
-                    raise AuthError(
-                        403,
-                        "Creation event's room_id domain does not match sender's"
-                    )
-                # FIXME
-                return True
-
-            creation_event = auth_events.get((EventTypes.Create, ""), None)
-
-            if not creation_event:
-                raise SynapseError(
-                    403,
-                    "Room %r does not exist" % (event.room_id,)
-                )
-
-            creating_domain = get_domain_from_id(event.room_id)
-            originating_domain = get_domain_from_id(event.sender)
-            if creating_domain != originating_domain:
-                if not self.can_federate(event, auth_events):
-                    raise AuthError(
-                        403,
-                        "This room has been marked as unfederatable."
-                    )
-
-            # FIXME: Temp hack
-            if event.type == EventTypes.Aliases:
-                if not event.is_state():
-                    raise AuthError(
-                        403,
-                        "Alias event must be a state event",
-                    )
-                if not event.state_key:
-                    raise AuthError(
-                        403,
-                        "Alias event must have non-empty state_key"
-                    )
-                sender_domain = get_domain_from_id(event.sender)
-                if event.state_key != sender_domain:
-                    raise AuthError(
-                        403,
-                        "Alias event's state_key does not match sender's domain"
-                    )
-                return True
-
-            logger.debug(
-                "Auth events: %s",
-                [a.event_id for a in auth_events.values()]
-            )
-
-            if event.type == EventTypes.Member:
-                allowed = self.is_membership_change_allowed(
-                    event, auth_events
-                )
-                if allowed:
-                    logger.debug("Allowing! %s", event)
-                else:
-                    logger.debug("Denying! %s", event)
-                return allowed
-
-            self.check_event_sender_in_room(event, auth_events)
-
-            # Special case to allow m.room.third_party_invite events wherever
-            # a user is allowed to issue invites.  Fixes
-            # https://github.com/vector-im/vector-web/issues/1208 hopefully
-            if event.type == EventTypes.ThirdPartyInvite:
-                user_level = self._get_user_power_level(event.user_id, auth_events)
-                invite_level = self._get_named_level(auth_events, "invite", 0)
-
-                if user_level < invite_level:
-                    raise AuthError(
-                        403, (
-                            "You cannot issue a third party invite for %s." %
-                            (event.content.display_name,)
-                        )
-                    )
-                else:
-                    return True
-
-            self._can_send_event(event, auth_events)
-
-            if event.type == EventTypes.PowerLevels:
-                self._check_power_levels(event, auth_events)
-
-            if event.type == EventTypes.Redaction:
-                self.check_redaction(event, auth_events)
-
-            logger.debug("Allowing! %s", event)
-
-    def check_size_limits(self, event):
-        def too_big(field):
-            raise EventSizeError("%s too large" % (field,))
-
-        if len(event.user_id) > 255:
-            too_big("user_id")
-        if len(event.room_id) > 255:
-            too_big("room_id")
-        if event.is_state() and len(event.state_key) > 255:
-            too_big("state_key")
-        if len(event.type) > 255:
-            too_big("type")
-        if len(event.event_id) > 255:
-            too_big("event_id")
-        if len(encode_canonical_json(event.get_pdu_json())) > 65536:
-            too_big("event")
+            event_auth.check(event, auth_events, do_sig_check=do_sig_check)
 
     @defer.inlineCallbacks
     def check_joined_room(self, room_id, user_id, current_state=None):
@@ -290,7 +146,7 @@ class Auth(object):
         with Measure(self.clock, "check_host_in_room"):
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
-            logger.info("calling resolve_state_groups from check_host_in_room")
+            logger.debug("calling resolve_state_groups from check_host_in_room")
             entry = yield self.state.resolve_state_groups(
                 room_id, latest_event_ids
             )
@@ -300,16 +156,6 @@ class Auth(object):
             )
             defer.returnValue(ret)
 
-    def check_event_sender_in_room(self, event, auth_events):
-        key = (EventTypes.Member, event.user_id, )
-        member_event = auth_events.get(key)
-
-        return self._check_joined_room(
-            member_event,
-            event.user_id,
-            event.room_id
-        )
-
     def _check_joined_room(self, member, user_id, room_id):
         if not member or member.membership != Membership.JOIN:
             raise AuthError(403, "User %s not in room %s (%s)" % (
@@ -321,267 +167,8 @@ class Auth(object):
 
         return creation_event.content.get("m.federate", True) is True
 
-    @log_function
-    def is_membership_change_allowed(self, event, auth_events):
-        membership = event.content["membership"]
-
-        # Check if this is the room creator joining:
-        if len(event.prev_events) == 1 and Membership.JOIN == membership:
-            # Get room creation event:
-            key = (EventTypes.Create, "", )
-            create = auth_events.get(key)
-            if create and event.prev_events[0][0] == create.event_id:
-                if create.content["creator"] == event.state_key:
-                    return True
-
-        target_user_id = event.state_key
-
-        creating_domain = get_domain_from_id(event.room_id)
-        target_domain = get_domain_from_id(target_user_id)
-        if creating_domain != target_domain:
-            if not self.can_federate(event, auth_events):
-                raise AuthError(
-                    403,
-                    "This room has been marked as unfederatable."
-                )
-
-        # get info about the caller
-        key = (EventTypes.Member, event.user_id, )
-        caller = auth_events.get(key)
-
-        caller_in_room = caller and caller.membership == Membership.JOIN
-        caller_invited = caller and caller.membership == Membership.INVITE
-
-        # get info about the target
-        key = (EventTypes.Member, target_user_id, )
-        target = auth_events.get(key)
-
-        target_in_room = target and target.membership == Membership.JOIN
-        target_banned = target and target.membership == Membership.BAN
-
-        key = (EventTypes.JoinRules, "", )
-        join_rule_event = auth_events.get(key)
-        if join_rule_event:
-            join_rule = join_rule_event.content.get(
-                "join_rule", JoinRules.INVITE
-            )
-        else:
-            join_rule = JoinRules.INVITE
-
-        user_level = self._get_user_power_level(event.user_id, auth_events)
-        target_level = self._get_user_power_level(
-            target_user_id, auth_events
-        )
-
-        # FIXME (erikj): What should we do here as the default?
-        ban_level = self._get_named_level(auth_events, "ban", 50)
-
-        logger.debug(
-            "is_membership_change_allowed: %s",
-            {
-                "caller_in_room": caller_in_room,
-                "caller_invited": caller_invited,
-                "target_banned": target_banned,
-                "target_in_room": target_in_room,
-                "membership": membership,
-                "join_rule": join_rule,
-                "target_user_id": target_user_id,
-                "event.user_id": event.user_id,
-            }
-        )
-
-        if Membership.INVITE == membership and "third_party_invite" in event.content:
-            if not self._verify_third_party_invite(event, auth_events):
-                raise AuthError(403, "You are not invited to this room.")
-            if target_banned:
-                raise AuthError(
-                    403, "%s is banned from the room" % (target_user_id,)
-                )
-            return True
-
-        if Membership.JOIN != membership:
-            if (caller_invited
-                    and Membership.LEAVE == membership
-                    and target_user_id == event.user_id):
-                return True
-
-            if not caller_in_room:  # caller isn't joined
-                raise AuthError(
-                    403,
-                    "%s not in room %s." % (event.user_id, event.room_id,)
-                )
-
-        if Membership.INVITE == membership:
-            # TODO (erikj): We should probably handle this more intelligently
-            # PRIVATE join rules.
-
-            # Invites are valid iff caller is in the room and target isn't.
-            if target_banned:
-                raise AuthError(
-                    403, "%s is banned from the room" % (target_user_id,)
-                )
-            elif target_in_room:  # the target is already in the room.
-                raise AuthError(403, "%s is already in the room." %
-                                     target_user_id)
-            else:
-                invite_level = self._get_named_level(auth_events, "invite", 0)
-
-                if user_level < invite_level:
-                    raise AuthError(
-                        403, "You cannot invite user %s." % target_user_id
-                    )
-        elif Membership.JOIN == membership:
-            # Joins are valid iff caller == target and they were:
-            # invited: They are accepting the invitation
-            # joined: It's a NOOP
-            if event.user_id != target_user_id:
-                raise AuthError(403, "Cannot force another user to join.")
-            elif target_banned:
-                raise AuthError(403, "You are banned from this room")
-            elif join_rule == JoinRules.PUBLIC:
-                pass
-            elif join_rule == JoinRules.INVITE:
-                if not caller_in_room and not caller_invited:
-                    raise AuthError(403, "You are not invited to this room.")
-            else:
-                # TODO (erikj): may_join list
-                # TODO (erikj): private rooms
-                raise AuthError(403, "You are not allowed to join this room")
-        elif Membership.LEAVE == membership:
-            # TODO (erikj): Implement kicks.
-            if target_banned and user_level < ban_level:
-                raise AuthError(
-                    403, "You cannot unban user &s." % (target_user_id,)
-                )
-            elif target_user_id != event.user_id:
-                kick_level = self._get_named_level(auth_events, "kick", 50)
-
-                if user_level < kick_level or user_level <= target_level:
-                    raise AuthError(
-                        403, "You cannot kick user %s." % target_user_id
-                    )
-        elif Membership.BAN == membership:
-            if user_level < ban_level or user_level <= target_level:
-                raise AuthError(403, "You don't have permission to ban")
-        else:
-            raise AuthError(500, "Unknown membership %s" % membership)
-
-        return True
-
-    def _verify_third_party_invite(self, event, auth_events):
-        """
-        Validates that the invite event is authorized by a previous third-party invite.
-
-        Checks that the public key, and keyserver, match those in the third party invite,
-        and that the invite event has a signature issued using that public key.
-
-        Args:
-            event: The m.room.member join event being validated.
-            auth_events: All relevant previous context events which may be used
-                for authorization decisions.
-
-        Return:
-            True if the event fulfills the expectations of a previous third party
-            invite event.
-        """
-        if "third_party_invite" not in event.content:
-            return False
-        if "signed" not in event.content["third_party_invite"]:
-            return False
-        signed = event.content["third_party_invite"]["signed"]
-        for key in {"mxid", "token"}:
-            if key not in signed:
-                return False
-
-        token = signed["token"]
-
-        invite_event = auth_events.get(
-            (EventTypes.ThirdPartyInvite, token,)
-        )
-        if not invite_event:
-            return False
-
-        if invite_event.sender != event.sender:
-            return False
-
-        if event.user_id != invite_event.user_id:
-            return False
-
-        if signed["mxid"] != event.state_key:
-            return False
-        if signed["token"] != token:
-            return False
-
-        for public_key_object in self.get_public_keys(invite_event):
-            public_key = public_key_object["public_key"]
-            try:
-                for server, signature_block in signed["signatures"].items():
-                    for key_name, encoded_signature in signature_block.items():
-                        if not key_name.startswith("ed25519:"):
-                            continue
-                        verify_key = decode_verify_key_bytes(
-                            key_name,
-                            decode_base64(public_key)
-                        )
-                        verify_signed_json(signed, server, verify_key)
-
-                        # We got the public key from the invite, so we know that the
-                        # correct server signed the signed bundle.
-                        # The caller is responsible for checking that the signing
-                        # server has not revoked that public key.
-                        return True
-            except (KeyError, SignatureVerifyException,):
-                continue
-        return False
-
     def get_public_keys(self, invite_event):
-        public_keys = []
-        if "public_key" in invite_event.content:
-            o = {
-                "public_key": invite_event.content["public_key"],
-            }
-            if "key_validity_url" in invite_event.content:
-                o["key_validity_url"] = invite_event.content["key_validity_url"]
-            public_keys.append(o)
-        public_keys.extend(invite_event.content.get("public_keys", []))
-        return public_keys
-
-    def _get_power_level_event(self, auth_events):
-        key = (EventTypes.PowerLevels, "", )
-        return auth_events.get(key)
-
-    def _get_user_power_level(self, user_id, auth_events):
-        power_level_event = self._get_power_level_event(auth_events)
-
-        if power_level_event:
-            level = power_level_event.content.get("users", {}).get(user_id)
-            if not level:
-                level = power_level_event.content.get("users_default", 0)
-
-            if level is None:
-                return 0
-            else:
-                return int(level)
-        else:
-            key = (EventTypes.Create, "", )
-            create_event = auth_events.get(key)
-            if (create_event is not None and
-                    create_event.content["creator"] == user_id):
-                return 100
-            else:
-                return 0
-
-    def _get_named_level(self, auth_events, name, default):
-        power_level_event = self._get_power_level_event(auth_events)
-
-        if not power_level_event:
-            return default
-
-        level = power_level_event.content.get(name, None)
-        if level is not None:
-            return int(level)
-        else:
-            return default
+        return event_auth.get_public_keys(invite_event)
 
     @defer.inlineCallbacks
     def get_user_by_req(self, request, allow_guest=False, rights="access"):
@@ -974,56 +561,6 @@ class Auth(object):
 
         defer.returnValue(auth_ids)
 
-    def _get_send_level(self, etype, state_key, auth_events):
-        key = (EventTypes.PowerLevels, "", )
-        send_level_event = auth_events.get(key)
-        send_level = None
-        if send_level_event:
-            send_level = send_level_event.content.get("events", {}).get(
-                etype
-            )
-            if send_level is None:
-                if state_key is not None:
-                    send_level = send_level_event.content.get(
-                        "state_default", 50
-                    )
-                else:
-                    send_level = send_level_event.content.get(
-                        "events_default", 0
-                    )
-
-        if send_level:
-            send_level = int(send_level)
-        else:
-            send_level = 0
-
-        return send_level
-
-    @log_function
-    def _can_send_event(self, event, auth_events):
-        send_level = self._get_send_level(
-            event.type, event.get("state_key", None), auth_events
-        )
-        user_level = self._get_user_power_level(event.user_id, auth_events)
-
-        if user_level < send_level:
-            raise AuthError(
-                403,
-                "You don't have permission to post that to the room. " +
-                "user_level (%d) < send_level (%d)" % (user_level, send_level)
-            )
-
-        # Check state_key
-        if hasattr(event, "state_key"):
-            if event.state_key.startswith("@"):
-                if event.state_key != event.user_id:
-                    raise AuthError(
-                        403,
-                        "You are not allowed to set others state"
-                    )
-
-        return True
-
     def check_redaction(self, event, auth_events):
         """Check whether the event sender is allowed to redact the target event.
 
@@ -1037,107 +574,7 @@ class Auth(object):
             AuthError if the event sender is definitely not allowed to redact
             the target event.
         """
-        user_level = self._get_user_power_level(event.user_id, auth_events)
-
-        redact_level = self._get_named_level(auth_events, "redact", 50)
-
-        if user_level >= redact_level:
-            return False
-
-        redacter_domain = get_domain_from_id(event.event_id)
-        redactee_domain = get_domain_from_id(event.redacts)
-        if redacter_domain == redactee_domain:
-            return True
-
-        raise AuthError(
-            403,
-            "You don't have permission to redact events"
-        )
-
-    def _check_power_levels(self, event, auth_events):
-        user_list = event.content.get("users", {})
-        # Validate users
-        for k, v in user_list.items():
-            try:
-                UserID.from_string(k)
-            except:
-                raise SynapseError(400, "Not a valid user_id: %s" % (k,))
-
-            try:
-                int(v)
-            except:
-                raise SynapseError(400, "Not a valid power level: %s" % (v,))
-
-        key = (event.type, event.state_key, )
-        current_state = auth_events.get(key)
-
-        if not current_state:
-            return
-
-        user_level = self._get_user_power_level(event.user_id, auth_events)
-
-        # Check other levels:
-        levels_to_check = [
-            ("users_default", None),
-            ("events_default", None),
-            ("state_default", None),
-            ("ban", None),
-            ("redact", None),
-            ("kick", None),
-            ("invite", None),
-        ]
-
-        old_list = current_state.content.get("users")
-        for user in set(old_list.keys() + user_list.keys()):
-            levels_to_check.append(
-                (user, "users")
-            )
-
-        old_list = current_state.content.get("events")
-        new_list = event.content.get("events")
-        for ev_id in set(old_list.keys() + new_list.keys()):
-            levels_to_check.append(
-                (ev_id, "events")
-            )
-
-        old_state = current_state.content
-        new_state = event.content
-
-        for level_to_check, dir in levels_to_check:
-            old_loc = old_state
-            new_loc = new_state
-            if dir:
-                old_loc = old_loc.get(dir, {})
-                new_loc = new_loc.get(dir, {})
-
-            if level_to_check in old_loc:
-                old_level = int(old_loc[level_to_check])
-            else:
-                old_level = None
-
-            if level_to_check in new_loc:
-                new_level = int(new_loc[level_to_check])
-            else:
-                new_level = None
-
-            if new_level is not None and old_level is not None:
-                if new_level == old_level:
-                    continue
-
-            if dir == "users" and level_to_check != event.user_id:
-                if old_level == user_level:
-                    raise AuthError(
-                        403,
-                        "You don't have permission to remove ops level equal "
-                        "to your own"
-                    )
-
-            if old_level > user_level or new_level > user_level:
-                raise AuthError(
-                    403,
-                    "You don't have permission to add ops level greater "
-                    "than your own"
-                )
+        return event_auth.check_redaction(event, auth_events)
 
     @defer.inlineCallbacks
     def check_can_change_room_list(self, room_id, user):
@@ -1167,10 +604,10 @@ class Auth(object):
         if power_level_event:
             auth_events[(EventTypes.PowerLevels, "")] = power_level_event
 
-        send_level = self._get_send_level(
+        send_level = event_auth.get_send_level(
             EventTypes.Aliases, "", auth_events
         )
-        user_level = self._get_user_power_level(user_id, auth_events)
+        user_level = event_auth.get_user_power_level(user_id, auth_events)
 
         if user_level < send_level:
             raise AuthError(
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
new file mode 100644
index 0000000000..4096c606f1
--- /dev/null
+++ b/synapse/event_auth.py
@@ -0,0 +1,678 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 - 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from canonicaljson import encode_canonical_json
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json, SignatureVerifyException
+from unpaddedbase64 import decode_base64
+
+from synapse.api.constants import EventTypes, Membership, JoinRules
+from synapse.api.errors import AuthError, SynapseError, EventSizeError
+from synapse.types import UserID, get_domain_from_id
+
+logger = logging.getLogger(__name__)
+
+
+def check(event, auth_events, do_sig_check=True, do_size_check=True):
+    """ Checks if this event is correctly authed.
+
+    Args:
+        event: the event being checked.
+        auth_events (dict: event-key -> event): the existing room state.
+
+
+    Returns:
+        True if the auth checks pass.
+    """
+    if do_size_check:
+        _check_size_limits(event)
+
+    if not hasattr(event, "room_id"):
+        raise AuthError(500, "Event has no room_id: %s" % event)
+
+    if do_sig_check:
+        sender_domain = get_domain_from_id(event.sender)
+        event_id_domain = get_domain_from_id(event.event_id)
+
+        is_invite_via_3pid = (
+            event.type == EventTypes.Member
+            and event.membership == Membership.INVITE
+            and "third_party_invite" in event.content
+        )
+
+        # Check the sender's domain has signed the event
+        if not event.signatures.get(sender_domain):
+            # We allow invites via 3pid to have a sender from a different
+            # HS, as the sender must match the sender of the original
+            # 3pid invite. This is checked further down with the
+            # other dedicated membership checks.
+            if not is_invite_via_3pid:
+                raise AuthError(403, "Event not signed by sender's server")
+
+        # Check the event_id's domain has signed the event
+        if not event.signatures.get(event_id_domain):
+            raise AuthError(403, "Event not signed by sending server")
+
+    if auth_events is None:
+        # Oh, we don't know what the state of the room was, so we
+        # are trusting that this is allowed (at least for now)
+        logger.warn("Trusting event: %s", event.event_id)
+        return True
+
+    if event.type == EventTypes.Create:
+        room_id_domain = get_domain_from_id(event.room_id)
+        if room_id_domain != sender_domain:
+            raise AuthError(
+                403,
+                "Creation event's room_id domain does not match sender's"
+            )
+        # FIXME
+        return True
+
+    creation_event = auth_events.get((EventTypes.Create, ""), None)
+
+    if not creation_event:
+        raise SynapseError(
+            403,
+            "Room %r does not exist" % (event.room_id,)
+        )
+
+    creating_domain = get_domain_from_id(event.room_id)
+    originating_domain = get_domain_from_id(event.sender)
+    if creating_domain != originating_domain:
+        if not _can_federate(event, auth_events):
+            raise AuthError(
+                403,
+                "This room has been marked as unfederatable."
+            )
+
+    # FIXME: Temp hack
+    if event.type == EventTypes.Aliases:
+        if not event.is_state():
+            raise AuthError(
+                403,
+                "Alias event must be a state event",
+            )
+        if not event.state_key:
+            raise AuthError(
+                403,
+                "Alias event must have non-empty state_key"
+            )
+        sender_domain = get_domain_from_id(event.sender)
+        if event.state_key != sender_domain:
+            raise AuthError(
+                403,
+                "Alias event's state_key does not match sender's domain"
+            )
+        return True
+
+    if logger.isEnabledFor(logging.DEBUG):
+        logger.debug(
+            "Auth events: %s",
+            [a.event_id for a in auth_events.values()]
+        )
+
+    if event.type == EventTypes.Member:
+        allowed = _is_membership_change_allowed(
+            event, auth_events
+        )
+        if allowed:
+            logger.debug("Allowing! %s", event)
+        else:
+            logger.debug("Denying! %s", event)
+        return allowed
+
+    _check_event_sender_in_room(event, auth_events)
+
+    # Special case to allow m.room.third_party_invite events wherever
+    # a user is allowed to issue invites.  Fixes
+    # https://github.com/vector-im/vector-web/issues/1208 hopefully
+    if event.type == EventTypes.ThirdPartyInvite:
+        user_level = get_user_power_level(event.user_id, auth_events)
+        invite_level = _get_named_level(auth_events, "invite", 0)
+
+        if user_level < invite_level:
+            raise AuthError(
+                403, (
+                    "You cannot issue a third party invite for %s." %
+                    (event.content.display_name,)
+                )
+            )
+        else:
+            return True
+
+    _can_send_event(event, auth_events)
+
+    if event.type == EventTypes.PowerLevels:
+        _check_power_levels(event, auth_events)
+
+    if event.type == EventTypes.Redaction:
+        check_redaction(event, auth_events)
+
+    logger.debug("Allowing! %s", event)
+
+
+def _check_size_limits(event):
+    def too_big(field):
+        raise EventSizeError("%s too large" % (field,))
+
+    if len(event.user_id) > 255:
+        too_big("user_id")
+    if len(event.room_id) > 255:
+        too_big("room_id")
+    if event.is_state() and len(event.state_key) > 255:
+        too_big("state_key")
+    if len(event.type) > 255:
+        too_big("type")
+    if len(event.event_id) > 255:
+        too_big("event_id")
+    if len(encode_canonical_json(event.get_pdu_json())) > 65536:
+        too_big("event")
+
+
+def _can_federate(event, auth_events):
+    creation_event = auth_events.get((EventTypes.Create, ""))
+
+    return creation_event.content.get("m.federate", True) is True
+
+
+def _is_membership_change_allowed(event, auth_events):
+    membership = event.content["membership"]
+
+    # Check if this is the room creator joining:
+    if len(event.prev_events) == 1 and Membership.JOIN == membership:
+        # Get room creation event:
+        key = (EventTypes.Create, "", )
+        create = auth_events.get(key)
+        if create and event.prev_events[0][0] == create.event_id:
+            if create.content["creator"] == event.state_key:
+                return True
+
+    target_user_id = event.state_key
+
+    creating_domain = get_domain_from_id(event.room_id)
+    target_domain = get_domain_from_id(target_user_id)
+    if creating_domain != target_domain:
+        if not _can_federate(event, auth_events):
+            raise AuthError(
+                403,
+                "This room has been marked as unfederatable."
+            )
+
+    # get info about the caller
+    key = (EventTypes.Member, event.user_id, )
+    caller = auth_events.get(key)
+
+    caller_in_room = caller and caller.membership == Membership.JOIN
+    caller_invited = caller and caller.membership == Membership.INVITE
+
+    # get info about the target
+    key = (EventTypes.Member, target_user_id, )
+    target = auth_events.get(key)
+
+    target_in_room = target and target.membership == Membership.JOIN
+    target_banned = target and target.membership == Membership.BAN
+
+    key = (EventTypes.JoinRules, "", )
+    join_rule_event = auth_events.get(key)
+    if join_rule_event:
+        join_rule = join_rule_event.content.get(
+            "join_rule", JoinRules.INVITE
+        )
+    else:
+        join_rule = JoinRules.INVITE
+
+    user_level = get_user_power_level(event.user_id, auth_events)
+    target_level = get_user_power_level(
+        target_user_id, auth_events
+    )
+
+    # FIXME (erikj): What should we do here as the default?
+    ban_level = _get_named_level(auth_events, "ban", 50)
+
+    logger.debug(
+        "_is_membership_change_allowed: %s",
+        {
+            "caller_in_room": caller_in_room,
+            "caller_invited": caller_invited,
+            "target_banned": target_banned,
+            "target_in_room": target_in_room,
+            "membership": membership,
+            "join_rule": join_rule,
+            "target_user_id": target_user_id,
+            "event.user_id": event.user_id,
+        }
+    )
+
+    if Membership.INVITE == membership and "third_party_invite" in event.content:
+        if not _verify_third_party_invite(event, auth_events):
+            raise AuthError(403, "You are not invited to this room.")
+        if target_banned:
+            raise AuthError(
+                403, "%s is banned from the room" % (target_user_id,)
+            )
+        return True
+
+    if Membership.JOIN != membership:
+        if (caller_invited
+                and Membership.LEAVE == membership
+                and target_user_id == event.user_id):
+            return True
+
+        if not caller_in_room:  # caller isn't joined
+            raise AuthError(
+                403,
+                "%s not in room %s." % (event.user_id, event.room_id,)
+            )
+
+    if Membership.INVITE == membership:
+        # TODO (erikj): We should probably handle this more intelligently
+        # PRIVATE join rules.
+
+        # Invites are valid iff caller is in the room and target isn't.
+        if target_banned:
+            raise AuthError(
+                403, "%s is banned from the room" % (target_user_id,)
+            )
+        elif target_in_room:  # the target is already in the room.
+            raise AuthError(403, "%s is already in the room." %
+                                 target_user_id)
+        else:
+            invite_level = _get_named_level(auth_events, "invite", 0)
+
+            if user_level < invite_level:
+                raise AuthError(
+                    403, "You cannot invite user %s." % target_user_id
+                )
+    elif Membership.JOIN == membership:
+        # Joins are valid iff caller == target and they were:
+        # invited: They are accepting the invitation
+        # joined: It's a NOOP
+        if event.user_id != target_user_id:
+            raise AuthError(403, "Cannot force another user to join.")
+        elif target_banned:
+            raise AuthError(403, "You are banned from this room")
+        elif join_rule == JoinRules.PUBLIC:
+            pass
+        elif join_rule == JoinRules.INVITE:
+            if not caller_in_room and not caller_invited:
+                raise AuthError(403, "You are not invited to this room.")
+        else:
+            # TODO (erikj): may_join list
+            # TODO (erikj): private rooms
+            raise AuthError(403, "You are not allowed to join this room")
+    elif Membership.LEAVE == membership:
+        # TODO (erikj): Implement kicks.
+        if target_banned and user_level < ban_level:
+            raise AuthError(
+                403, "You cannot unban user &s." % (target_user_id,)
+            )
+        elif target_user_id != event.user_id:
+            kick_level = _get_named_level(auth_events, "kick", 50)
+
+            if user_level < kick_level or user_level <= target_level:
+                raise AuthError(
+                    403, "You cannot kick user %s." % target_user_id
+                )
+    elif Membership.BAN == membership:
+        if user_level < ban_level or user_level <= target_level:
+            raise AuthError(403, "You don't have permission to ban")
+    else:
+        raise AuthError(500, "Unknown membership %s" % membership)
+
+    return True
+
+
+def _check_event_sender_in_room(event, auth_events):
+    key = (EventTypes.Member, event.user_id, )
+    member_event = auth_events.get(key)
+
+    return _check_joined_room(
+        member_event,
+        event.user_id,
+        event.room_id
+    )
+
+
+def _check_joined_room(member, user_id, room_id):
+    if not member or member.membership != Membership.JOIN:
+        raise AuthError(403, "User %s not in room %s (%s)" % (
+            user_id, room_id, repr(member)
+        ))
+
+
+def get_send_level(etype, state_key, auth_events):
+    key = (EventTypes.PowerLevels, "", )
+    send_level_event = auth_events.get(key)
+    send_level = None
+    if send_level_event:
+        send_level = send_level_event.content.get("events", {}).get(
+            etype
+        )
+        if send_level is None:
+            if state_key is not None:
+                send_level = send_level_event.content.get(
+                    "state_default", 50
+                )
+            else:
+                send_level = send_level_event.content.get(
+                    "events_default", 0
+                )
+
+    if send_level:
+        send_level = int(send_level)
+    else:
+        send_level = 0
+
+    return send_level
+
+
+def _can_send_event(event, auth_events):
+    send_level = get_send_level(
+        event.type, event.get("state_key", None), auth_events
+    )
+    user_level = get_user_power_level(event.user_id, auth_events)
+
+    if user_level < send_level:
+        raise AuthError(
+            403,
+            "You don't have permission to post that to the room. " +
+            "user_level (%d) < send_level (%d)" % (user_level, send_level)
+        )
+
+    # Check state_key
+    if hasattr(event, "state_key"):
+        if event.state_key.startswith("@"):
+            if event.state_key != event.user_id:
+                raise AuthError(
+                    403,
+                    "You are not allowed to set others state"
+                )
+
+    return True
+
+
+def check_redaction(event, auth_events):
+    """Check whether the event sender is allowed to redact the target event.
+
+    Returns:
+        True if the the sender is allowed to redact the target event if the
+        target event was created by them.
+        False if the sender is allowed to redact the target event with no
+        further checks.
+
+    Raises:
+        AuthError if the event sender is definitely not allowed to redact
+        the target event.
+    """
+    user_level = get_user_power_level(event.user_id, auth_events)
+
+    redact_level = _get_named_level(auth_events, "redact", 50)
+
+    if user_level >= redact_level:
+        return False
+
+    redacter_domain = get_domain_from_id(event.event_id)
+    redactee_domain = get_domain_from_id(event.redacts)
+    if redacter_domain == redactee_domain:
+        return True
+
+    raise AuthError(
+        403,
+        "You don't have permission to redact events"
+    )
+
+
+def _check_power_levels(event, auth_events):
+    user_list = event.content.get("users", {})
+    # Validate users
+    for k, v in user_list.items():
+        try:
+            UserID.from_string(k)
+        except:
+            raise SynapseError(400, "Not a valid user_id: %s" % (k,))
+
+        try:
+            int(v)
+        except:
+            raise SynapseError(400, "Not a valid power level: %s" % (v,))
+
+    key = (event.type, event.state_key, )
+    current_state = auth_events.get(key)
+
+    if not current_state:
+        return
+
+    user_level = get_user_power_level(event.user_id, auth_events)
+
+    # Check other levels:
+    levels_to_check = [
+        ("users_default", None),
+        ("events_default", None),
+        ("state_default", None),
+        ("ban", None),
+        ("redact", None),
+        ("kick", None),
+        ("invite", None),
+    ]
+
+    old_list = current_state.content.get("users")
+    for user in set(old_list.keys() + user_list.keys()):
+        levels_to_check.append(
+            (user, "users")
+        )
+
+    old_list = current_state.content.get("events")
+    new_list = event.content.get("events")
+    for ev_id in set(old_list.keys() + new_list.keys()):
+        levels_to_check.append(
+            (ev_id, "events")
+        )
+
+    old_state = current_state.content
+    new_state = event.content
+
+    for level_to_check, dir in levels_to_check:
+        old_loc = old_state
+        new_loc = new_state
+        if dir:
+            old_loc = old_loc.get(dir, {})
+            new_loc = new_loc.get(dir, {})
+
+        if level_to_check in old_loc:
+            old_level = int(old_loc[level_to_check])
+        else:
+            old_level = None
+
+        if level_to_check in new_loc:
+            new_level = int(new_loc[level_to_check])
+        else:
+            new_level = None
+
+        if new_level is not None and old_level is not None:
+            if new_level == old_level:
+                continue
+
+        if dir == "users" and level_to_check != event.user_id:
+            if old_level == user_level:
+                raise AuthError(
+                    403,
+                    "You don't have permission to remove ops level equal "
+                    "to your own"
+                )
+
+        if old_level > user_level or new_level > user_level:
+            raise AuthError(
+                403,
+                "You don't have permission to add ops level greater "
+                "than your own"
+            )
+
+
+def _get_power_level_event(auth_events):
+    key = (EventTypes.PowerLevels, "", )
+    return auth_events.get(key)
+
+
+def get_user_power_level(user_id, auth_events):
+    power_level_event = _get_power_level_event(auth_events)
+
+    if power_level_event:
+        level = power_level_event.content.get("users", {}).get(user_id)
+        if not level:
+            level = power_level_event.content.get("users_default", 0)
+
+        if level is None:
+            return 0
+        else:
+            return int(level)
+    else:
+        key = (EventTypes.Create, "", )
+        create_event = auth_events.get(key)
+        if (create_event is not None and
+                create_event.content["creator"] == user_id):
+            return 100
+        else:
+            return 0
+
+
+def _get_named_level(auth_events, name, default):
+    power_level_event = _get_power_level_event(auth_events)
+
+    if not power_level_event:
+        return default
+
+    level = power_level_event.content.get(name, None)
+    if level is not None:
+        return int(level)
+    else:
+        return default
+
+
+def _verify_third_party_invite(event, auth_events):
+    """
+    Validates that the invite event is authorized by a previous third-party invite.
+
+    Checks that the public key, and keyserver, match those in the third party invite,
+    and that the invite event has a signature issued using that public key.
+
+    Args:
+        event: The m.room.member join event being validated.
+        auth_events: All relevant previous context events which may be used
+            for authorization decisions.
+
+    Return:
+        True if the event fulfills the expectations of a previous third party
+        invite event.
+    """
+    if "third_party_invite" not in event.content:
+        return False
+    if "signed" not in event.content["third_party_invite"]:
+        return False
+    signed = event.content["third_party_invite"]["signed"]
+    for key in {"mxid", "token"}:
+        if key not in signed:
+            return False
+
+    token = signed["token"]
+
+    invite_event = auth_events.get(
+        (EventTypes.ThirdPartyInvite, token,)
+    )
+    if not invite_event:
+        return False
+
+    if invite_event.sender != event.sender:
+        return False
+
+    if event.user_id != invite_event.user_id:
+        return False
+
+    if signed["mxid"] != event.state_key:
+        return False
+    if signed["token"] != token:
+        return False
+
+    for public_key_object in get_public_keys(invite_event):
+        public_key = public_key_object["public_key"]
+        try:
+            for server, signature_block in signed["signatures"].items():
+                for key_name, encoded_signature in signature_block.items():
+                    if not key_name.startswith("ed25519:"):
+                        continue
+                    verify_key = decode_verify_key_bytes(
+                        key_name,
+                        decode_base64(public_key)
+                    )
+                    verify_signed_json(signed, server, verify_key)
+
+                    # We got the public key from the invite, so we know that the
+                    # correct server signed the signed bundle.
+                    # The caller is responsible for checking that the signing
+                    # server has not revoked that public key.
+                    return True
+        except (KeyError, SignatureVerifyException,):
+            continue
+    return False
+
+
+def get_public_keys(invite_event):
+    public_keys = []
+    if "public_key" in invite_event.content:
+        o = {
+            "public_key": invite_event.content["public_key"],
+        }
+        if "key_validity_url" in invite_event.content:
+            o["key_validity_url"] = invite_event.content["key_validity_url"]
+        public_keys.append(o)
+    public_keys.extend(invite_event.content.get("public_keys", []))
+    return public_keys
+
+
+def auth_types_for_event(event):
+    """Given an event, return a list of (EventType, StateKey) that may be
+    needed to auth the event. The returned list may be a superset of what
+    would actually be required depending on the full state of the room.
+
+    Used to limit the number of events to fetch from the database to
+    actually auth the event.
+    """
+    if event.type == EventTypes.Create:
+        return []
+
+    auth_types = []
+
+    auth_types.append((EventTypes.PowerLevels, "", ))
+    auth_types.append((EventTypes.Member, event.user_id, ))
+    auth_types.append((EventTypes.Create, "", ))
+
+    if event.type == EventTypes.Member:
+        membership = event.content["membership"]
+        if membership in [Membership.JOIN, Membership.INVITE]:
+            auth_types.append((EventTypes.JoinRules, "", ))
+
+        auth_types.append((EventTypes.Member, event.state_key, ))
+
+        if membership == Membership.INVITE:
+            if "third_party_invite" in event.content:
+                key = (
+                    EventTypes.ThirdPartyInvite,
+                    event.content["third_party_invite"]["signed"]["token"]
+                )
+                auth_types.append(key)
+
+    return auth_types
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index da9f3ad436..e673e96cc0 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -79,7 +79,6 @@ class EventBase(object):
     auth_events = _event_dict_property("auth_events")
     depth = _event_dict_property("depth")
     content = _event_dict_property("content")
-    event_id = _event_dict_property("event_id")
     hashes = _event_dict_property("hashes")
     origin = _event_dict_property("origin")
     origin_server_ts = _event_dict_property("origin_server_ts")
@@ -88,8 +87,6 @@ class EventBase(object):
     redacts = _event_dict_property("redacts")
     room_id = _event_dict_property("room_id")
     sender = _event_dict_property("sender")
-    state_key = _event_dict_property("state_key")
-    type = _event_dict_property("type")
     user_id = _event_dict_property("sender")
 
     @property
@@ -162,6 +159,11 @@ class FrozenEvent(EventBase):
         else:
             frozen_dict = event_dict
 
+        self.event_id = event_dict["event_id"]
+        self.type = event_dict["type"]
+        if "state_key" in event_dict:
+            self.state_key = event_dict["state_key"]
+
         super(FrozenEvent, self).__init__(
             frozen_dict,
             signatures=signatures,
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 7369d70980..365fd96bd2 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from . import EventBase, FrozenEvent
+from . import EventBase, FrozenEvent, _event_dict_property
 
 from synapse.types import EventID
 
@@ -34,6 +34,10 @@ class EventBuilder(EventBase):
             internal_metadata_dict=internal_metadata_dict,
         )
 
+    event_id = _event_dict_property("event_id")
+    state_key = _event_dict_property("state_key")
+    type = _event_dict_property("type")
+
     def build(self):
         return FrozenEvent.from_event(self)
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b4bcec77ed..c9175bb33d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.events import FrozenEvent
+from synapse.events import FrozenEvent, builder
 import synapse.metrics
 
 from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -499,8 +499,10 @@ class FederationClient(FederationBase):
                 if "prev_state" not in pdu_dict:
                     pdu_dict["prev_state"] = []
 
+                ev = builder.EventBuilder(pdu_dict)
+
                 defer.returnValue(
-                    (destination, self.event_from_pdu_json(pdu_dict))
+                    (destination, ev)
                 )
                 break
             except CodeMessageException as e:
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 7db7b806dc..6b3a7abb9e 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -362,7 +362,7 @@ class TransactionQueue(object):
                     if not success:
                         break
         except NotRetryingDestination:
-            logger.info(
+            logger.debug(
                 "TX [%s] not ready for retry yet - "
                 "dropping transaction for now",
                 destination,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1021bcc405..d3f5892376 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -591,12 +591,12 @@ class FederationHandler(BaseHandler):
 
         event_ids = list(extremities.keys())
 
-        logger.info("calling resolve_state_groups in _maybe_backfill")
+        logger.debug("calling resolve_state_groups in _maybe_backfill")
         states = yield preserve_context_over_deferred(defer.gatherResults([
             preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
             for e in event_ids
         ]))
-        states = dict(zip(event_ids, [s[1] for s in states]))
+        states = dict(zip(event_ids, [s.state for s in states]))
 
         state_map = yield self.store.get_events(
             [e_id for ids in states.values() for e_id in ids],
@@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler):
                     (d.type, d.state_key): d for d in different_events if d
                 })
 
-                new_state, prev_state = self.state_handler.resolve_events(
+                new_state = self.state_handler.resolve_events(
                     [local_view.values(), remote_view.values()],
                     event
                 )
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index b47bf1f92b..a27476bbad 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -52,7 +52,7 @@ def get_badge_count(store, user_id):
 def get_context_for_event(store, state_handler, ev, user_id):
     ctx = {}
 
-    room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
+    room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
 
     # we no longer bother setting room_alias, and make room_name the
     # human-readable name instead, be that m.room.name, an alias or
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 093bc072f4..0c9cdff3b8 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def do_password_login(self, login_submission):
         if 'medium' in login_submission and 'address' in login_submission:
+            address = login_submission['address']
+            if login_submission['medium'] == 'email':
+                # For emails, transform the address to lowercase.
+                # We store all email addreses as lowercase in the DB.
+                # (See add_threepid in synapse/handlers/auth.py)
+                address = address.lower()
             user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
-                login_submission['medium'], login_submission['address']
+                login_submission['medium'], address
             )
             if not user_id:
                 raise LoginError(403, "", errcode=Codes.FORBIDDEN)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index e74e5e0123..398e7f5eb0 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet):
             threepid = result[LoginType.EMAIL_IDENTITY]
             if 'medium' not in threepid or 'address' not in threepid:
                 raise SynapseError(500, "Malformed threepid")
+            if threepid['medium'] == 'email':
+                # For emails, transform the address to lowercase.
+                # We store all email addreses as lowercase in the DB.
+                # (See add_threepid in synapse/handlers/auth.py)
+                threepid['address'] = threepid['address'].lower()
             # if using email, we must know about the email they're authing with!
             threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
                 threepid['medium'], threepid['address']
diff --git a/synapse/state.py b/synapse/state.py
index b9d5627a82..20aaacf40f 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -16,12 +16,12 @@
 
 from twisted.internet import defer
 
+from synapse import event_auth
 from synapse.util.logutils import log_function
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
-from synapse.api.auth import AuthEventTypes
 from synapse.events.snapshot import EventContext
 from synapse.util.async import Linearizer
 
@@ -41,12 +41,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
-SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
+SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
 EVICTION_TIMEOUT_SECONDS = 60 * 60
 
 
 _NEXT_STATE_ID = 1
 
+POWER_KEY = (EventTypes.PowerLevels, "")
+
 
 def _gen_state_id():
     global _NEXT_STATE_ID
@@ -77,6 +79,9 @@ class _StateCacheEntry(object):
         else:
             self.state_id = _gen_state_id()
 
+    def __len__(self):
+        return len(self.state)
+
 
 class StateHandler(object):
     """ Responsible for doing state conflict resolution.
@@ -99,6 +104,7 @@ class StateHandler(object):
             clock=self.clock,
             max_len=SIZE_OF_CACHE,
             expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+            iterable=True,
             reset_expiry_on_get=True,
         )
 
@@ -123,7 +129,7 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
-        logger.info("calling resolve_state_groups from get_current_state")
+        logger.debug("calling resolve_state_groups from get_current_state")
         ret = yield self.resolve_state_groups(room_id, latest_event_ids)
         state = ret.state
 
@@ -148,7 +154,7 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
-        logger.info("calling resolve_state_groups from get_current_state_ids")
+        logger.debug("calling resolve_state_groups from get_current_state_ids")
         ret = yield self.resolve_state_groups(room_id, latest_event_ids)
         state = ret.state
 
@@ -162,7 +168,7 @@ class StateHandler(object):
     def get_current_user_in_room(self, room_id, latest_event_ids=None):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
-        logger.info("calling resolve_state_groups from get_current_user_in_room")
+        logger.debug("calling resolve_state_groups from get_current_user_in_room")
         entry = yield self.resolve_state_groups(room_id, latest_event_ids)
         joined_users = yield self.store.get_joined_users_from_state(
             room_id, entry.state_id, entry.state
@@ -226,7 +232,7 @@ class StateHandler(object):
             context.prev_state_events = []
             defer.returnValue(context)
 
-        logger.info("calling resolve_state_groups from compute_event_context")
+        logger.debug("calling resolve_state_groups from compute_event_context")
         if event.is_state():
             entry = yield self.resolve_state_groups(
                 event.room_id, [e for e, _ in event.prev_events],
@@ -327,20 +333,13 @@ class StateHandler(object):
 
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
-                state_map = yield self.store.get_events(
-                    [e_id for st in state_groups_ids.values() for e_id in st.values()],
-                    get_prev_content=False
-                )
-                state_sets = [
-                    [state_map[e_id] for key, e_id in st.items() if e_id in state_map]
-                    for st in state_groups_ids.values()
-                ]
-                new_state, _ = self._resolve_events(
-                    state_sets, event_type, state_key
-                )
-                new_state = {
-                    key: e.event_id for key, e in new_state.items()
-                }
+                with Measure(self.clock, "state._resolve_events"):
+                    new_state = yield resolve_events(
+                        state_groups_ids.values(),
+                        state_map_factory=lambda ev_ids: self.store.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
+                        ),
+                    )
             else:
                 new_state = {
                     key: e_ids.pop() for key, e_ids in state.items()
@@ -388,152 +387,264 @@ class StateHandler(object):
         logger.info(
             "Resolving state for %s with %d groups", event.room_id, len(state_sets)
         )
-        if event.is_state():
-            return self._resolve_events(
-                state_sets, event.type, event.state_key
-            )
-        else:
-            return self._resolve_events(state_sets)
+        state_set_ids = [{
+            (ev.type, ev.state_key): ev.event_id
+            for ev in st
+        } for st in state_sets]
+
+        state_map = {
+            ev.event_id: ev
+            for st in state_sets
+            for ev in st
+        }
 
-    def _resolve_events(self, state_sets, event_type=None, state_key=""):
-        """
-        Returns
-            (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
-            (new_state, prev_states). new_state is a map from (type, state_key)
-            to event. prev_states is a list of event_ids.
-        """
         with Measure(self.clock, "state._resolve_events"):
-            state = {}
-            for st in state_sets:
-                for e in st:
-                    state.setdefault(
-                        (e.type, e.state_key),
-                        {}
-                    )[e.event_id] = e
-
-            unconflicted_state = {
-                k: v.values()[0] for k, v in state.items()
-                if len(v.values()) == 1
-            }
+            new_state = resolve_events(state_set_ids, state_map)
 
-            conflicted_state = {
-                k: v.values()
-                for k, v in state.items()
-                if len(v.values()) > 1
-            }
+        new_state = {
+            key: state_map[ev_id] for key, ev_id in new_state.items()
+        }
 
-            if event_type:
-                prev_states_events = conflicted_state.get(
-                    (event_type, state_key), []
-                )
-                prev_states = [s.event_id for s in prev_states_events]
-            else:
-                prev_states = []
+        return new_state
 
-            auth_events = {
-                k: e for k, e in unconflicted_state.items()
-                if k[0] in AuthEventTypes
-            }
 
-            try:
-                resolved_state = self._resolve_state_events(
-                    conflicted_state, auth_events
-                )
-            except:
-                logger.exception("Failed to resolve state")
-                raise
+def _ordered_events(events):
+    def key_func(e):
+        return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
 
-            new_state = unconflicted_state
-            new_state.update(resolved_state)
+    return sorted(events, key=key_func)
 
-        return new_state, prev_states
 
-    @log_function
-    def _resolve_state_events(self, conflicted_state, auth_events):
-        """ This is where we actually decide which of the conflicted state to
-        use.
-
-        We resolve conflicts in the following order:
-            1. power levels
-            2. join rules
-            3. memberships
-            4. other events.
-        """
-        resolved_state = {}
-        power_key = (EventTypes.PowerLevels, "")
-        if power_key in conflicted_state:
-            events = conflicted_state[power_key]
-            logger.debug("Resolving conflicted power levels %r", events)
-            resolved_state[power_key] = self._resolve_auth_events(
-                events, auth_events)
-
-        auth_events.update(resolved_state)
-
-        for key, events in conflicted_state.items():
-            if key[0] == EventTypes.JoinRules:
-                logger.debug("Resolving conflicted join rules %r", events)
-                resolved_state[key] = self._resolve_auth_events(
-                    events,
-                    auth_events
-                )
-
-        auth_events.update(resolved_state)
-
-        for key, events in conflicted_state.items():
-            if key[0] == EventTypes.Member:
-                logger.debug("Resolving conflicted member lists %r", events)
-                resolved_state[key] = self._resolve_auth_events(
-                    events,
-                    auth_events
-                )
-
-        auth_events.update(resolved_state)
-
-        for key, events in conflicted_state.items():
-            if key not in resolved_state:
-                logger.debug("Resolving conflicted state %r:%r", key, events)
-                resolved_state[key] = self._resolve_normal_events(
-                    events, auth_events
-                )
-
-        return resolved_state
-
-    def _resolve_auth_events(self, events, auth_events):
-        reverse = [i for i in reversed(self._ordered_events(events))]
-
-        auth_events = dict(auth_events)
-
-        prev_event = reverse[0]
-        for event in reverse[1:]:
-            auth_events[(prev_event.type, prev_event.state_key)] = prev_event
-            try:
-                # FIXME: hs.get_auth() is bad style, but we need to do it to
-                # get around circular deps.
-                # The signatures have already been checked at this point
-                self.hs.get_auth().check(event, auth_events, do_sig_check=False)
-                prev_event = event
-            except AuthError:
-                return prev_event
-
-        return event
-
-    def _resolve_normal_events(self, events, auth_events):
-        for event in self._ordered_events(events):
-            try:
-                # FIXME: hs.get_auth() is bad style, but we need to do it to
-                # get around circular deps.
-                # The signatures have already been checked at this point
-                self.hs.get_auth().check(event, auth_events, do_sig_check=False)
-                return event
-            except AuthError:
-                pass
-
-        # Use the last event (the one with the least depth) if they all fail
-        # the auth check.
-        return event
-
-    def _ordered_events(self, events):
-        def key_func(e):
-            return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
-
-        return sorted(events, key=key_func)
+def resolve_events(state_sets, state_map_factory):
+    """
+    Args:
+        state_sets(list): List of dicts of (type, state_key) -> event_id,
+            which are the different state groups to resolve.
+        state_map_factory(dict|callable): If callable, then will be called
+            with a list of event_ids that are needed, and should return with
+            a Deferred of dict of event_id to event. Otherwise, should be
+            a dict from event_id to event of all events in state_sets.
+
+    Returns
+        dict[(str, str), synapse.events.FrozenEvent] is a map from
+        (type, state_key) to event.
+    """
+    unconflicted_state, conflicted_state = _seperate(
+        state_sets,
+    )
+
+    if callable(state_map_factory):
+        return _resolve_with_state_fac(
+            unconflicted_state, conflicted_state, state_map_factory
+        )
+
+    state_map = state_map_factory
+
+    auth_events = _create_auth_events_from_maps(
+        unconflicted_state, conflicted_state, state_map
+    )
+
+    return _resolve_with_state(
+        unconflicted_state, conflicted_state, auth_events, state_map
+    )
+
+
+def _seperate(state_sets):
+    """Takes the state_sets and figures out which keys are conflicted and
+    which aren't. i.e., which have multiple different event_ids associated
+    with them in different state sets.
+    """
+    unconflicted_state = dict(state_sets[0])
+    conflicted_state = {}
+
+    for state_set in state_sets[1:]:
+        for key, value in state_set.iteritems():
+            # Check if there is an unconflicted entry for the state key.
+            unconflicted_value = unconflicted_state.get(key)
+            if unconflicted_value is None:
+                # There isn't an unconflicted entry so check if there is a
+                # conflicted entry.
+                ls = conflicted_state.get(key)
+                if ls is None:
+                    # There wasn't a conflicted entry so haven't seen this key before.
+                    # Therefore it isn't conflicted yet.
+                    unconflicted_state[key] = value
+                else:
+                    # This key is already conflicted, add our value to the conflict set.
+                    ls.add(value)
+            elif unconflicted_value != value:
+                # If the unconflicted value is not the same as our value then we
+                # have a new conflict. So move the key from the unconflicted_state
+                # to the conflicted state.
+                conflicted_state[key] = {value, unconflicted_value}
+                unconflicted_state.pop(key, None)
+
+    return unconflicted_state, conflicted_state
+
+
+@defer.inlineCallbacks
+def _resolve_with_state_fac(unconflicted_state, conflicted_state,
+                            state_map_factory):
+    needed_events = set(
+        event_id
+        for event_ids in conflicted_state.itervalues()
+        for event_id in event_ids
+    )
+
+    logger.info("Asking for %d conflicted events", len(needed_events))
+
+    state_map = yield state_map_factory(needed_events)
+
+    auth_events = _create_auth_events_from_maps(
+        unconflicted_state, conflicted_state, state_map
+    )
+
+    new_needed_events = set(auth_events.itervalues())
+    new_needed_events -= needed_events
+
+    logger.info("Asking for %d auth events", len(new_needed_events))
+
+    state_map_new = yield state_map_factory(new_needed_events)
+    state_map.update(state_map_new)
+
+    defer.returnValue(_resolve_with_state(
+        unconflicted_state, conflicted_state, auth_events, state_map
+    ))
+
+
+def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
+    auth_events = {}
+    for event_ids in conflicted_state.itervalues():
+        for event_id in event_ids:
+            if event_id in state_map:
+                keys = event_auth.auth_types_for_event(state_map[event_id])
+                for key in keys:
+                    if key not in auth_events:
+                        event_id = unconflicted_state.get(key, None)
+                        if event_id:
+                            auth_events[key] = event_id
+    return auth_events
+
+
+def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
+                        state_map):
+    conflicted_state = {}
+    for key, event_ids in conflicted_state_ds.iteritems():
+        events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
+        if len(events) > 1:
+            conflicted_state[key] = events
+        elif len(events) == 1:
+            unconflicted_state_ids[key] = events[0].event_id
+
+    auth_events = {
+        key: state_map[ev_id]
+        for key, ev_id in auth_event_ids.items()
+        if ev_id in state_map
+    }
+
+    try:
+        resolved_state = _resolve_state_events(
+            conflicted_state, auth_events
+        )
+    except:
+        logger.exception("Failed to resolve state")
+        raise
+
+    new_state = unconflicted_state_ids
+    for key, event in resolved_state.iteritems():
+        new_state[key] = event.event_id
+
+    return new_state
+
+
+def _resolve_state_events(conflicted_state, auth_events):
+    """ This is where we actually decide which of the conflicted state to
+    use.
+
+    We resolve conflicts in the following order:
+        1. power levels
+        2. join rules
+        3. memberships
+        4. other events.
+    """
+    resolved_state = {}
+    if POWER_KEY in conflicted_state:
+        events = conflicted_state[POWER_KEY]
+        logger.debug("Resolving conflicted power levels %r", events)
+        resolved_state[POWER_KEY] = _resolve_auth_events(
+            events, auth_events)
+
+    auth_events.update(resolved_state)
+
+    for key, events in conflicted_state.items():
+        if key[0] == EventTypes.JoinRules:
+            logger.debug("Resolving conflicted join rules %r", events)
+            resolved_state[key] = _resolve_auth_events(
+                events,
+                auth_events
+            )
+
+    auth_events.update(resolved_state)
+
+    for key, events in conflicted_state.items():
+        if key[0] == EventTypes.Member:
+            logger.debug("Resolving conflicted member lists %r", events)
+            resolved_state[key] = _resolve_auth_events(
+                events,
+                auth_events
+            )
+
+    auth_events.update(resolved_state)
+
+    for key, events in conflicted_state.items():
+        if key not in resolved_state:
+            logger.debug("Resolving conflicted state %r:%r", key, events)
+            resolved_state[key] = _resolve_normal_events(
+                events, auth_events
+            )
+
+    return resolved_state
+
+
+def _resolve_auth_events(events, auth_events):
+    reverse = [i for i in reversed(_ordered_events(events))]
+
+    auth_keys = set(
+        key
+        for event in events
+        for key in event_auth.auth_types_for_event(event)
+    )
+
+    new_auth_events = {}
+    for key in auth_keys:
+        auth_event = auth_events.get(key, None)
+        if auth_event:
+            new_auth_events[key] = auth_event
+
+    auth_events = new_auth_events
+
+    prev_event = reverse[0]
+    for event in reverse[1:]:
+        auth_events[(prev_event.type, prev_event.state_key)] = prev_event
+        try:
+            # The signatures have already been checked at this point
+            event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
+            prev_event = event
+        except AuthError:
+            return prev_event
+
+    return event
+
+
+def _resolve_normal_events(events, auth_events):
+    for event in _ordered_events(events):
+        try:
+            # The signatures have already been checked at this point
+            event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
+            return event
+        except AuthError:
+            pass
+
+    # Use the last event (the one with the least depth) if they all fail
+    # the auth check.
+    return event
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 5620a655eb..963ef999d5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -169,7 +169,7 @@ class SQLBaseStore(object):
                                       max_entries=hs.config.event_cache_size)
 
         self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
+            "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
         )
 
         self._event_fetch_lock = threading.Condition()
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 2821eb89c9..bde3b5cbbc 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -18,13 +18,29 @@ import ujson
 
 from twisted.internet import defer
 
-from ._base import SQLBaseStore
+from .background_updates import BackgroundUpdateStore
 
 
 logger = logging.getLogger(__name__)
 
 
-class DeviceInboxStore(SQLBaseStore):
+class DeviceInboxStore(BackgroundUpdateStore):
+    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+    def __init__(self, hs):
+        super(DeviceInboxStore, self).__init__(hs)
+
+        self.register_background_index_update(
+            "device_inbox_stream_index",
+            index_name="device_inbox_stream_id_user_id",
+            table="device_inbox",
+            columns=["stream_id", "user_id"],
+        )
+
+        self.register_background_update_handler(
+            self.DEVICE_INBOX_STREAM_ID,
+            self._background_drop_index_device_inbox,
+        )
 
     @defer.inlineCallbacks
     def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
@@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore):
             "delete_device_msgs_for_remote",
             delete_messages_for_remote_destination_txn
         )
+
+    @defer.inlineCallbacks
+    def _background_drop_index_device_inbox(self, progress, batch_size):
+        def reindex_txn(conn):
+            txn = conn.cursor()
+            txn.execute(
+                "DROP INDEX IF EXISTS device_inbox_stream_id"
+            )
+            txn.close()
+
+        yield self.runWithConnection(reindex_txn)
+
+        yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+        defer.returnValue(1)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 04dbdac3f8..ca501932f3 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1084,10 +1084,10 @@ class EventsStore(SQLBaseStore):
                     self._do_fetch
                 )
 
-        logger.info("Loading %d events", len(events))
+        logger.debug("Loading %d events", len(events))
         with PreserveLoggingContext():
             rows = yield events_d
-        logger.info("Loaded %d events (%d rows)", len(events), len(rows))
+        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
 
         if not allow_rejected:
             rows[:] = [r for r in rows if not r["rejects"]]
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e46ae6502e..b357f22be7 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 39
+SCHEMA_VERSION = 40
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 5d18037c7c..768e0a4451 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore):
             room_id, state_group, state_ids,
         )
 
-    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
+                           max_entries=100000)
     def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
                                        cache_context, event=None):
         # We don't use `state_group`, it's there so that we can cache based
diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql
new file mode 100644
index 0000000000..b9fe1f0480
--- /dev/null
+++ b/synapse/storage/schema/delta/40/device_inbox.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- turn the pre-fill startup query into a index-only scan on postgresql.
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('device_inbox_stream_index', '{}');
+
+INSERT into background_updates (update_name, progress_json, depends_on)
+    VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7f466c40ac..7d34dd03bf 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -284,7 +284,7 @@ class StateStore(SQLBaseStore):
             return [r[0] for r in results]
         return self.runInteraction("get_current_state_for_key", f)
 
-    @cached(num_args=2, max_entries=1000)
+    @cached(num_args=2, max_entries=100000, iterable=True)
     def _get_state_group_from_group(self, group, types):
         raise NotImplementedError()
 
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index ebd715c5dc..8a7774a88e 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -40,8 +40,8 @@ def register_cache(name, cache):
     )
 
 
-_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
-caches_by_name["string_cache"] = _string_cache
+_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
+_stirng_cache_metrics = register_cache("string_cache", _string_cache)
 
 
 KNOWN_KEYS = {
@@ -69,7 +69,12 @@ KNOWN_KEYS = {
 def intern_string(string):
     """Takes a (potentially) unicode string and interns using custom cache
     """
-    return _string_cache.setdefault(string, string)
+    new_str = _string_cache.setdefault(string, string)
+    if new_str is string:
+        _stirng_cache_metrics.inc_hits()
+    else:
+        _stirng_cache_metrics.inc_misses()
+    return new_str
 
 
 def intern_dict(dictionary):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8dba61d49f..675bfd5feb 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,7 +17,7 @@ import logging
 from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
 )
@@ -42,6 +42,25 @@ _CacheSentinel = object()
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
+class CacheEntry(object):
+    __slots__ = [
+        "deferred", "sequence", "callbacks", "invalidated"
+    ]
+
+    def __init__(self, deferred, sequence, callbacks):
+        self.deferred = deferred
+        self.sequence = sequence
+        self.callbacks = set(callbacks)
+        self.invalidated = False
+
+    def invalidate(self):
+        if not self.invalidated:
+            self.invalidated = True
+            for callback in self.callbacks:
+                callback()
+            self.callbacks.clear()
+
+
 class Cache(object):
     __slots__ = (
         "cache",
@@ -51,12 +70,16 @@ class Cache(object):
         "sequence",
         "thread",
         "metrics",
+        "_pending_deferred_cache",
     )
 
-    def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+    def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
         cache_type = TreeCache if tree else dict
+        self._pending_deferred_cache = cache_type()
+
         self.cache = LruCache(
-            max_size=max_entries, keylen=keylen, cache_type=cache_type
+            max_size=max_entries, keylen=keylen, cache_type=cache_type,
+            size_callback=(lambda d: len(d.result)) if iterable else None,
         )
 
         self.name = name
@@ -76,7 +99,15 @@ class Cache(object):
                 )
 
     def get(self, key, default=_CacheSentinel, callback=None):
-        val = self.cache.get(key, _CacheSentinel, callback=callback)
+        callbacks = [callback] if callback else []
+        val = self._pending_deferred_cache.get(key, _CacheSentinel)
+        if val is not _CacheSentinel:
+            if val.sequence == self.sequence:
+                val.callbacks.update(callbacks)
+                self.metrics.inc_hits()
+                return val.deferred
+
+        val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
         if val is not _CacheSentinel:
             self.metrics.inc_hits()
             return val
@@ -88,15 +119,39 @@ class Cache(object):
         else:
             return default
 
-    def update(self, sequence, key, value, callback=None):
+    def set(self, key, value, callback=None):
+        callbacks = [callback] if callback else []
         self.check_thread()
-        if self.sequence == sequence:
-            # Only update the cache if the caches sequence number matches the
-            # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(key, value, callback=callback)
+        entry = CacheEntry(
+            deferred=value,
+            sequence=self.sequence,
+            callbacks=callbacks,
+        )
+
+        entry.callbacks.update(callbacks)
+
+        existing_entry = self._pending_deferred_cache.pop(key, None)
+        if existing_entry:
+            existing_entry.invalidate()
+
+        self._pending_deferred_cache[key] = entry
+
+        def shuffle(result):
+            if self.sequence == entry.sequence:
+                existing_entry = self._pending_deferred_cache.pop(key, None)
+                if existing_entry is entry:
+                    self.cache.set(key, entry.deferred, entry.callbacks)
+                else:
+                    entry.invalidate()
+            else:
+                entry.invalidate()
+            return result
+
+        entry.deferred.addCallback(shuffle)
 
     def prefill(self, key, value, callback=None):
-        self.cache.set(key, value, callback=callback)
+        callbacks = [callback] if callback else []
+        self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
         self.check_thread()
@@ -108,6 +163,10 @@ class Cache(object):
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
+        entry = self._pending_deferred_cache.pop(key, None)
+        if entry:
+            entry.invalidate()
+
         self.cache.pop(key, None)
 
     def invalidate_many(self, key):
@@ -119,6 +178,11 @@ class Cache(object):
         self.sequence += 1
         self.cache.del_multi(key)
 
+        entry_dict = self._pending_deferred_cache.pop(key, None)
+        if entry_dict is not None:
+            for entry in iterate_tree_cache_entry(entry_dict):
+                entry.invalidate()
+
     def invalidate_all(self):
         self.check_thread()
         self.sequence += 1
@@ -155,7 +219,7 @@ class CacheDescriptor(object):
 
     """
     def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
-                 inlineCallbacks=False, cache_context=False):
+                 inlineCallbacks=False, cache_context=False, iterable=False):
         max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.orig = orig
@@ -169,6 +233,8 @@ class CacheDescriptor(object):
         self.num_args = num_args
         self.tree = tree
 
+        self.iterable = iterable
+
         all_args = inspect.getargspec(orig)
         self.arg_names = all_args.args[1:num_args + 1]
 
@@ -203,6 +269,7 @@ class CacheDescriptor(object):
             max_entries=self.max_entries,
             keylen=self.num_args,
             tree=self.tree,
+            iterable=self.iterable,
         )
 
         @functools.wraps(self.orig)
@@ -243,11 +310,6 @@ class CacheDescriptor(object):
 
                 return preserve_context_over_deferred(observer)
             except KeyError:
-                # Get the sequence number of the cache before reading from the
-                # database so that we can tell if the cache is invalidated
-                # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
-
                 ret = defer.maybeDeferred(
                     preserve_context_over_fn,
                     self.function_to_call,
@@ -261,7 +323,7 @@ class CacheDescriptor(object):
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.update(sequence, cache_key, ret, callback=invalidate_callback)
+                cache.set(cache_key, ret, callback=invalidate_callback)
 
                 return preserve_context_over_deferred(ret.observe())
 
@@ -359,7 +421,6 @@ class CacheListDescriptor(object):
                     missing.append(arg)
 
             if missing:
-                sequence = cache.sequence
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = missing
 
@@ -382,8 +443,8 @@ class CacheListDescriptor(object):
 
                     key = list(keyargs)
                     key[self.list_pos] = arg
-                    cache.update(
-                        sequence, tuple(key), observer,
+                    cache.set(
+                        tuple(key), observer,
                         callback=invalidate_callback
                     )
 
@@ -421,17 +482,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+           iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         tree=tree,
         cache_context=cache_context,
+        iterable=iterable,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
+                          iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -439,6 +503,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
         tree=tree,
         inlineCallbacks=True,
         cache_context=cache_context,
+        iterable=iterable,
     )
 
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index b0ca1bb79d..cb6933c61c 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -23,7 +23,9 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
+class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
+    def __len__(self):
+        return len(self.value)
 
 
 class DictionaryCache(object):
@@ -32,7 +34,7 @@ class DictionaryCache(object):
     """
 
     def __init__(self, name, max_entries=1000):
-        self.cache = LruCache(max_size=max_entries)
+        self.cache = LruCache(max_size=max_entries, size_callback=len)
 
         self.name = name
         self.sequence = 0
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 080388958f..2987c38a2d 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -15,6 +15,7 @@
 
 from synapse.util.caches import register_cache
 
+from collections import OrderedDict
 import logging
 
 
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
 
 class ExpiringCache(object):
     def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
-                 reset_expiry_on_get=False):
+                 reset_expiry_on_get=False, iterable=False):
         """
         Args:
             cache_name (str): Name of this cache, used for logging.
@@ -36,6 +37,8 @@ class ExpiringCache(object):
                 evicted based on time.
             reset_expiry_on_get (bool): If true, will reset the expiry time for
                 an item on access. Defaults to False.
+            iterable (bool): If true, the size is calculated by summing the
+                sizes of all entries, rather than the number of entries.
 
         """
         self._cache_name = cache_name
@@ -47,9 +50,13 @@ class ExpiringCache(object):
 
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache = {}
+        self._cache = OrderedDict()
 
-        self.metrics = register_cache(cache_name, self._cache)
+        self.metrics = register_cache(cache_name, self)
+
+        self.iterable = iterable
+
+        self._size_estimate = 0
 
     def start(self):
         if not self._expiry_ms:
@@ -65,15 +72,14 @@ class ExpiringCache(object):
         now = self._clock.time_msec()
         self._cache[key] = _CacheEntry(now, value)
 
-        # Evict if there are now too many items
-        if self._max_len and len(self._cache.keys()) > self._max_len:
-            sorted_entries = sorted(
-                self._cache.items(),
-                key=lambda item: item[1].time,
-            )
+        if self.iterable:
+            self._size_estimate += len(value)
 
-            for k, _ in sorted_entries[self._max_len:]:
-                self._cache.pop(k)
+        # Evict if there are now too many items
+        while self._max_len and len(self) > self._max_len:
+            _key, value = self._cache.popitem(last=False)
+            if self.iterable:
+                self._size_estimate -= len(value.value)
 
     def __getitem__(self, key):
         try:
@@ -99,7 +105,7 @@ class ExpiringCache(object):
             # zero expiry time means don't expire. This should never get called
             # since we have this check in start too.
             return
-        begin_length = len(self._cache)
+        begin_length = len(self)
 
         now = self._clock.time_msec()
 
@@ -110,15 +116,20 @@ class ExpiringCache(object):
                 keys_to_delete.add(key)
 
         for k in keys_to_delete:
-            self._cache.pop(k)
+            value = self._cache.pop(k)
+            if self.iterable:
+                self._size_estimate -= len(value.value)
 
         logger.debug(
             "[%s] _prune_cache before: %d, after len: %d",
-            self._cache_name, begin_length, len(self._cache)
+            self._cache_name, begin_length, len(self)
         )
 
     def __len__(self):
-        return len(self._cache)
+        if self.iterable:
+            return self._size_estimate
+        else:
+            return len(self._cache)
 
 
 class _CacheEntry(object):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 9c4c679175..072f9a9d19 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,7 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -58,6 +58,12 @@ class LruCache(object):
 
         lock = threading.Lock()
 
+        def evict():
+            while cache_len() > max_size:
+                todelete = list_root.prev_node
+                delete_node(todelete)
+                cache.pop(todelete.key, None)
+
         def synchronized(f):
             @wraps(f)
             def inner(*args, **kwargs):
@@ -66,6 +72,16 @@ class LruCache(object):
 
             return inner
 
+        cached_cache_len = [0]
+        if size_callback is not None:
+            def cache_len():
+                return cached_cache_len[0]
+        else:
+            def cache_len():
+                return len(cache)
+
+        self.len = synchronized(cache_len)
+
         def add_node(key, value, callbacks=set()):
             prev_node = list_root
             next_node = prev_node.next_node
@@ -74,6 +90,9 @@ class LruCache(object):
             next_node.prev_node = node
             cache[key] = node
 
+            if size_callback:
+                cached_cache_len[0] += size_callback(node.value)
+
         def move_node_to_front(node):
             prev_node = node.prev_node
             next_node = node.next_node
@@ -92,23 +111,25 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            if size_callback:
+                cached_cache_len[0] -= size_callback(node.value)
+
             for cb in node.callbacks:
                 cb()
             node.callbacks.clear()
 
         @synchronized
-        def cache_get(key, default=None, callback=None):
+        def cache_get(key, default=None, callbacks=[]):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
-                if callback:
-                    node.callbacks.add(callback)
+                node.callbacks.update(callbacks)
                 return node.value
             else:
                 return default
 
         @synchronized
-        def cache_set(key, value, callback=None):
+        def cache_set(key, value, callbacks=[]):
             node = cache.get(key, None)
             if node is not None:
                 if value != node.value:
@@ -116,21 +137,18 @@ class LruCache(object):
                         cb()
                     node.callbacks.clear()
 
-                if callback:
-                    node.callbacks.add(callback)
+                    if size_callback:
+                        cached_cache_len[0] -= size_callback(node.value)
+                        cached_cache_len[0] += size_callback(value)
+
+                node.callbacks.update(callbacks)
 
                 move_node_to_front(node)
                 node.value = value
             else:
-                if callback:
-                    callbacks = set([callback])
-                else:
-                    callbacks = set()
-                add_node(key, value, callbacks)
-                if len(cache) > max_size:
-                    todelete = list_root.prev_node
-                    delete_node(todelete)
-                    cache.pop(todelete.key, None)
+                add_node(key, value, set(callbacks))
+
+            evict()
 
         @synchronized
         def cache_set_default(key, value):
@@ -139,10 +157,7 @@ class LruCache(object):
                 return node.value
             else:
                 add_node(key, value)
-                if len(cache) > max_size:
-                    todelete = list_root.prev_node
-                    delete_node(todelete)
-                    cache.pop(todelete.key, None)
+                evict()
                 return value
 
         @synchronized
@@ -176,10 +191,6 @@ class LruCache(object):
             cache.clear()
 
         @synchronized
-        def cache_len():
-            return len(cache)
-
-        @synchronized
         def cache_contains(key):
             return key in cache
 
@@ -190,7 +201,7 @@ class LruCache(object):
         self.pop = cache_pop
         if cache_type is TreeCache:
             self.del_multi = cache_del_multi
-        self.len = cache_len
+        self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear
 
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index c31585aea3..fcc341a6b7 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -65,12 +65,27 @@ class TreeCache(object):
         return popped
 
     def values(self):
-        return [e.value for e in self.root.values()]
+        return list(iterate_tree_cache_entry(self.root))
 
     def __len__(self):
         return self.size
 
 
+def iterate_tree_cache_entry(d):
+    """Helper function to iterate over the leaves of a tree, i.e. a dict of that
+    can contain dicts.
+    """
+    if isinstance(d, dict):
+        for value_d in d.itervalues():
+            for value in iterate_tree_cache_entry(value_d):
+                yield value
+    else:
+        if isinstance(d, _Entry):
+            yield d.value
+        else:
+            yield d
+
+
 class _Entry(object):
     __slots__ = ["value"]
 
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index dcb6c5bc31..50e8607c14 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
 from synapse.events import FrozenEvent
 
 user_localpart = "test_user"
-# MockEvent = namedtuple("MockEvent", "sender type room_id")
 
 
 def MockEvent(**kwargs):
+    if "event_id" not in kwargs:
+        kwargs["event_id"] = "fake_event_id"
+    if "type" not in kwargs:
+        kwargs["type"] = "fake_type"
     return FrozenEvent(kwargs)
 
 
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 29f068d1f1..dfc870066e 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
 
 
 def MockEvent(**kwargs):
+    if "event_id" not in kwargs:
+        kwargs["event_id"] = "fake_event_id"
+    if "type" not in kwargs:
+        kwargs["type"] = "fake_type"
     return FrozenEvent(kwargs)
 
 
@@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
 
     def test_minimal(self):
         self.run_test(
-            {'type': 'A'},
             {
                 'type': 'A',
+                'event_id': '$test:domain',
+            },
+            {
+                'type': 'A',
+                'event_id': '$test:domain',
                 'content': {},
                 'signatures': {},
                 'unsigned': {},
@@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
         self.run_test(
             {
                 'type': 'B',
+                'event_id': '$test:domain',
                 'unsigned': {'age_ts': 20},
             },
             {
                 'type': 'B',
+                'event_id': '$test:domain',
                 'content': {},
                 'signatures': {},
                 'unsigned': {'age_ts': 20},
@@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
         self.run_test(
             {
                 'type': 'B',
+                'event_id': '$test:domain',
                 'unsigned': {'other_key': 'here'},
             },
             {
                 'type': 'B',
+                'event_id': '$test:domain',
                 'content': {},
                 'signatures': {},
                 'unsigned': {},
@@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
         self.run_test(
             {
                 'type': 'C',
+                'event_id': '$test:domain',
                 'content': {'things': 'here'},
             },
             {
                 'type': 'C',
+                'event_id': '$test:domain',
                 'content': {},
                 'signatures': {},
                 'unsigned': {},
@@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
         self.run_test(
             {
                 'type': 'm.room.create',
+                'event_id': '$test:domain',
                 'content': {'creator': '@2:domain', 'other_field': 'here'},
             },
             {
                 'type': 'm.room.create',
+                'event_id': '$test:domain',
                 'content': {'creator': '@2:domain'},
                 'signatures': {},
                 'unsigned': {},
@@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
         self.assertEquals(
             self.serialize(
                 MockEvent(
+                    type="foo",
+                    event_id="test",
                     room_id="!foo:bar",
                     content={
                         "foo": "bar",
@@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
                 []
             ),
             {
+                "type": "foo",
+                "event_id": "test",
                 "room_id": "!foo:bar",
                 "content": {
                     "foo": "bar",
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index ab6095564a..8361dd8cee 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
         callcount2 = [0]
 
         class A(object):
-            @cached(max_entries=2)
+            @cached(max_entries=20)  # HACK: This makes it 2 due to cache factor
             def func(self, key):
                 callcount[0] += 1
                 return key
@@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
         self.assertEquals(callcount[0], 2)
         self.assertEquals(callcount2[0], 2)
 
+        yield a.func2("foo")
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
         yield a.func("foo3")
 
         self.assertEquals(callcount[0], 3)
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
new file mode 100644
index 0000000000..31d24adb8b
--- /dev/null
+++ b/tests/util/test_expiring_cache.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .. import unittest
+
+from synapse.util.caches.expiringcache import ExpiringCache
+
+from tests.utils import MockClock
+
+
+class ExpiringCacheTestCase(unittest.TestCase):
+
+    def test_get_set(self):
+        clock = MockClock()
+        cache = ExpiringCache("test", clock, max_len=1)
+
+        cache["key"] = "value"
+        self.assertEquals(cache.get("key"), "value")
+        self.assertEquals(cache["key"], "value")
+
+    def test_eviction(self):
+        clock = MockClock()
+        cache = ExpiringCache("test", clock, max_len=2)
+
+        cache["key"] = "value"
+        cache["key2"] = "value2"
+        self.assertEquals(cache.get("key"), "value")
+        self.assertEquals(cache.get("key2"), "value2")
+
+        cache["key3"] = "value3"
+        self.assertEquals(cache.get("key"), None)
+        self.assertEquals(cache.get("key2"), "value2")
+        self.assertEquals(cache.get("key3"), "value3")
+
+    def test_iterable_eviction(self):
+        clock = MockClock()
+        cache = ExpiringCache("test", clock, max_len=5, iterable=True)
+
+        cache["key"] = [1]
+        cache["key2"] = [2, 3]
+        cache["key3"] = [4, 5]
+
+        self.assertEquals(cache.get("key"), [1])
+        self.assertEquals(cache.get("key2"), [2, 3])
+        self.assertEquals(cache.get("key3"), [4, 5])
+
+        cache["key4"] = [6, 7]
+        self.assertEquals(cache.get("key"), None)
+        self.assertEquals(cache.get("key2"), None)
+        self.assertEquals(cache.get("key3"), [4, 5])
+        self.assertEquals(cache.get("key4"), [6, 7])
+
+    def test_time_eviction(self):
+        clock = MockClock()
+        cache = ExpiringCache("test", clock, expiry_ms=1000)
+        cache.start()
+
+        cache["key"] = 1
+        clock.advance_time(0.5)
+        cache["key2"] = 2
+
+        self.assertEquals(cache.get("key"), 1)
+        self.assertEquals(cache.get("key2"), 2)
+
+        clock.advance_time(0.9)
+        self.assertEquals(cache.get("key"), None)
+        self.assertEquals(cache.get("key2"), 2)
+
+        clock.advance_time(1)
+        self.assertEquals(cache.get("key"), None)
+        self.assertEquals(cache.get("key2"), None)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 1eba5b535e..dfb78cb8bd 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         cache.set("key", "value")
         self.assertFalse(m.called)
 
-        cache.get("key", callback=m)
+        cache.get("key", callbacks=[m])
         self.assertFalse(m.called)
 
         cache.get("key", "value")
@@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         cache.set("key", "value")
         self.assertFalse(m.called)
 
-        cache.get("key", callback=m)
+        cache.get("key", callbacks=[m])
         self.assertFalse(m.called)
 
-        cache.get("key", callback=m)
+        cache.get("key", callbacks=[m])
         self.assertFalse(m.called)
 
         cache.set("key", "value2")
@@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         m = Mock()
         cache = LruCache(1)
 
-        cache.set("key", "value", m)
+        cache.set("key", "value", callbacks=[m])
         self.assertFalse(m.called)
 
         cache.set("key", "value")
@@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         m = Mock()
         cache = LruCache(1)
 
-        cache.set("key", "value", m)
+        cache.set("key", "value", callbacks=[m])
         self.assertFalse(m.called)
 
         cache.pop("key")
@@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         m4 = Mock()
         cache = LruCache(4, 2, cache_type=TreeCache)
 
-        cache.set(("a", "1"), "value", m1)
-        cache.set(("a", "2"), "value", m2)
-        cache.set(("b", "1"), "value", m3)
-        cache.set(("b", "2"), "value", m4)
+        cache.set(("a", "1"), "value", callbacks=[m1])
+        cache.set(("a", "2"), "value", callbacks=[m2])
+        cache.set(("b", "1"), "value", callbacks=[m3])
+        cache.set(("b", "2"), "value", callbacks=[m4])
 
         self.assertEquals(m1.call_count, 0)
         self.assertEquals(m2.call_count, 0)
@@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         m2 = Mock()
         cache = LruCache(5)
 
-        cache.set("key1", "value", m1)
-        cache.set("key2", "value", m2)
+        cache.set("key1", "value", callbacks=[m1])
+        cache.set("key2", "value", callbacks=[m2])
 
         self.assertEquals(m1.call_count, 0)
         self.assertEquals(m2.call_count, 0)
@@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         m3 = Mock(name="m3")
         cache = LruCache(2)
 
-        cache.set("key1", "value", m1)
-        cache.set("key2", "value", m2)
+        cache.set("key1", "value", callbacks=[m1])
+        cache.set("key2", "value", callbacks=[m2])
 
         self.assertEquals(m1.call_count, 0)
         self.assertEquals(m2.call_count, 0)
         self.assertEquals(m3.call_count, 0)
 
-        cache.set("key3", "value", m3)
+        cache.set("key3", "value", callbacks=[m3])
 
         self.assertEquals(m1.call_count, 1)
         self.assertEquals(m2.call_count, 0)
@@ -227,8 +227,33 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         self.assertEquals(m2.call_count, 0)
         self.assertEquals(m3.call_count, 0)
 
-        cache.set("key1", "value", m1)
+        cache.set("key1", "value", callbacks=[m1])
 
         self.assertEquals(m1.call_count, 1)
         self.assertEquals(m2.call_count, 0)
         self.assertEquals(m3.call_count, 1)
+
+
+class LruCacheSizedTestCase(unittest.TestCase):
+
+    def test_evict(self):
+        cache = LruCache(5, size_callback=len)
+        cache["key1"] = [0]
+        cache["key2"] = [1, 2]
+        cache["key3"] = [3]
+        cache["key4"] = [4]
+
+        self.assertEquals(cache["key1"], [0])
+        self.assertEquals(cache["key2"], [1, 2])
+        self.assertEquals(cache["key3"], [3])
+        self.assertEquals(cache["key4"], [4])
+        self.assertEquals(len(cache), 5)
+
+        cache["key5"] = [5, 6]
+
+        self.assertEquals(len(cache), 4)
+        self.assertEquals(cache.get("key1"), None)
+        self.assertEquals(cache.get("key2"), None)
+        self.assertEquals(cache["key3"], [3])
+        self.assertEquals(cache["key4"], [4])
+        self.assertEquals(cache["key5"], [5, 6])