summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-08-25 17:32:22 +0100
committerErik Johnston <erik@matrix.org>2016-08-25 17:32:22 +0100
commita3dc1e9cbe491aa981b8bbaeb2414b4ec8e5b9ca (patch)
treed7414c7b98aac2aeb1486285cb2774c2273fba1e
parentPull out event ids rather than full events for state (diff)
downloadsynapse-a3dc1e9cbe491aa981b8bbaeb2414b4ec8e5b9ca.tar.xz
Replace context.current_state with context.current_state_ids
Diffstat (limited to '')
-rw-r--r--synapse/api/auth.py68
-rw-r--r--synapse/events/snapshot.py13
-rw-r--r--synapse/handlers/_base.py30
-rw-r--r--synapse/handlers/federation.py112
-rw-r--r--synapse/handlers/message.py91
-rw-r--r--synapse/handlers/room_member.py124
-rw-r--r--synapse/push/action_generator.py4
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py32
-rw-r--r--synapse/state.py48
-rw-r--r--synapse/storage/push_rule.py21
-rw-r--r--synapse/storage/roommember.py45
-rw-r--r--synapse/storage/state.py16
-rw-r--r--synapse/visibility.py19
-rw-r--r--tests/replication/slave/storage/test_events.py9
-rw-r--r--tests/test_state.py73
15 files changed, 435 insertions, 270 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 0db26fcfd7..40c3e9db0d 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -52,7 +52,7 @@ class Auth(object):
         self.state = hs.get_state_handler()
         self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
         # Docs for these currently lives at
-        # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
+        # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
         # In addition, we have type == delete_pusher which grants access only to
         # delete pushers.
         self._KNOWN_CAVEAT_PREFIXES = set([
@@ -63,6 +63,17 @@ class Auth(object):
             "user_id = ",
         ])
 
+    @defer.inlineCallbacks
+    def check_from_context(self, event, context, do_sig_check=True):
+        auth_events_ids = yield self.compute_auth_events(
+            event, context.current_state_ids, for_verification=True,
+        )
+        auth_events = yield self.store.get_events(auth_events_ids)
+        auth_events = {
+            (e.type, e.state_key): e for e in auth_events.values()
+        }
+        self.check(event, auth_events=auth_events, do_sig_check=False)
+
     def check(self, event, auth_events, do_sig_check=True):
         """ Checks if this event is correctly authed.
 
@@ -847,7 +858,7 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def add_auth_events(self, builder, context):
-        auth_ids = self.compute_auth_events(builder, context.current_state)
+        auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
 
         auth_events_entries = yield self.store.add_event_hashes(
             auth_ids
@@ -855,30 +866,32 @@ class Auth(object):
 
         builder.auth_events = auth_events_entries
 
-    def compute_auth_events(self, event, current_state):
+    @defer.inlineCallbacks
+    def compute_auth_events(self, event, current_state_ids, for_verification=False):
         if event.type == EventTypes.Create:
-            return []
+            defer.returnValue([])
 
         auth_ids = []
 
         key = (EventTypes.PowerLevels, "", )
-        power_level_event = current_state.get(key)
+        power_level_event_id = current_state_ids.get(key)
 
-        if power_level_event:
-            auth_ids.append(power_level_event.event_id)
+        if power_level_event_id:
+            auth_ids.append(power_level_event_id)
 
         key = (EventTypes.JoinRules, "", )
-        join_rule_event = current_state.get(key)
+        join_rule_event_id = current_state_ids.get(key)
 
         key = (EventTypes.Member, event.user_id, )
-        member_event = current_state.get(key)
+        member_event_id = current_state_ids.get(key)
 
         key = (EventTypes.Create, "", )
-        create_event = current_state.get(key)
-        if create_event:
-            auth_ids.append(create_event.event_id)
+        create_event_id = current_state_ids.get(key)
+        if create_event_id:
+            auth_ids.append(create_event_id)
 
-        if join_rule_event:
+        if join_rule_event_id:
+            join_rule_event = yield self.store.get_event(join_rule_event_id)
             join_rule = join_rule_event.content.get("join_rule")
             is_public = join_rule == JoinRules.PUBLIC if join_rule else False
         else:
@@ -887,15 +900,21 @@ class Auth(object):
         if event.type == EventTypes.Member:
             e_type = event.content["membership"]
             if e_type in [Membership.JOIN, Membership.INVITE]:
-                if join_rule_event:
-                    auth_ids.append(join_rule_event.event_id)
+                if join_rule_event_id:
+                    auth_ids.append(join_rule_event_id)
 
             if e_type == Membership.JOIN:
-                if member_event and not is_public:
-                    auth_ids.append(member_event.event_id)
+                if member_event_id and not is_public:
+                    auth_ids.append(member_event_id)
             else:
-                if member_event:
-                    auth_ids.append(member_event.event_id)
+                if member_event_id:
+                    auth_ids.append(member_event_id)
+
+                if for_verification:
+                    key = (EventTypes.Member, event.state_key, )
+                    existing_event_id = current_state_ids.get(key)
+                    if existing_event_id:
+                        auth_ids.append(existing_event_id)
 
             if e_type == Membership.INVITE:
                 if "third_party_invite" in event.content:
@@ -903,14 +922,15 @@ class Auth(object):
                         EventTypes.ThirdPartyInvite,
                         event.content["third_party_invite"]["signed"]["token"]
                     )
-                    third_party_invite = current_state.get(key)
-                    if third_party_invite:
-                        auth_ids.append(third_party_invite.event_id)
-        elif member_event:
+                    third_party_invite_id = current_state_ids.get(key)
+                    if third_party_invite_id:
+                        auth_ids.append(third_party_invite_id)
+        elif member_event_id:
+            member_event = yield self.store.get_event(member_event_id)
             if member_event.content["membership"] == Membership.JOIN:
                 auth_ids.append(member_event.event_id)
 
-        return auth_ids
+        defer.returnValue(auth_ids)
 
     def _get_send_level(self, etype, state_key, auth_events):
         key = (EventTypes.PowerLevels, "", )
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index cf11b4aa2e..c75afd02d8 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,17 +15,8 @@
 
 
 class EventContext(object):
-    def _set_current_state(self, current_state):
-        if current_state is not None:
-            self.current_state_ids = {k: e.event_id for k, e in current_state.items()}
-        else:
-            self.current_state_ids = None
-        self._current_state = current_state
-
-    current_state = property(lambda self: self._current_state, _set_current_state)
-
-    def __init__(self, current_state=None):
-        self.current_state = current_state
+    def __init__(self, current_state_ids=None):
+        self.current_state_ids = current_state_ids
         self.state_group = None
         self.rejected = False
         self.push_actions = []
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 11081a0cd5..e58735294e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -65,33 +65,21 @@ class BaseHandler(object):
                 retry_after_ms=int(1000 * (time_allowed - time_now)),
             )
 
-    def is_host_in_room(self, current_state):
-        room_members = [
-            (state_key, event.membership)
-            for ((event_type, state_key), event) in current_state.items()
-            if event_type == EventTypes.Member
-        ]
-        if len(room_members) == 0:
-            # Have we just created the room, and is this about to be the very
-            # first member event?
-            create_event = current_state.get(("m.room.create", ""))
-            if create_event:
-                return True
-        for (state_key, membership) in room_members:
-            if (
-                self.hs.is_mine_id(state_key)
-                and membership == Membership.JOIN
-            ):
-                return True
-        return False
-
     @defer.inlineCallbacks
-    def maybe_kick_guest_users(self, event, current_state):
+    def maybe_kick_guest_users(self, event, context=None):
         # Technically this function invalidates current_state by changing it.
         # Hopefully this isn't that important to the caller.
         if event.type == EventTypes.GuestAccess:
             guest_access = event.content.get("guest_access", "forbidden")
             if guest_access != "can_join":
+                if context:
+                    current_state = yield self.store.get_events(
+                        context.current_state_ids.values()
+                    )
+                    current_state = current_state.values()
+                else:
+                    current_state = yield self.store.get_current_state(event.room_id)
+                logger.info("maybe_kick_guest_users %r", current_state)
                 yield self.kick_guest_users(current_state)
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 92679532b9..2b88e6550e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -217,11 +217,21 @@ class FederationHandler(BaseHandler):
 
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
-                prev_state = context.current_state.get((event.type, event.state_key))
-                if not prev_state or prev_state.membership != Membership.JOIN:
-                    # Only fire user_joined_room if the user has acutally
-                    # joined the room. Don't bother if the user is just
-                    # changing their profile info.
+                # Only fire user_joined_room if the user has acutally
+                # joined the room. Don't bother if the user is just
+                # changing their profile info.
+                newly_joined = True
+                prev_state_id = context.current_state_ids.get(
+                    (event.type, event.state_key)
+                )
+                if prev_state_id:
+                    prev_state = yield self.store.get_event(
+                        prev_state_id, allow_none=True,
+                    )
+                    if prev_state and prev_state.membership == Membership.JOIN:
+                        newly_joined = False
+
+                if newly_joined:
                     user = UserID.from_string(event.state_key)
                     yield user_joined_room(self.distributor, user, event.room_id)
 
@@ -734,7 +744,7 @@ class FederationHandler(BaseHandler):
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+        yield self.auth.check_from_context(event, context, do_sig_check=False)
 
         defer.returnValue(event)
 
@@ -782,18 +792,11 @@ class FederationHandler(BaseHandler):
 
         new_pdu = event
 
-        destinations = set()
-
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.JOIN:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
-
+        message_handler = self.hs.get_handlers().message_handler
+        destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+            context
+        )
+        destinations = set(destinations)
         destinations.discard(origin)
 
         logger.debug(
@@ -804,13 +807,15 @@ class FederationHandler(BaseHandler):
 
         self.replication_layer.send_pdu(new_pdu, destinations)
 
-        state_ids = [e.event_id for e in context.current_state.values()]
+        state_ids = context.current_state_ids.values()
         auth_chain = yield self.store.get_auth_chain(set(
             [event.event_id] + state_ids
         ))
 
+        state = yield self.store.get_events(context.current_state_ids.values())
+
         defer.returnValue({
-            "state": context.current_state.values(),
+            "state": state.values(),
             "auth_chain": auth_chain,
         })
 
@@ -966,7 +971,7 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+            yield self.auth.check_from_context(event, context, do_sig_check=False)
         except AuthError as e:
             logger.warn("Failed to create new leave %r because %s", event, e)
             raise e
@@ -1010,18 +1015,11 @@ class FederationHandler(BaseHandler):
 
         new_pdu = event
 
-        destinations = set()
-
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.LEAVE:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
-
+        message_handler = self.hs.get_handlers().message_handler
+        destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+            context
+        )
+        destinations = set(destinations)
         destinations.discard(origin)
 
         logger.debug(
@@ -1306,7 +1304,13 @@ class FederationHandler(BaseHandler):
         )
 
         if not auth_events:
-            auth_events = context.current_state
+            auth_events_ids = yield self.auth.compute_auth_events(
+                event, context.current_state_ids, for_verification=True,
+            )
+            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_events.values()
+            }
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
@@ -1332,8 +1336,7 @@ class FederationHandler(BaseHandler):
             context.rejected = RejectedReason.AUTH_ERROR
 
         if event.type == EventTypes.GuestAccess:
-            full_context = yield self.store.get_current_state(room_id=event.room_id)
-            yield self.maybe_kick_guest_users(event, full_context)
+            yield self.maybe_kick_guest_users(event)
 
         defer.returnValue(context)
 
@@ -1504,7 +1507,9 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                context.current_state.update(auth_events)
+                context.current_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                })
                 context.state_group = None
 
         if different_auth and not event.internal_metadata.is_outlier():
@@ -1526,8 +1531,8 @@ class FederationHandler(BaseHandler):
 
             if do_resolution:
                 # 1. Get what we think is the auth chain.
-                auth_ids = self.auth.compute_auth_events(
-                    event, context.current_state
+                auth_ids = yield self.auth.compute_auth_events(
+                    event, context.current_state_ids
                 )
                 local_auth_chain = yield self.store.get_auth_chain(auth_ids)
 
@@ -1583,7 +1588,9 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                context.current_state.update(auth_events)
+                context.current_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                })
                 context.state_group = None
 
         try:
@@ -1770,12 +1777,12 @@ class FederationHandler(BaseHandler):
             )
 
             try:
-                self.auth.check(event, context.current_state)
+                yield self.auth.check_from_context(event, context)
             except AuthError as e:
                 logger.warn("Denying new third party invite %r because %s", event, e)
                 raise e
 
-            yield self._check_signature(event, auth_events=context.current_state)
+            yield self._check_signature(event, context)
             member_handler = self.hs.get_handlers().room_member_handler
             yield member_handler.send_membership_event(None, event, context)
         else:
@@ -1801,11 +1808,11 @@ class FederationHandler(BaseHandler):
         )
 
         try:
-            self.auth.check(event, auth_events=context.current_state)
+            self.auth.check_from_context(event, context)
         except AuthError as e:
             logger.warn("Denying third party invite %r because %s", event, e)
             raise e
-        yield self._check_signature(event, auth_events=context.current_state)
+        yield self._check_signature(event, context)
 
         returned_invite = yield self.send_invite(origin, event)
         # TODO: Make sure the signatures actually are correct.
@@ -1819,7 +1826,12 @@ class FederationHandler(BaseHandler):
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"]
         )
-        original_invite = context.current_state.get(key)
+        original_invite = None
+        original_invite_id = context.current_state_ids.get(key)
+        if original_invite_id:
+            original_invite = yield self.store.get_event(
+                original_invite_id, allow_none=True
+            )
         if not original_invite:
             logger.info(
                 "Could not find invite event for third_party_invite - "
@@ -1836,13 +1848,13 @@ class FederationHandler(BaseHandler):
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
-    def _check_signature(self, event, auth_events):
+    def _check_signature(self, event, context):
         """
         Checks that the signature in the event is consistent with its invite.
 
         Args:
             event (Event): The m.room.member event to check
-            auth_events (dict<(event type, state_key), event>):
+            context (EventContext):
 
         Raises:
             AuthError: if signature didn't match any keys, or key has been
@@ -1853,10 +1865,14 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        invite_event = auth_events.get(
+        invite_event_id = context.current_state_ids.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
 
+        invite_event = None
+        if invite_event_id:
+            invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+
         if not invite_event:
             raise AuthError(403, "Could not find invite")
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4c3cd9d12e..e2f4387f60 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
 from synapse.util.caches.snapshot_cache import SnapshotCache
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.util.metrics import measure_func
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
 
         if event.is_state():
-            prev_state = self.deduplicate_state_event(event, context)
+            prev_state = yield self.deduplicate_state_event(event, context)
             if prev_state is not None:
                 defer.returnValue(prev_state)
 
@@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
             presence = self.hs.get_presence_handler()
             yield presence.bump_presence_active_time(user)
 
+    @defer.inlineCallbacks
     def deduplicate_state_event(self, event, context):
         """
         Checks whether event is in the latest resolved state in context.
@@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_event = context.current_state.get((event.type, event.state_key))
+        prev_event_id = context.current_state_ids.get((event.type, event.state_key))
+        prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+        if not prev_event:
+            return
+
         if prev_event and event.user_id == prev_event.user_id:
             prev_content = encode_canonical_json(prev_event.content)
             next_content = encode_canonical_json(event.content)
             if prev_content == next_content:
-                return prev_event
-        return None
+                defer.returnValue(prev_event)
+        return
 
     @defer.inlineCallbacks
     def create_and_send_nonmember_event(
@@ -803,7 +809,7 @@ class MessageHandler(BaseHandler):
 
         logger.debug(
             "Created event %s with current state: %s",
-            event.event_id, context.current_state,
+            event.event_id, context.current_state_ids,
         )
 
         defer.returnValue(
@@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
             self.ratelimit(requester)
 
         try:
-            self.auth.check(event, auth_events=context.current_state)
+            yield self.auth.check_from_context(event, context)
         except AuthError as err:
             logger.warn("Denying new event %r because %s", event, err)
             raise err
 
-        yield self.maybe_kick_guest_users(event, context.current_state.values())
+        yield self.maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Check the alias is acually valid (at this time at least)
@@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
                         e.sender == event.sender
                     )
 
+                state_to_include_ids = [
+                    e_id
+                    for k, e_id in context.current_state_ids.items()
+                    if k[0] in self.hs.config.room_invite_state_types
+                    or k[0] == EventTypes.Member and k[1] == event.sender
+                ]
+
+                state_to_include = yield self.store.get_events(state_to_include_ids)
+
                 event.unsigned["invite_room_state"] = [
                     {
                         "type": e.type,
@@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
                         "content": e.content,
                         "sender": e.sender,
                     }
-                    for k, e in context.current_state.items()
-                    if e.type in self.hs.config.room_invite_state_types
-                    or is_inviter_member_event(e)
+                    for e in state_to_include.values()
                 ]
 
                 invitee = UserID.from_string(event.state_key)
@@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
                     )
 
         if event.type == EventTypes.Redaction:
-            if self.auth.check_redaction(event, auth_events=context.current_state):
+            auth_events_ids = yield self.auth.compute_auth_events(
+                event, context.current_state_ids, for_verification=True,
+            )
+            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_events.values()
+            }
+            if self.auth.check_redaction(event, auth_events=auth_events):
                 original_event = yield self.store.get_event(
                     event.redacts,
                     check_redacted=False,
@@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
                         "You don't have permission to redact events"
                     )
 
-        if event.type == EventTypes.Create and context.current_state:
+        if event.type == EventTypes.Create and context.current_state_ids:
             raise AuthError(
                 403,
                 "Changing the room create event is forbidden",
@@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
             event_stream_id, max_stream_id
         )
 
-        destinations = set()
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.JOIN:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except SynapseError:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
+        destinations = yield self.get_joined_hosts_for_room_from_state(context)
 
         @defer.inlineCallbacks
         def _notify():
@@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
         preserve_fn(federation_handler.handle_new_event)(
             event, destinations=destinations,
         )
+
+    def get_joined_hosts_for_room_from_state(self, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_hosts_for_room_from_state(
+            state_group, context.current_state_ids
+        )
+
+    @cachedInlineCallbacks(num_args=1, cache_context=True)
+    def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
+                                              cache_context):
+
+        # Don't bother getting state for people on the same HS
+        current_state = yield self.store.get_events([
+            e_id for key, e_id in current_state_ids.items()
+            if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
+        ])
+
+        destinations = set()
+        for e in current_state.itervalues():
+            try:
+                if e.type == EventTypes.Member:
+                    if e.content["membership"] == Membership.JOIN:
+                        destinations.add(get_domain_from_id(e.state_key))
+            except SynapseError:
+                logger.warn(
+                    "Failed to get destination from event %s", e.event_id
+                )
+
+        defer.returnValue(destinations)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8b17632fdc..dd4b90ee24 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -93,20 +93,26 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event = context.current_state.get(
+        prev_member_event_id = context.current_state_ids.get(
             (EventTypes.Member, target.to_string()),
             None
         )
 
         if event.membership == Membership.JOIN:
-            if not prev_member_event or prev_member_event.membership != Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally joined the
-                # room. Don't bother if the user is just changing their profile
-                # info.
+            # Only fire user_joined_room if the user has acutally joined the
+            # room. Don't bother if the user is just changing their profile
+            # info.
+            newly_joined = True
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                newly_joined = prev_member_event.membership != Membership.JOIN
+            if newly_joined:
                 yield user_joined_room(self.distributor, target, room_id)
         elif event.membership == Membership.LEAVE:
-            if prev_member_event and prev_member_event.membership == Membership.JOIN:
-                user_left_room(self.distributor, target, room_id)
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                if prev_member_event.membership == Membership.JOIN:
+                    user_left_room(self.distributor, target, room_id)
 
     @defer.inlineCallbacks
     def remote_join(self, remote_room_hosts, room_id, user, content):
@@ -195,29 +201,32 @@ class RoomMemberHandler(BaseHandler):
             remote_room_hosts = []
 
         latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
-        current_state = yield self.state_handler.get_current_state(
+        current_state_ids = yield self.state_handler.get_current_state_ids(
             room_id, latest_event_ids=latest_event_ids,
         )
 
-        old_state = current_state.get((EventTypes.Member, target.to_string()))
-        old_membership = old_state.content.get("membership") if old_state else None
-        if action == "unban" and old_membership != "ban":
-            raise SynapseError(
-                403,
-                "Cannot unban user who was not banned (membership=%s)" % old_membership,
-                errcode=Codes.BAD_STATE
-            )
-        if old_membership == "ban" and action != "unban":
-            raise SynapseError(
-                403,
-                "Cannot %s user who was banned" % (action,),
-                errcode=Codes.BAD_STATE
-            )
+        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+        if old_state_id:
+            old_state = yield self.store.get_event(old_state_id, allow_none=True)
+            old_membership = old_state.content.get("membership") if old_state else None
+            if action == "unban" and old_membership != "ban":
+                raise SynapseError(
+                    403,
+                    "Cannot unban user who was not banned"
+                    " (membership=%s)" % old_membership,
+                    errcode=Codes.BAD_STATE
+                )
+            if old_membership == "ban" and action != "unban":
+                raise SynapseError(
+                    403,
+                    "Cannot %s user who was banned" % (action,),
+                    errcode=Codes.BAD_STATE
+                )
 
-        is_host_in_room = self.is_host_in_room(current_state)
+        is_host_in_room = yield self._is_host_in_room(current_state_ids)
 
         if effective_membership_state == Membership.JOIN:
-            if requester.is_guest and not self._can_guest_join(current_state):
+            if requester.is_guest and not self._can_guest_join(current_state_ids):
                 # This should be an auth check, but guests are a local concept,
                 # so don't really fit into the general auth process.
                 raise AuthError(403, "Guest access not allowed")
@@ -326,15 +335,17 @@ class RoomMemberHandler(BaseHandler):
             requester = synapse.types.create_requester(target_user)
 
         message_handler = self.hs.get_handlers().message_handler
-        prev_event = message_handler.deduplicate_state_event(event, context)
+        prev_event = yield message_handler.deduplicate_state_event(event, context)
         if prev_event is not None:
             return
 
         if event.membership == Membership.JOIN:
-            if requester.is_guest and not self._can_guest_join(context.current_state):
-                # This should be an auth check, but guests are a local concept,
-                # so don't really fit into the general auth process.
-                raise AuthError(403, "Guest access not allowed")
+            if requester.is_guest:
+                guest_can_join = yield self._can_guest_join(context.current_state_ids)
+                if not guest_can_join:
+                    # This should be an auth check, but guests are a local concept,
+                    # so don't really fit into the general auth process.
+                    raise AuthError(403, "Guest access not allowed")
 
         yield message_handler.handle_new_client_event(
             requester,
@@ -344,27 +355,39 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event = context.current_state.get(
-            (EventTypes.Member, target_user.to_string()),
+        prev_member_event_id = context.current_state_ids.get(
+            (EventTypes.Member, event.state_key),
             None
         )
 
         if event.membership == Membership.JOIN:
-            if not prev_member_event or prev_member_event.membership != Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally joined the
-                # room. Don't bother if the user is just changing their profile
-                # info.
+            # Only fire user_joined_room if the user has acutally joined the
+            # room. Don't bother if the user is just changing their profile
+            # info.
+            newly_joined = True
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                newly_joined = prev_member_event.membership != Membership.JOIN
+            if newly_joined:
                 yield user_joined_room(self.distributor, target_user, room_id)
         elif event.membership == Membership.LEAVE:
-            if prev_member_event and prev_member_event.membership == Membership.JOIN:
-                user_left_room(self.distributor, target_user, room_id)
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                if prev_member_event.membership == Membership.JOIN:
+                    user_left_room(self.distributor, target_user, room_id)
 
-    def _can_guest_join(self, current_state):
+    @defer.inlineCallbacks
+    def _can_guest_join(self, current_state_ids):
         """
         Returns whether a guest can join a room based on its current state.
         """
-        guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
-        return (
+        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+        if not guest_access_id:
+            defer.returnValue(False)
+
+        guest_access = yield self.store.get_event(guest_access_id)
+
+        defer.returnValue(
             guest_access
             and guest_access.content
             and "guest_access" in guest_access.content
@@ -683,3 +706,24 @@ class RoomMemberHandler(BaseHandler):
 
         if membership:
             yield self.store.forget(user_id, room_id)
+
+    @defer.inlineCallbacks
+    def _is_host_in_room(self, current_state_ids):
+        # Have we just created the room, and is this about to be the very
+        # first member event?
+        create_event_id = current_state_ids.get(("m.room.create", ""))
+        if len(current_state_ids) == 1 and create_event_id:
+            defer.returnValue(self.hs.is_mine_id(create_event_id))
+
+        for (etype, state_key), event_id in current_state_ids.items():
+            if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
+                continue
+
+            event = yield self.store.get_event(event_id, allow_none=True)
+            if not event:
+                continue
+
+            if event.membership == Membership.JOIN:
+                defer.returnValue(True)
+
+        defer.returnValue(False)
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index ed2ccc4dfb..3f75d3f921 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,12 +40,12 @@ class ActionGenerator:
     def handle_push_actions_for_event(self, event, context):
         with Measure(self.clock, "evaluator_for_event"):
             bulk_evaluator = yield evaluator_for_event(
-                event, self.hs, self.store, context.state_group, context.current_state
+                event, self.hs, self.store, context
             )
 
         with Measure(self.clock, "action_for_event_by_user"):
             actions_by_user = yield bulk_evaluator.action_for_event_by_user(
-                event, context.current_state
+                event, context
             )
 
         context.push_actions = [
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 004eded61f..8d49beaec5 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,8 @@ from twisted.internet import defer
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
-from synapse.api.constants import EventTypes, Membership
-from synapse.visibility import filter_events_for_clients
+from synapse.api.constants import EventTypes
+from synapse.visibility import filter_events_for_clients_context
 
 
 logger = logging.getLogger(__name__)
@@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store):
 
 
 @defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, state_group, current_state):
+def evaluator_for_event(event, hs, store, context):
     rules_by_user = yield store.bulk_get_push_rules_for_room(
-        event.room_id, state_group, current_state
+        event.room_id, context
     )
 
     # if this event is an invite event, we may need to run rules for the user
@@ -72,7 +72,7 @@ class BulkPushRuleEvaluator:
         self.store = store
 
     @defer.inlineCallbacks
-    def action_for_event_by_user(self, event, current_state):
+    def action_for_event_by_user(self, event, context):
         actions_by_user = {}
 
         # None of these users can be peeking since this list of users comes
@@ -82,27 +82,25 @@ class BulkPushRuleEvaluator:
             (u, False) for u in self.rules_by_user.keys()
         ]
 
-        filtered_by_user = yield filter_events_for_clients(
-            self.store, user_tuples, [event], {event.event_id: current_state}
+        filtered_by_user = yield filter_events_for_clients_context(
+            self.store, user_tuples, [event], {event.event_id: context}
         )
 
-        room_members = set(
-            e.state_key for e in current_state.values()
-            if e.type == EventTypes.Member and e.membership == Membership.JOIN
+        room_members = yield self.store.get_joined_users_from_context(
+            event.room_id, context,
         )
 
         evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
 
         condition_cache = {}
 
-        display_names = {}
-        for ev in current_state.values():
-            nm = ev.content.get("displayname", None)
-            if nm and ev.type == EventTypes.Member:
-                display_names[ev.state_key] = nm
-
         for uid, rules in self.rules_by_user.items():
-            display_name = display_names.get(uid, None)
+            display_name = None
+            member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
+            if member_ev_id:
+                member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
+                if member_ev:
+                    display_name = member_ev.content.get("displayname", None)
 
             filtered = filtered_by_user[uid]
             if len(filtered) == 0:
diff --git a/synapse/state.py b/synapse/state.py
index 2249b7fffb..2a01887a67 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -107,6 +107,20 @@ class StateHandler(object):
         defer.returnValue(state)
 
     @defer.inlineCallbacks
+    def get_current_state_ids(self, room_id, event_type=None, state_key="",
+                              latest_event_ids=None):
+        if not latest_event_ids:
+            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+
+        _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
+
+        if event_type:
+            defer.returnValue(state.get((event_type, state_key)))
+            return
+
+        defer.returnValue(state)
+
+    @defer.inlineCallbacks
     def compute_event_context(self, event, old_state=None):
         """ Fills out the context with the `current state` of the graph. The
         `current state` here is defined to be the state of the event graph
@@ -127,27 +141,27 @@ class StateHandler(object):
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
             if old_state:
-                context.current_state = {
-                    (s.type, s.state_key): s for s in old_state
+                context.current_state_ids = {
+                    (s.type, s.state_key): s.event_id for s in old_state
                 }
             else:
-                context.current_state = {}
+                context.current_state_ids = {}
             context.prev_state_events = []
             context.state_group = None
             defer.returnValue(context)
 
         if old_state:
-            context.current_state = {
-                (s.type, s.state_key): s for s in old_state
+            context.current_state_ids = {
+                (s.type, s.state_key): s.event_id for s in old_state
             }
             context.state_group = None
 
             if event.is_state():
                 key = (event.type, event.state_key)
-                if key in context.current_state:
-                    replaces = context.current_state[key]
-                    if replaces.event_id != event.event_id:  # Paranoia check
-                        event.unsigned["replaces_state"] = replaces.event_id
+                if key in context.current_state_ids:
+                    replaces = context.current_state_ids[key]
+                    if replaces != event.event_id:  # Paranoia check
+                        event.unsigned["replaces_state"] = replaces
 
             context.prev_state_events = []
             defer.returnValue(context)
@@ -165,22 +179,14 @@ class StateHandler(object):
 
         group, curr_state = ret
 
-        state_map = yield self.store.get_events(
-            curr_state.values(),
-            get_prev_content=False
-        )
-        curr_state = {
-            key: state_map[e_id] for key, e_id in curr_state.items() if e_id in state_map
-        }
-
-        context.current_state = curr_state
+        context.current_state_ids = curr_state
         context.state_group = group if not event.is_state() else None
 
         if event.is_state():
             key = (event.type, event.state_key)
-            if key in context.current_state:
-                replaces = context.current_state[key]
-                event.unsigned["replaces_state"] = replaces.event_id
+            if key in context.current_state_ids:
+                replaces = context.current_state_ids[key]
+                event.unsigned["replaces_state"] = replaces
 
         context.prev_state_events = []
         defer.returnValue(context)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 78334a98cf..7e6ec411cd 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -124,7 +124,8 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(results)
 
-    def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
+    def bulk_get_push_rules_for_room(self, room_id, context):
+        state_group = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -132,10 +133,12 @@ class PushRuleStore(SQLBaseStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
+        return self._bulk_get_push_rules_for_room(
+            room_id, state_group, context.current_state_ids
+        )
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
-    def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
+    def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
                                       cache_context):
         # We don't use `state_group`, its there so that we can cache based
         # on it. However, its important that its never None, since two current_state's
@@ -147,10 +150,16 @@ class PushRuleStore(SQLBaseStore):
         # their unread countss are correct in the event stream, but to avoid
         # generating them for bot / AS users etc, we only do so for people who've
         # sent a read receipt into the room.
+        local_user_member_ids = [
+            e_id for (etype, state_key), e_id in current_state_ids.iteritems()
+            if etype == EventTypes.Member and self.hs.is_mine_id(state_key)
+        ]
+
+        local_member_events = yield self._get_events(local_user_member_ids)
+
         local_users_in_room = set(
-            e.state_key for e in current_state.values()
-            if e.type == EventTypes.Member and e.membership == Membership.JOIN
-            and self.hs.is_mine_id(e.state_key)
+            member_event.state_key for member_event in local_member_events
+            if member_event.membership == Membership.JOIN
         )
 
         # users in the room who have pushers need to get push rules run because
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index a422ddf633..3ffad672a7 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -20,7 +20,7 @@ from collections import namedtuple
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, EventTypes
 from synapse.types import get_domain_from_id
 
 import logging
@@ -325,7 +325,8 @@ class RoomMemberStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=3)
     def was_forgotten_at(self, user_id, room_id, event_id):
-        """Returns whether user_id has elected to discard history for room_id at event_id.
+        """Returns whether user_id has elected to discard history for room_id at
+        event_id.
 
         event_id must be a membership event."""
         def f(txn):
@@ -358,3 +359,43 @@ class RoomMemberStore(SQLBaseStore):
             },
             desc="who_forgot"
         )
+
+    def get_joined_users_from_context(self, room_id, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_users_from_context(
+            room_id, state_group, context.current_state_ids
+        )
+
+    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
+                                       cache_context):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        member_event_ids = [
+            e_id
+            for key, e_id in current_state_ids.iteritems()
+            if key[0] == EventTypes.Member
+        ]
+
+        rows = yield self._simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=member_event_ids,
+            retcols=['user_id'],
+            keyvalues={
+                "membership": Membership.JOIN,
+            }
+        )
+
+        defer.returnValue(set(row["user_id"] for row in rows))
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index fa40af6933..22f7fb1aa1 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -89,17 +89,17 @@ class StateStore(SQLBaseStore):
             if event.internal_metadata.is_outlier():
                 continue
 
-            if context.current_state is None:
+            if context.current_state_ids is None:
                 continue
 
             if context.state_group is not None:
                 state_groups[event.event_id] = context.state_group
                 continue
 
-            state_events = dict(context.current_state)
+            state_event_ids = dict(context.current_state_ids)
 
             if event.is_state():
-                state_events[(event.type, event.state_key)] = event
+                state_event_ids[(event.type, event.state_key)] = event.event_id
 
             state_group = context.new_state_group_id
 
@@ -119,12 +119,12 @@ class StateStore(SQLBaseStore):
                 values=[
                     {
                         "state_group": state_group,
-                        "room_id": state.room_id,
-                        "type": state.type,
-                        "state_key": state.state_key,
-                        "event_id": state.event_id,
+                        "room_id": event.room_id,
+                        "type": key[0],
+                        "state_key": key[1],
+                        "event_id": state_id,
                     }
-                    for state in state_events.values()
+                    for key, state_id in state_event_ids.items()
                 ],
             )
             state_groups[event.event_id] = state_group
diff --git a/synapse/visibility.py b/synapse/visibility.py
index cc12c0a23d..199b16d827 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -181,6 +181,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
 
 
 @defer.inlineCallbacks
+def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
+    user_ids = set(u[0] for u in user_tuples)
+    event_id_to_state = {}
+    for event_id, context in event_id_to_context.items():
+        state = yield store.get_events([
+            e_id
+            for key, e_id in context.current_state_ids.iteritems()
+            if key == (EventTypes.RoomHistoryVisibility, "")
+            or (key[0] == EventTypes.Member and key[1] in user_ids)
+        ])
+        event_id_to_state[event_id] = state
+
+    res = yield filter_events_for_clients(
+        store, user_tuples, events, event_id_to_state
+    )
+    defer.returnValue(res)
+
+
+@defer.inlineCallbacks
 def filter_events_for_client(store, user_id, events, is_peeking=False):
     """
     Check which events a user is allowed to see
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f33e6f60fb..218cb24889 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -305,7 +305,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         self.event_id += 1
 
-        context = EventContext(current_state=state)
+        if state is not None:
+            state_ids = {
+                key: e.event_id for key, e in state.items()
+            }
+        else:
+            state_ids = None
+
+        context = EventContext(current_state_ids=state_ids)
         context.push_actions = push_actions
 
         ordering = None
diff --git a/tests/test_state.py b/tests/test_state.py
index 1a11bbcee0..df9362c985 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -69,7 +69,7 @@ class StateGroupStore(object):
 
         self._next_group = 1
 
-    def get_state_groups(self, room_id, event_ids):
+    def get_state_groups_ids(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
@@ -79,20 +79,20 @@ class StateGroupStore(object):
         return defer.succeed(groups)
 
     def store_state_groups(self, event, context):
-        if context.current_state is None:
+        if context.current_state_ids is None:
             return
 
-        state_events = context.current_state
+        state_events = dict(context.current_state_ids)
 
         if event.is_state():
-            state_events[(event.type, event.state_key)] = event
+            state_events[(event.type, event.state_key)] = event.event_id
 
         state_group = context.state_group
         if not state_group:
             state_group = self._next_group
             self._next_group += 1
 
-            self._group_to_state[state_group] = state_events.values()
+            self._group_to_state[state_group] = state_events
 
         self._event_to_state_group[event.event_id] = state_group
 
@@ -136,7 +136,7 @@ class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = Mock(
             spec_set=[
-                "get_state_groups",
+                "get_state_groups_ids",
                 "add_event_hashes",
             ]
         )
@@ -187,7 +187,7 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
 
         context_store = {}
 
@@ -196,7 +196,7 @@ class StateTestCase(unittest.TestCase):
             store.store_state_groups(event, context)
             context_store[event.event_id] = context
 
-        self.assertEqual(2, len(context_store["D"].current_state))
+        self.assertEqual(2, len(context_store["D"].current_state_ids))
 
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
@@ -239,7 +239,7 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
 
         context_store = {}
 
@@ -303,7 +303,7 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
 
         context_store = {}
 
@@ -384,7 +384,7 @@ class StateTestCase(unittest.TestCase):
         graph = Graph(nodes, edges)
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
 
         context_store = {}
 
@@ -424,13 +424,8 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
-            set(old_state), set(context.current_state.values())
+            set(e.event_id for e in old_state), set(context.current_state_ids.values())
         )
 
         self.assertIsNone(context.state_group)
@@ -449,14 +444,8 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
-            set(old_state),
-            set(context.current_state.values())
+            set(e.event_id for e in old_state), set(context.current_state_ids.values())
         )
 
         self.assertIsNone(context.state_group)
@@ -473,20 +462,15 @@ class StateTestCase(unittest.TestCase):
 
         group_name = "group_name_1"
 
-        self.store.get_state_groups.return_value = {
-            group_name: old_state,
+        self.store.get_state_groups_ids.return_value = {
+            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
         }
 
         context = yield self.state.compute_event_context(event)
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set([e.event_id for e in context.current_state.values()])
+            set(context.current_state_ids.values())
         )
 
         self.assertEqual(group_name, context.state_group)
@@ -503,20 +487,15 @@ class StateTestCase(unittest.TestCase):
 
         group_name = "group_name_1"
 
-        self.store.get_state_groups.return_value = {
-            group_name: old_state,
+        self.store.get_state_groups_ids.return_value = {
+            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
         }
 
         context = yield self.state.compute_event_context(event)
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set([e.event_id for e in context.current_state.values()])
+            set(context.current_state_ids.values())
         )
 
         self.assertIsNone(context.state_group)
@@ -545,7 +524,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 6)
+        self.assertEqual(len(context.current_state_ids), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -573,7 +552,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 6)
+        self.assertEqual(len(context.current_state_ids), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -608,7 +587,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")])
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
@@ -627,15 +606,15 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")])
 
     def _get_context(self, event, old_state_1, old_state_2):
         group_name_1 = "group_name_1"
         group_name_2 = "group_name_2"
 
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
+        self.store.get_state_groups_ids.return_value = {
+            group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
+            group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
         }
 
         return self.state.compute_event_context(event)