summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-09-08 15:26:26 +0100
committerErik Johnston <erik@matrix.org>2016-09-08 15:26:26 +0100
commit5834c6178cd8ab8121b0ad87bfe0d3b69a48c21b (patch)
tree3d1fe71fc6a53ea65981f814facd80d79d99bf68
parentMerge branch 'release-v0.17.1' of github.com:matrix-org/synapse (diff)
parentBump version and changelog (diff)
downloadsynapse-5834c6178cd8ab8121b0ad87bfe0d3b69a48c21b.tar.xz
Merge branch 'release-v0.17.2' of github.com:matrix-org/synapse v0.17.2
-rw-r--r--CHANGES.rst37
-rw-r--r--README.rst24
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py90
-rw-r--r--synapse/api/constants.py5
-rw-r--r--synapse/api/urls.py1
-rw-r--r--synapse/app/synchrotron.py5
-rw-r--r--synapse/appservice/__init__.py2
-rw-r--r--synapse/appservice/api.py51
-rw-r--r--synapse/config/appservice.py17
-rw-r--r--synapse/events/__init__.py2
-rw-r--r--synapse/events/snapshot.py6
-rw-r--r--synapse/federation/federation_client.py15
-rw-r--r--synapse/federation/federation_server.py10
-rw-r--r--synapse/handlers/_base.py30
-rw-r--r--synapse/handlers/appservice.py10
-rw-r--r--synapse/handlers/directory.py8
-rw-r--r--synapse/handlers/events.py3
-rw-r--r--synapse/handlers/federation.py227
-rw-r--r--synapse/handlers/message.py93
-rw-r--r--synapse/handlers/presence.py48
-rw-r--r--synapse/handlers/receipts.py5
-rw-r--r--synapse/handlers/room_member.py130
-rw-r--r--synapse/handlers/sync.py126
-rw-r--r--synapse/handlers/typing.py9
-rw-r--r--synapse/notifier.py3
-rw-r--r--synapse/push/action_generator.py4
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py32
-rw-r--r--synapse/push/httppusher.py2
-rw-r--r--synapse/push/mailer.py73
-rw-r--r--synapse/push/presentable_names.py (renamed from synapse/util/presentable_names.py)96
-rw-r--r--synapse/push/push_tools.py13
-rw-r--r--synapse/replication/resource.py43
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py42
-rw-r--r--synapse/replication/slave/storage/events.py13
-rw-r--r--synapse/rest/__init__.py2
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py90
-rw-r--r--synapse/rest/client/v2_alpha/sync.py9
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py24
-rw-r--r--synapse/state.py232
-rw-r--r--synapse/storage/__init__.py7
-rw-r--r--synapse/storage/deviceinbox.py184
-rw-r--r--synapse/storage/events.py78
-rw-r--r--synapse/storage/push_rule.py23
-rw-r--r--synapse/storage/receipts.py2
-rw-r--r--synapse/storage/roommember.py106
-rw-r--r--synapse/storage/schema/delta/34/device_inbox.sql24
-rw-r--r--synapse/storage/schema/delta/34/sent_txn_purge.py32
-rw-r--r--synapse/storage/state.py128
-rw-r--r--synapse/storage/transactions.py4
-rw-r--r--synapse/streams/events.py2
-rw-r--r--synapse/types.py9
-rw-r--r--synapse/visibility.py19
-rw-r--r--tests/handlers/test_presence.py47
-rw-r--r--tests/handlers/test_typing.py6
-rw-r--r--tests/replication/slave/storage/test_events.py11
-rw-r--r--tests/replication/test_resource.py10
-rw-r--r--tests/rest/client/v1/test_rooms.py4
-rw-r--r--tests/storage/test_roommember.py41
-rw-r--r--tests/test_state.py138
60 files changed, 1836 insertions, 673 deletions
diff --git a/CHANGES.rst b/CHANGES.rst
index 49673ccce4..c40a32abd6 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,40 @@
+Changes in synapse v0.17.2 (2016-09-08)
+=======================================
+
+This release contains security bug fixes. Please upgrade.
+
+
+No changes since v0.17.2
+
+
+Changes in synapse v0.17.2-rc1 (2016-09-05)
+===========================================
+
+Features:
+
+* Start adding store-and-forward direct-to-device messaging (PR #1046, #1050,
+  #1062, #1066)
+
+
+Changes:
+
+* Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063,
+  #1068)
+* Don't notify for online to online presence transitions. (PR #1054)
+* Occasionally persist unpersisted presence updates (PR #1055)
+* Allow application services to have an optional 'url' (PR #1056)
+* Clean up old sent transactions from DB (PR #1059)
+
+
+Bug fixes:
+
+* Fix None check in backfill (PR #1043)
+* Fix membership changes to be idempotent (PR #1067)
+* Fix bug in get_pdu where it would sometimes return events with incorrect
+  signature
+
+
+
 Changes in synapse v0.17.1 (2016-08-24)
 =======================================
 
diff --git a/README.rst b/README.rst
index 172dd4dfa0..f1ccc8dc45 100644
--- a/README.rst
+++ b/README.rst
@@ -134,6 +134,12 @@ Installing prerequisites on Raspbian::
     sudo pip install --upgrade ndg-httpsclient
     sudo pip install --upgrade virtualenv
 
+Installing prerequisites on openSUSE::
+
+    sudo zypper in -t pattern devel_basis
+    sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
+                   python-devel libffi-devel libopenssl-devel libjpeg62-devel
+
 To install the synapse homeserver run::
 
     virtualenv -p python2.7 ~/.synapse
@@ -199,6 +205,21 @@ run (e.g. ``~/.synapse``), and::
     source ./bin/activate
     synctl start
 
+Security Note
+=============
+
+Matrix serves raw user generated data in some APIs - specifically the content
+repository endpoints: http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid
+Whilst we have tried to mitigate against possible XSS attacks (e.g.
+https://github.com/matrix-org/synapse/pull/1021) we recommend running
+matrix homeservers on a dedicated domain name, to limit any malicious user generated
+content served to web browsers a matrix API from being able to attack webapps hosted
+on the same domain.  This is particularly true of sharing a matrix webclient and
+server on the same domain.
+
+See https://github.com/vector-im/vector-web/issues/1977 and
+https://developer.github.com/changes/2014-04-25-user-content-security for more details.
+
 Using PostgreSQL
 ================
 
@@ -215,9 +236,6 @@ The advantages of Postgres include:
   pointing at the same DB master, as well as enabling DB replication in
   synapse itself.
 
-The only disadvantage is that the code is relatively new as of April 2015 and
-may have a few regressions relative to SQLite.
-
 For information on how to install and use PostgreSQL, please see
 `docs/postgres.rst <docs/postgres.rst>`_.
 
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 43bf78f885..523deaa5ff 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.17.1"
+__version__ = "0.17.2"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 0db26fcfd7..dcda40863f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -52,7 +52,7 @@ class Auth(object):
         self.state = hs.get_state_handler()
         self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
         # Docs for these currently lives at
-        # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
+        # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
         # In addition, we have type == delete_pusher which grants access only to
         # delete pushers.
         self._KNOWN_CAVEAT_PREFIXES = set([
@@ -63,6 +63,17 @@ class Auth(object):
             "user_id = ",
         ])
 
+    @defer.inlineCallbacks
+    def check_from_context(self, event, context, do_sig_check=True):
+        auth_events_ids = yield self.compute_auth_events(
+            event, context.prev_state_ids, for_verification=True,
+        )
+        auth_events = yield self.store.get_events(auth_events_ids)
+        auth_events = {
+            (e.type, e.state_key): e for e in auth_events.values()
+        }
+        self.check(event, auth_events=auth_events, do_sig_check=False)
+
     def check(self, event, auth_events, do_sig_check=True):
         """ Checks if this event is correctly authed.
 
@@ -267,21 +278,17 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def check_host_in_room(self, room_id, host):
-        curr_state = yield self.state.get_current_state(room_id)
-
-        for event in curr_state.values():
-            if event.type == EventTypes.Member:
-                try:
-                    if get_domain_from_id(event.state_key) != host:
-                        continue
-                except:
-                    logger.warn("state_key not user_id: %s", event.state_key)
-                    continue
+        with Measure(self.clock, "check_host_in_room"):
+            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
-                if event.content["membership"] == Membership.JOIN:
-                    defer.returnValue(True)
+            entry = yield self.state.resolve_state_groups(
+                room_id, latest_event_ids
+            )
 
-        defer.returnValue(False)
+            ret = yield self.store.is_host_joined(
+                room_id, host, entry.state_group, entry.state
+            )
+            defer.returnValue(ret)
 
     def check_event_sender_in_room(self, event, auth_events):
         key = (EventTypes.Member, event.user_id, )
@@ -847,7 +854,7 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def add_auth_events(self, builder, context):
-        auth_ids = self.compute_auth_events(builder, context.current_state)
+        auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
 
         auth_events_entries = yield self.store.add_event_hashes(
             auth_ids
@@ -855,30 +862,32 @@ class Auth(object):
 
         builder.auth_events = auth_events_entries
 
-    def compute_auth_events(self, event, current_state):
+    @defer.inlineCallbacks
+    def compute_auth_events(self, event, current_state_ids, for_verification=False):
         if event.type == EventTypes.Create:
-            return []
+            defer.returnValue([])
 
         auth_ids = []
 
         key = (EventTypes.PowerLevels, "", )
-        power_level_event = current_state.get(key)
+        power_level_event_id = current_state_ids.get(key)
 
-        if power_level_event:
-            auth_ids.append(power_level_event.event_id)
+        if power_level_event_id:
+            auth_ids.append(power_level_event_id)
 
         key = (EventTypes.JoinRules, "", )
-        join_rule_event = current_state.get(key)
+        join_rule_event_id = current_state_ids.get(key)
 
         key = (EventTypes.Member, event.user_id, )
-        member_event = current_state.get(key)
+        member_event_id = current_state_ids.get(key)
 
         key = (EventTypes.Create, "", )
-        create_event = current_state.get(key)
-        if create_event:
-            auth_ids.append(create_event.event_id)
+        create_event_id = current_state_ids.get(key)
+        if create_event_id:
+            auth_ids.append(create_event_id)
 
-        if join_rule_event:
+        if join_rule_event_id:
+            join_rule_event = yield self.store.get_event(join_rule_event_id)
             join_rule = join_rule_event.content.get("join_rule")
             is_public = join_rule == JoinRules.PUBLIC if join_rule else False
         else:
@@ -887,15 +896,21 @@ class Auth(object):
         if event.type == EventTypes.Member:
             e_type = event.content["membership"]
             if e_type in [Membership.JOIN, Membership.INVITE]:
-                if join_rule_event:
-                    auth_ids.append(join_rule_event.event_id)
+                if join_rule_event_id:
+                    auth_ids.append(join_rule_event_id)
 
             if e_type == Membership.JOIN:
-                if member_event and not is_public:
-                    auth_ids.append(member_event.event_id)
+                if member_event_id and not is_public:
+                    auth_ids.append(member_event_id)
             else:
-                if member_event:
-                    auth_ids.append(member_event.event_id)
+                if member_event_id:
+                    auth_ids.append(member_event_id)
+
+                if for_verification:
+                    key = (EventTypes.Member, event.state_key, )
+                    existing_event_id = current_state_ids.get(key)
+                    if existing_event_id:
+                        auth_ids.append(existing_event_id)
 
             if e_type == Membership.INVITE:
                 if "third_party_invite" in event.content:
@@ -903,14 +918,15 @@ class Auth(object):
                         EventTypes.ThirdPartyInvite,
                         event.content["third_party_invite"]["signed"]["token"]
                     )
-                    third_party_invite = current_state.get(key)
-                    if third_party_invite:
-                        auth_ids.append(third_party_invite.event_id)
-        elif member_event:
+                    third_party_invite_id = current_state_ids.get(key)
+                    if third_party_invite_id:
+                        auth_ids.append(third_party_invite_id)
+        elif member_event_id:
+            member_event = yield self.store.get_event(member_event_id)
             if member_event.content["membership"] == Membership.JOIN:
                 auth_ids.append(member_event.event_id)
 
-        return auth_ids
+        defer.returnValue(auth_ids)
 
     def _get_send_level(self, etype, state_key, auth_events):
         key = (EventTypes.PowerLevels, "", )
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 8cf4d6169c..a8123cddcb 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -85,3 +85,8 @@ class RoomCreationPreset(object):
     PRIVATE_CHAT = "private_chat"
     PUBLIC_CHAT = "public_chat"
     TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
+
+
+class ThirdPartyEntityKind(object):
+    USER = "user"
+    LOCATION = "location"
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 0fd9b7f244..91a33a3402 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
 SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
 MEDIA_PREFIX = "/_matrix/media/r0"
 LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
-APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index e3173533e2..07d3d047c6 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -36,6 +36,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.replication.slave.storage.filtering import SlavedFilteringStore
 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
 from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.server import HomeServer
 from synapse.storage.client_ips import ClientIpStore
 from synapse.storage.engines import create_engine
@@ -72,6 +73,7 @@ class SynchrotronSlavedStore(
     SlavedRegistrationStore,
     SlavedFilteringStore,
     SlavedPresenceStore,
+    SlavedDeviceInboxStore,
     BaseSlavedStore,
     ClientIpStore,  # After BaseSlavedStore because the constructor is different
 ):
@@ -397,6 +399,9 @@ class SynchrotronServer(HomeServer):
             notify_from_stream(
                 result, "typing", "typing_key", room="room_id"
             )
+            notify_from_stream(
+                result, "to_device", "to_device_key", user="user_id"
+            )
 
         while True:
             try:
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index bde9b51b2e..126a10efb7 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -88,6 +88,8 @@ class ApplicationService(object):
         self.sender = sender
         self.namespaces = self._check_namespaces(namespaces)
         self.id = id
+
+        # .protocols is a publicly visible field
         if protocols:
             self.protocols = set(protocols)
         else:
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 066127b666..cc4af23962 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,10 +14,11 @@
 # limitations under the License.
 from twisted.internet import defer
 
+from synapse.api.constants import ThirdPartyEntityKind
 from synapse.api.errors import CodeMessageException
 from synapse.http.client import SimpleHttpClient
 from synapse.events.utils import serialize_event
-from synapse.types import ThirdPartyEntityKind
+from synapse.util.caches.response_cache import ResponseCache
 
 import logging
 import urllib
@@ -25,6 +26,12 @@ import urllib
 logger = logging.getLogger(__name__)
 
 
+HOUR_IN_MS = 60 * 60 * 1000
+
+
+APP_SERVICE_PREFIX = "/_matrix/app/unstable"
+
+
 def _is_valid_3pe_result(r, field):
     if not isinstance(r, dict):
         return False
@@ -56,8 +63,12 @@ class ApplicationServiceApi(SimpleHttpClient):
         super(ApplicationServiceApi, self).__init__(hs)
         self.clock = hs.get_clock()
 
+        self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
+
     @defer.inlineCallbacks
     def query_user(self, service, user_id):
+        if service.url is None:
+            defer.returnValue(False)
         uri = service.url + ("/users/%s" % urllib.quote(user_id))
         response = None
         try:
@@ -77,6 +88,8 @@ class ApplicationServiceApi(SimpleHttpClient):
 
     @defer.inlineCallbacks
     def query_alias(self, service, alias):
+        if service.url is None:
+            defer.returnValue(False)
         uri = service.url + ("/rooms/%s" % urllib.quote(alias))
         response = None
         try:
@@ -97,16 +110,22 @@ class ApplicationServiceApi(SimpleHttpClient):
     @defer.inlineCallbacks
     def query_3pe(self, service, kind, protocol, fields):
         if kind == ThirdPartyEntityKind.USER:
-            uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
             required_field = "userid"
         elif kind == ThirdPartyEntityKind.LOCATION:
-            uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
             required_field = "alias"
         else:
             raise ValueError(
                 "Unrecognised 'kind' argument %r to query_3pe()", kind
             )
+        if service.url is None:
+            defer.returnValue([])
 
+        uri = "%s%s/thirdparty/%s/%s" % (
+            service.url,
+            APP_SERVICE_PREFIX,
+            kind,
+            urllib.quote(protocol)
+        )
         try:
             response = yield self.get_json(uri, fields)
             if not isinstance(response, list):
@@ -131,8 +150,34 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_3pe to %s threw exception %s", uri, ex)
             defer.returnValue([])
 
+    def get_3pe_protocol(self, service, protocol):
+        if service.url is None:
+            defer.returnValue({})
+
+        @defer.inlineCallbacks
+        def _get():
+            uri = "%s%s/thirdparty/protocol/%s" % (
+                service.url,
+                APP_SERVICE_PREFIX,
+                urllib.quote(protocol)
+            )
+            try:
+                defer.returnValue((yield self.get_json(uri, {})))
+            except Exception as ex:
+                logger.warning("query_3pe_protocol to %s threw exception %s",
+                               uri, ex)
+                defer.returnValue({})
+
+        key = (service.id, protocol)
+        return self.protocol_meta_cache.get(key) or (
+            self.protocol_meta_cache.set(key, _get())
+        )
+
     @defer.inlineCallbacks
     def push_bulk(self, service, events, txn_id=None):
+        if service.url is None:
+            defer.returnValue(True)
+
         events = self._serialize(events)
 
         if txn_id is None:
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index dfe43b0b4c..d7537e8d44 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -86,7 +86,7 @@ def load_appservices(hostname, config_files):
 
 def _load_appservice(hostname, as_info, config_filename):
     required_string_fields = [
-        "id", "url", "as_token", "hs_token", "sender_localpart"
+        "id", "as_token", "hs_token", "sender_localpart"
     ]
     for field in required_string_fields:
         if not isinstance(as_info.get(field), basestring):
@@ -94,6 +94,14 @@ def _load_appservice(hostname, as_info, config_filename):
                 field, config_filename,
             ))
 
+    # 'url' must either be a string or explicitly null, not missing
+    # to avoid accidentally turning off push for ASes.
+    if (not isinstance(as_info.get("url"), basestring) and
+            as_info.get("url", "") is not None):
+        raise KeyError(
+            "Required string field or explicit null: 'url' (%s)" % (config_filename,)
+        )
+
     localpart = as_info["sender_localpart"]
     if urllib.quote(localpart) != localpart:
         raise ValueError(
@@ -132,6 +140,13 @@ def _load_appservice(hostname, as_info, config_filename):
         for p in protocols:
             if not isinstance(p, str):
                 raise KeyError("Bad value for 'protocols' item")
+
+    if as_info["url"] is None:
+        logger.info(
+            "(%s) Explicitly empty 'url' provided. This application service"
+            " will not receive events or queries.",
+            config_filename,
+        )
     return ApplicationService(
         token=as_info["as_token"],
         url=as_info["url"],
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 13154b1723..bcb8f33a58 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -99,7 +99,7 @@ class EventBase(object):
 
         return d
 
-    def get(self, key, default):
+    def get(self, key, default=None):
         return self._event_dict.get(key, default)
 
     def get_internal_metadata_dict(self):
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 8a475417a6..e895b1c450 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,9 +15,9 @@
 
 
 class EventContext(object):
-
-    def __init__(self, current_state=None):
-        self.current_state = current_state
+    def __init__(self):
+        self.current_state_ids = None
+        self.prev_state_ids = None
         self.state_group = None
         self.rejected = False
         self.push_actions = []
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index f2b3aceb49..627acc6a4f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.events import FrozenEvent
+from synapse.types import get_domain_from_id
 import synapse.metrics
 
 from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -63,6 +64,7 @@ class FederationClient(FederationBase):
         self._clock.looping_call(
             self._clear_tried_cache, 60 * 1000,
         )
+        self.state = hs.get_state_handler()
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
@@ -267,7 +269,7 @@ class FederationClient(FederationBase):
 
         pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
 
-        pdu = None
+        signed_pdu = None
         for destination in destinations:
             now = self._clock.time_msec()
             last_attempt = pdu_attempts.get(destination, 0)
@@ -297,7 +299,7 @@ class FederationClient(FederationBase):
                         pdu = pdu_list[0]
 
                         # Check signatures are correct.
-                        pdu = yield self._check_sigs_and_hashes([pdu])[0]
+                        signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
 
                         break
 
@@ -320,10 +322,10 @@ class FederationClient(FederationBase):
                 )
                 continue
 
-        if self._get_pdu_cache is not None and pdu:
-            self._get_pdu_cache[event_id] = pdu
+        if self._get_pdu_cache is not None and signed_pdu:
+            self._get_pdu_cache[event_id] = signed_pdu
 
-        defer.returnValue(pdu)
+        defer.returnValue(signed_pdu)
 
     @defer.inlineCallbacks
     @log_function
@@ -811,7 +813,8 @@ class FederationClient(FederationBase):
         if len(signed_events) >= limit:
             defer.returnValue(signed_events)
 
-        servers = yield self.store.get_joined_hosts_for_room(room_id)
+        users = yield self.state.get_current_user_in_room(room_id)
+        servers = set(get_domain_from_id(u) for u in users)
 
         servers = set(servers)
         servers.discard(self.server_name)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index aba19639c7..5621655098 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -223,16 +223,14 @@ class FederationServer(FederationBase):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        pdus = yield self.handler.get_state_for_pdu(
+        state_ids = yield self.handler.get_state_ids_for_pdu(
             room_id, event_id,
         )
-        auth_chain = yield self.store.get_auth_chain(
-            [pdu.event_id for pdu in pdus]
-        )
+        auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
 
         defer.returnValue((200, {
-            "pdu_ids": [pdu.event_id for pdu in pdus],
-            "auth_chain_ids": [pdu.event_id for pdu in auth_chain],
+            "pdu_ids": state_ids,
+            "auth_chain_ids": auth_chain_ids,
         }))
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 11081a0cd5..e58735294e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -65,33 +65,21 @@ class BaseHandler(object):
                 retry_after_ms=int(1000 * (time_allowed - time_now)),
             )
 
-    def is_host_in_room(self, current_state):
-        room_members = [
-            (state_key, event.membership)
-            for ((event_type, state_key), event) in current_state.items()
-            if event_type == EventTypes.Member
-        ]
-        if len(room_members) == 0:
-            # Have we just created the room, and is this about to be the very
-            # first member event?
-            create_event = current_state.get(("m.room.create", ""))
-            if create_event:
-                return True
-        for (state_key, membership) in room_members:
-            if (
-                self.hs.is_mine_id(state_key)
-                and membership == Membership.JOIN
-            ):
-                return True
-        return False
-
     @defer.inlineCallbacks
-    def maybe_kick_guest_users(self, event, current_state):
+    def maybe_kick_guest_users(self, event, context=None):
         # Technically this function invalidates current_state by changing it.
         # Hopefully this isn't that important to the caller.
         if event.type == EventTypes.GuestAccess:
             guest_access = event.content.get("guest_access", "forbidden")
             if guest_access != "can_join":
+                if context:
+                    current_state = yield self.store.get_events(
+                        context.current_state_ids.values()
+                    )
+                    current_state = current_state.values()
+                else:
+                    current_state = yield self.store.get_current_state(event.room_id)
+                logger.info("maybe_kick_guest_users %r", current_state)
                 yield self.kick_guest_users(current_state)
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 306686a384..b440280b74 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -176,6 +176,16 @@ class ApplicationServicesHandler(object):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
+    def get_3pe_protocols(self):
+        services = yield self.store.get_app_services()
+        protocols = {}
+        for s in services:
+            for p in s.protocols:
+                protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
+
+        defer.returnValue(protocols)
+
+    @defer.inlineCallbacks
     def _get_services_for_event(self, event):
         """Retrieve a list of application services interested in this event.
 
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 4bea7f2b19..14352985e2 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -19,7 +19,7 @@ from ._base import BaseHandler
 
 from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
 from synapse.api.constants import EventTypes
-from synapse.types import RoomAlias, UserID
+from synapse.types import RoomAlias, UserID, get_domain_from_id
 
 import logging
 import string
@@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler):
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            servers = yield self.store.get_joined_hosts_for_room(room_id)
+            users = yield self.state.get_current_user_in_room(room_id)
+            servers = set(get_domain_from_id(u) for u in users)
 
         if not servers:
             raise SynapseError(400, "Failed to get server list")
@@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler):
                 Codes.NOT_FOUND
             )
 
-        extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
+        users = yield self.state.get_current_user_in_room(room_id)
+        extra_servers = set(get_domain_from_id(u) for u in users)
         servers = set(extra_servers) | set(servers)
 
         # If this server is in the list of servers, return it first.
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 3a3a1257d3..d3685fb12a 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler):
         self.clock = hs.get_clock()
 
         self.notifier = hs.get_notifier()
+        self.state = hs.get_state_handler()
 
     @defer.inlineCallbacks
     @log_function
@@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler):
                     # Send down presence.
                     if event.state_key == auth_user_id:
                         # Send down presence for everyone in the room.
-                        users = yield self.store.get_users_in_room(event.room_id)
+                        users = yield self.state.get_current_user_in_room(event.room_id)
                         states = yield presence_handler.get_states(
                             users,
                             as_event=True,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 01a761715b..dc90a5dde4 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError
 from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
 )
+from synapse.util.metrics import measure_func
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.util.frozenutils import unfreeze
@@ -100,6 +101,9 @@ class FederationHandler(BaseHandler):
     def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
         """ Called by the ReplicationLayer when we have a new pdu. We need to
         do auth checks and put it through the StateHandler.
+
+        auth_chain and state are None if we already have the necessary state
+        and prev_events in the db
         """
         event = pdu
 
@@ -117,12 +121,21 @@ class FederationHandler(BaseHandler):
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
         # in the room work
-        is_in_room = yield self.auth.check_host_in_room(
-            event.room_id,
-            self.server_name
-        )
-        if not is_in_room and not event.internal_metadata.is_outlier():
-            logger.debug("Got event for room we're not in.")
+        # If state and auth_chain are None, then we don't need to do this check
+        # as we already know we have enough state in the DB to handle this
+        # event.
+        if state and auth_chain and not event.internal_metadata.is_outlier():
+            is_in_room = yield self.auth.check_host_in_room(
+                event.room_id,
+                self.server_name
+            )
+        else:
+            is_in_room = True
+        if not is_in_room:
+            logger.info(
+                "Got event for room we're not in: %r %r",
+                event.room_id, event.event_id
+            )
 
             try:
                 event_stream_id, max_stream_id = yield self._persist_auth_tree(
@@ -217,17 +230,28 @@ class FederationHandler(BaseHandler):
 
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
-                prev_state = context.current_state.get((event.type, event.state_key))
-                if not prev_state or prev_state.membership != Membership.JOIN:
-                    # Only fire user_joined_room if the user has acutally
-                    # joined the room. Don't bother if the user is just
-                    # changing their profile info.
+                # Only fire user_joined_room if the user has acutally
+                # joined the room. Don't bother if the user is just
+                # changing their profile info.
+                newly_joined = True
+                prev_state_id = context.prev_state_ids.get(
+                    (event.type, event.state_key)
+                )
+                if prev_state_id:
+                    prev_state = yield self.store.get_event(
+                        prev_state_id, allow_none=True,
+                    )
+                    if prev_state and prev_state.membership == Membership.JOIN:
+                        newly_joined = False
+
+                if newly_joined:
                     user = UserID.from_string(event.state_key)
                     yield user_joined_room(self.distributor, user, event.room_id)
 
+    @measure_func("_filter_events_for_server")
     @defer.inlineCallbacks
     def _filter_events_for_server(self, server_name, room_id, events):
-        event_to_state = yield self.store.get_state_for_events(
+        event_to_state_ids = yield self.store.get_state_ids_for_events(
             frozenset(e.event_id for e in events),
             types=(
                 (EventTypes.RoomHistoryVisibility, ""),
@@ -235,6 +259,30 @@ class FederationHandler(BaseHandler):
             )
         )
 
+        # We only want to pull out member events that correspond to the
+        # server's domain.
+
+        def check_match(id):
+            try:
+                return server_name == get_domain_from_id(id)
+            except:
+                return False
+
+        event_map = yield self.store.get_events([
+            e_id for key_to_eid in event_to_state_ids.values()
+            for key, e_id in key_to_eid
+            if key[0] != EventTypes.Member or check_match(key[1])
+        ])
+
+        event_to_state = {
+            e_id: {
+                key: event_map[inner_e_id]
+                for key, inner_e_id in key_to_eid.items()
+                if inner_e_id in event_map
+            }
+            for e_id, key_to_eid in event_to_state_ids.items()
+        }
+
         def redact_disallowed(event, state):
             if not state:
                 return event
@@ -377,7 +425,9 @@ class FederationHandler(BaseHandler):
                 )).addErrback(unwrapFirstError)
                 auth_events.update({a.event_id: a for a in results if a})
                 required_auth.update(
-                    a_id for event in results for a_id, _ in event.auth_events if event
+                    a_id
+                    for event in results if event
+                    for a_id, _ in event.auth_events
                 )
                 missing_auth = required_auth - set(auth_events)
 
@@ -560,6 +610,18 @@ class FederationHandler(BaseHandler):
         ]))
         states = dict(zip(event_ids, [s[1] for s in states]))
 
+        state_map = yield self.store.get_events(
+            [e_id for ids in states.values() for e_id in ids],
+            get_prev_content=False
+        )
+        states = {
+            key: {
+                k: state_map[e_id]
+                for k, e_id in state_dict.items()
+                if e_id in state_map
+            } for key, state_dict in states.items()
+        }
+
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
@@ -722,7 +784,7 @@ class FederationHandler(BaseHandler):
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+        yield self.auth.check_from_context(event, context, do_sig_check=False)
 
         defer.returnValue(event)
 
@@ -770,18 +832,11 @@ class FederationHandler(BaseHandler):
 
         new_pdu = event
 
-        destinations = set()
-
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.JOIN:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
-
+        message_handler = self.hs.get_handlers().message_handler
+        destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+            context
+        )
+        destinations = set(destinations)
         destinations.discard(origin)
 
         logger.debug(
@@ -792,13 +847,15 @@ class FederationHandler(BaseHandler):
 
         self.replication_layer.send_pdu(new_pdu, destinations)
 
-        state_ids = [e.event_id for e in context.current_state.values()]
+        state_ids = context.prev_state_ids.values()
         auth_chain = yield self.store.get_auth_chain(set(
             [event.event_id] + state_ids
         ))
 
+        state = yield self.store.get_events(context.prev_state_ids.values())
+
         defer.returnValue({
-            "state": context.current_state.values(),
+            "state": state.values(),
             "auth_chain": auth_chain,
         })
 
@@ -954,7 +1011,7 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+            yield self.auth.check_from_context(event, context, do_sig_check=False)
         except AuthError as e:
             logger.warn("Failed to create new leave %r because %s", event, e)
             raise e
@@ -998,18 +1055,11 @@ class FederationHandler(BaseHandler):
 
         new_pdu = event
 
-        destinations = set()
-
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.LEAVE:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
-
+        message_handler = self.hs.get_handlers().message_handler
+        destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+            context
+        )
+        destinations = set(destinations)
         destinations.discard(origin)
 
         logger.debug(
@@ -1024,6 +1074,8 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def get_state_for_pdu(self, room_id, event_id):
+        """Returns the state at the event. i.e. not including said event.
+        """
         yield run_on_reactor()
 
         state_groups = yield self.store.get_state_groups(
@@ -1065,6 +1117,34 @@ class FederationHandler(BaseHandler):
             defer.returnValue([])
 
     @defer.inlineCallbacks
+    def get_state_ids_for_pdu(self, room_id, event_id):
+        """Returns the state at the event. i.e. not including said event.
+        """
+        yield run_on_reactor()
+
+        state_groups = yield self.store.get_state_groups_ids(
+            room_id, [event_id]
+        )
+
+        if state_groups:
+            _, state = state_groups.items().pop()
+            results = state
+
+            event = yield self.store.get_event(event_id)
+            if event and event.is_state():
+                # Get previous state
+                if "replaces_state" in event.unsigned:
+                    prev_id = event.unsigned["replaces_state"]
+                    if prev_id != event.event_id:
+                        results[(event.type, event.state_key)] = prev_id
+                else:
+                    del results[(event.type, event.state_key)]
+
+            defer.returnValue(results.values())
+        else:
+            defer.returnValue([])
+
+    @defer.inlineCallbacks
     @log_function
     def on_backfill_request(self, origin, room_id, pdu_list, limit):
         in_room = yield self.auth.check_host_in_room(room_id, origin)
@@ -1294,7 +1374,13 @@ class FederationHandler(BaseHandler):
         )
 
         if not auth_events:
-            auth_events = context.current_state
+            auth_events_ids = yield self.auth.compute_auth_events(
+                event, context.prev_state_ids, for_verification=True,
+            )
+            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_events.values()
+            }
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
@@ -1320,8 +1406,7 @@ class FederationHandler(BaseHandler):
             context.rejected = RejectedReason.AUTH_ERROR
 
         if event.type == EventTypes.GuestAccess:
-            full_context = yield self.store.get_current_state(room_id=event.room_id)
-            yield self.maybe_kick_guest_users(event, full_context)
+            yield self.maybe_kick_guest_users(event)
 
         defer.returnValue(context)
 
@@ -1389,6 +1474,11 @@ class FederationHandler(BaseHandler):
         current_state = set(e.event_id for e in auth_events.values())
         event_auth_events = set(e_id for e_id, _ in event.auth_events)
 
+        if event.is_state():
+            event_key = (event.type, event.state_key)
+        else:
+            event_key = None
+
         if event_auth_events - current_state:
             have_events = yield self.store.have_events(
                 event_auth_events - current_state
@@ -1492,8 +1582,14 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                context.current_state.update(auth_events)
-                context.state_group = None
+                context.current_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                    if k != event_key
+                })
+                context.prev_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                })
+                context.state_group = self.store.get_next_state_group()
 
         if different_auth and not event.internal_metadata.is_outlier():
             logger.info("Different auth after resolution: %s", different_auth)
@@ -1514,8 +1610,8 @@ class FederationHandler(BaseHandler):
 
             if do_resolution:
                 # 1. Get what we think is the auth chain.
-                auth_ids = self.auth.compute_auth_events(
-                    event, context.current_state
+                auth_ids = yield self.auth.compute_auth_events(
+                    event, context.prev_state_ids
                 )
                 local_auth_chain = yield self.store.get_auth_chain(auth_ids)
 
@@ -1571,8 +1667,14 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                context.current_state.update(auth_events)
-                context.state_group = None
+                context.current_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                    if k != event_key
+                })
+                context.prev_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                })
+                context.state_group = self.store.get_next_state_group()
 
         try:
             self.auth.check(event, auth_events=auth_events)
@@ -1758,12 +1860,12 @@ class FederationHandler(BaseHandler):
             )
 
             try:
-                self.auth.check(event, context.current_state)
+                yield self.auth.check_from_context(event, context)
             except AuthError as e:
                 logger.warn("Denying new third party invite %r because %s", event, e)
                 raise e
 
-            yield self._check_signature(event, auth_events=context.current_state)
+            yield self._check_signature(event, context)
             member_handler = self.hs.get_handlers().room_member_handler
             yield member_handler.send_membership_event(None, event, context)
         else:
@@ -1789,11 +1891,11 @@ class FederationHandler(BaseHandler):
         )
 
         try:
-            self.auth.check(event, auth_events=context.current_state)
+            self.auth.check_from_context(event, context)
         except AuthError as e:
             logger.warn("Denying third party invite %r because %s", event, e)
             raise e
-        yield self._check_signature(event, auth_events=context.current_state)
+        yield self._check_signature(event, context)
 
         returned_invite = yield self.send_invite(origin, event)
         # TODO: Make sure the signatures actually are correct.
@@ -1807,7 +1909,12 @@ class FederationHandler(BaseHandler):
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"]
         )
-        original_invite = context.current_state.get(key)
+        original_invite = None
+        original_invite_id = context.prev_state_ids.get(key)
+        if original_invite_id:
+            original_invite = yield self.store.get_event(
+                original_invite_id, allow_none=True
+            )
         if not original_invite:
             logger.info(
                 "Could not find invite event for third_party_invite - "
@@ -1824,13 +1931,13 @@ class FederationHandler(BaseHandler):
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
-    def _check_signature(self, event, auth_events):
+    def _check_signature(self, event, context):
         """
         Checks that the signature in the event is consistent with its invite.
 
         Args:
             event (Event): The m.room.member event to check
-            auth_events (dict<(event type, state_key), event>):
+            context (EventContext):
 
         Raises:
             AuthError: if signature didn't match any keys, or key has been
@@ -1841,10 +1948,14 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        invite_event = auth_events.get(
+        invite_event_id = context.prev_state_ids.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
 
+        invite_event = None
+        if invite_event_id:
+            invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+
         if not invite_event:
             raise AuthError(403, "Could not find invite")
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4c3cd9d12e..3577db0595 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
 from synapse.util.caches.snapshot_cache import SnapshotCache
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.util.metrics import measure_func
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
 
         if event.is_state():
-            prev_state = self.deduplicate_state_event(event, context)
+            prev_state = yield self.deduplicate_state_event(event, context)
             if prev_state is not None:
                 defer.returnValue(prev_state)
 
@@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
             presence = self.hs.get_presence_handler()
             yield presence.bump_presence_active_time(user)
 
+    @defer.inlineCallbacks
     def deduplicate_state_event(self, event, context):
         """
         Checks whether event is in the latest resolved state in context.
@@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_event = context.current_state.get((event.type, event.state_key))
+        prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
+        prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+        if not prev_event:
+            return
+
         if prev_event and event.user_id == prev_event.user_id:
             prev_content = encode_canonical_json(prev_event.content)
             next_content = encode_canonical_json(event.content)
             if prev_content == next_content:
-                return prev_event
-        return None
+                defer.returnValue(prev_event)
+        return
 
     @defer.inlineCallbacks
     def create_and_send_nonmember_event(
@@ -802,8 +808,8 @@ class MessageHandler(BaseHandler):
         event = builder.build()
 
         logger.debug(
-            "Created event %s with current state: %s",
-            event.event_id, context.current_state,
+            "Created event %s with state: %s",
+            event.event_id, context.prev_state_ids,
         )
 
         defer.returnValue(
@@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
             self.ratelimit(requester)
 
         try:
-            self.auth.check(event, auth_events=context.current_state)
+            yield self.auth.check_from_context(event, context)
         except AuthError as err:
             logger.warn("Denying new event %r because %s", event, err)
             raise err
 
-        yield self.maybe_kick_guest_users(event, context.current_state.values())
+        yield self.maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Check the alias is acually valid (at this time at least)
@@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
                         e.sender == event.sender
                     )
 
+                state_to_include_ids = [
+                    e_id
+                    for k, e_id in context.current_state_ids.items()
+                    if k[0] in self.hs.config.room_invite_state_types
+                    or k[0] == EventTypes.Member and k[1] == event.sender
+                ]
+
+                state_to_include = yield self.store.get_events(state_to_include_ids)
+
                 event.unsigned["invite_room_state"] = [
                     {
                         "type": e.type,
@@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
                         "content": e.content,
                         "sender": e.sender,
                     }
-                    for k, e in context.current_state.items()
-                    if e.type in self.hs.config.room_invite_state_types
-                    or is_inviter_member_event(e)
+                    for e in state_to_include.values()
                 ]
 
                 invitee = UserID.from_string(event.state_key)
@@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
                     )
 
         if event.type == EventTypes.Redaction:
-            if self.auth.check_redaction(event, auth_events=context.current_state):
+            auth_events_ids = yield self.auth.compute_auth_events(
+                event, context.prev_state_ids, for_verification=True,
+            )
+            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_events.values()
+            }
+            if self.auth.check_redaction(event, auth_events=auth_events):
                 original_event = yield self.store.get_event(
                     event.redacts,
                     check_redacted=False,
@@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
                         "You don't have permission to redact events"
                     )
 
-        if event.type == EventTypes.Create and context.current_state:
+        if event.type == EventTypes.Create and context.prev_state_ids:
             raise AuthError(
                 403,
                 "Changing the room create event is forbidden",
@@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
             event_stream_id, max_stream_id
         )
 
-        destinations = set()
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.JOIN:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except SynapseError:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
+        destinations = yield self.get_joined_hosts_for_room_from_state(context)
 
         @defer.inlineCallbacks
         def _notify():
@@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
         preserve_fn(federation_handler.handle_new_event)(
             event, destinations=destinations,
         )
+
+    def get_joined_hosts_for_room_from_state(self, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_hosts_for_room_from_state(
+            state_group, context.current_state_ids
+        )
+
+    @cachedInlineCallbacks(num_args=1, cache_context=True)
+    def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
+                                              cache_context):
+
+        # Don't bother getting state for people on the same HS
+        current_state = yield self.store.get_events([
+            e_id for key, e_id in current_state_ids.items()
+            if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
+        ])
+
+        destinations = set()
+        for e in current_state.itervalues():
+            try:
+                if e.type == EventTypes.Member:
+                    if e.content["membership"] == Membership.JOIN:
+                        destinations.add(get_domain_from_id(e.state_key))
+            except SynapseError:
+                logger.warn(
+                    "Failed to get destination from event %s", e.event_id
+                )
+
+        defer.returnValue(destinations)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6a1fe76c88..cf82a2336e 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -88,6 +88,8 @@ class PresenceHandler(object):
         self.notifier = hs.get_notifier()
         self.federation = hs.get_replication_layer()
 
+        self.state = hs.get_state_handler()
+
         self.federation.register_edu_handler(
             "m.presence", self.incoming_presence
         )
@@ -189,6 +191,13 @@ class PresenceHandler(object):
             5000,
         )
 
+        self.clock.call_later(
+            60,
+            self.clock.looping_call,
+            self._persist_unpersisted_changes,
+            60 * 1000,
+        )
+
         metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
 
     @defer.inlineCallbacks
@@ -215,6 +224,27 @@ class PresenceHandler(object):
         logger.info("Finished _on_shutdown")
 
     @defer.inlineCallbacks
+    def _persist_unpersisted_changes(self):
+        """We periodically persist the unpersisted changes, as otherwise they
+        may stack up and slow down shutdown times.
+        """
+        logger.info(
+            "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
+            len(self.unpersisted_users_changes)
+        )
+
+        unpersisted = self.unpersisted_users_changes
+        self.unpersisted_users_changes = set()
+
+        if unpersisted:
+            yield self.store.update_presence([
+                self.user_to_current_state[user_id]
+                for user_id in unpersisted
+            ])
+
+        logger.info("Finished _persist_unpersisted_changes")
+
+    @defer.inlineCallbacks
     def _update_states(self, new_states):
         """Updates presence of users. Sets the appropriate timeouts. Pokes
         the notifier and federation if and only if the changed presence state
@@ -532,7 +562,9 @@ class PresenceHandler(object):
                 if not local_states:
                     continue
 
-                hosts = yield self.store.get_joined_hosts_for_room(room_id)
+                users = yield self.state.get_current_user_in_room(room_id)
+                hosts = set(get_domain_from_id(u) for u in users)
+
                 for host in hosts:
                     hosts_to_states.setdefault(host, []).extend(local_states)
 
@@ -725,13 +757,13 @@ class PresenceHandler(object):
         # don't need to send to local clients here, as that is done as part
         # of the event stream/sync.
         # TODO: Only send to servers not already in the room.
+        user_ids = yield self.state.get_current_user_in_room(room_id)
         if self.is_mine(user):
             state = yield self.current_state_for_user(user.to_string())
 
-            hosts = yield self.store.get_joined_hosts_for_room(room_id)
+            hosts = set(get_domain_from_id(u) for u in user_ids)
             self._push_to_remotes({host: (state,) for host in hosts})
         else:
-            user_ids = yield self.store.get_users_in_room(room_id)
             user_ids = filter(self.is_mine_id, user_ids)
 
             states = yield self.current_state_for_users(user_ids)
@@ -918,7 +950,12 @@ def should_notify(old_state, new_state):
         if new_state.currently_active != old_state.currently_active:
             return True
 
-    if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
+        if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
+            # Only notify about last active bumps if we're not currently acive
+            if not (old_state.currently_active and new_state.currently_active):
+                return True
+
+    elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
         # Always notify for a transition where last active gets bumped.
         return True
 
@@ -955,6 +992,7 @@ class PresenceEventSource(object):
         self.get_presence_handler = hs.get_presence_handler
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
 
     @defer.inlineCallbacks
     @log_function
@@ -1017,7 +1055,7 @@ class PresenceEventSource(object):
 
                 user_ids_to_check = set()
                 for room_id in room_ids:
-                    users = yield self.store.get_users_in_room(room_id)
+                    users = yield self.state.get_current_user_in_room(room_id)
                     user_ids_to_check.update(users)
 
                 user_ids_to_check.update(friends)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e62722d78d..726f7308d2 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,6 +18,7 @@ from ._base import BaseHandler
 from twisted.internet import defer
 
 from synapse.util.logcontext import PreserveLoggingContext
+from synapse.types import get_domain_from_id
 
 import logging
 
@@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler):
             "m.receipt", self._received_remote_receipt
         )
         self.clock = self.hs.get_clock()
+        self.state = hs.get_state_handler()
 
     @defer.inlineCallbacks
     def received_client_receipt(self, room_id, receipt_type, user_id,
@@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler):
             event_ids = receipt["event_ids"]
             data = receipt["data"]
 
-            remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
+            users = yield self.state.get_current_user_in_room(room_id)
+            remotedomains = set(get_domain_from_id(u) for u in users)
             remotedomains = remotedomains.copy()
             remotedomains.discard(self.server_name)
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8b17632fdc..ba49075a20 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -85,6 +85,12 @@ class RoomMemberHandler(BaseHandler):
             prev_event_ids=prev_event_ids,
         )
 
+        # Check if this event matches the previous membership event for the user.
+        duplicate = yield msg_handler.deduplicate_state_event(event, context)
+        if duplicate is not None:
+            # Discard the new event since this membership change is a no-op.
+            return
+
         yield msg_handler.handle_new_client_event(
             requester,
             event,
@@ -93,20 +99,26 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event = context.current_state.get(
+        prev_member_event_id = context.prev_state_ids.get(
             (EventTypes.Member, target.to_string()),
             None
         )
 
         if event.membership == Membership.JOIN:
-            if not prev_member_event or prev_member_event.membership != Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally joined the
-                # room. Don't bother if the user is just changing their profile
-                # info.
+            # Only fire user_joined_room if the user has acutally joined the
+            # room. Don't bother if the user is just changing their profile
+            # info.
+            newly_joined = True
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                newly_joined = prev_member_event.membership != Membership.JOIN
+            if newly_joined:
                 yield user_joined_room(self.distributor, target, room_id)
         elif event.membership == Membership.LEAVE:
-            if prev_member_event and prev_member_event.membership == Membership.JOIN:
-                user_left_room(self.distributor, target, room_id)
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                if prev_member_event.membership == Membership.JOIN:
+                    user_left_room(self.distributor, target, room_id)
 
     @defer.inlineCallbacks
     def remote_join(self, remote_room_hosts, room_id, user, content):
@@ -195,29 +207,32 @@ class RoomMemberHandler(BaseHandler):
             remote_room_hosts = []
 
         latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
-        current_state = yield self.state_handler.get_current_state(
+        current_state_ids = yield self.state_handler.get_current_state_ids(
             room_id, latest_event_ids=latest_event_ids,
         )
 
-        old_state = current_state.get((EventTypes.Member, target.to_string()))
-        old_membership = old_state.content.get("membership") if old_state else None
-        if action == "unban" and old_membership != "ban":
-            raise SynapseError(
-                403,
-                "Cannot unban user who was not banned (membership=%s)" % old_membership,
-                errcode=Codes.BAD_STATE
-            )
-        if old_membership == "ban" and action != "unban":
-            raise SynapseError(
-                403,
-                "Cannot %s user who was banned" % (action,),
-                errcode=Codes.BAD_STATE
-            )
+        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+        if old_state_id:
+            old_state = yield self.store.get_event(old_state_id, allow_none=True)
+            old_membership = old_state.content.get("membership") if old_state else None
+            if action == "unban" and old_membership != "ban":
+                raise SynapseError(
+                    403,
+                    "Cannot unban user who was not banned"
+                    " (membership=%s)" % old_membership,
+                    errcode=Codes.BAD_STATE
+                )
+            if old_membership == "ban" and action != "unban":
+                raise SynapseError(
+                    403,
+                    "Cannot %s user who was banned" % (action,),
+                    errcode=Codes.BAD_STATE
+                )
 
-        is_host_in_room = self.is_host_in_room(current_state)
+        is_host_in_room = yield self._is_host_in_room(current_state_ids)
 
         if effective_membership_state == Membership.JOIN:
-            if requester.is_guest and not self._can_guest_join(current_state):
+            if requester.is_guest and not self._can_guest_join(current_state_ids):
                 # This should be an auth check, but guests are a local concept,
                 # so don't really fit into the general auth process.
                 raise AuthError(403, "Guest access not allowed")
@@ -326,15 +341,17 @@ class RoomMemberHandler(BaseHandler):
             requester = synapse.types.create_requester(target_user)
 
         message_handler = self.hs.get_handlers().message_handler
-        prev_event = message_handler.deduplicate_state_event(event, context)
+        prev_event = yield message_handler.deduplicate_state_event(event, context)
         if prev_event is not None:
             return
 
         if event.membership == Membership.JOIN:
-            if requester.is_guest and not self._can_guest_join(context.current_state):
-                # This should be an auth check, but guests are a local concept,
-                # so don't really fit into the general auth process.
-                raise AuthError(403, "Guest access not allowed")
+            if requester.is_guest:
+                guest_can_join = yield self._can_guest_join(context.prev_state_ids)
+                if not guest_can_join:
+                    # This should be an auth check, but guests are a local concept,
+                    # so don't really fit into the general auth process.
+                    raise AuthError(403, "Guest access not allowed")
 
         yield message_handler.handle_new_client_event(
             requester,
@@ -344,27 +361,39 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event = context.current_state.get(
-            (EventTypes.Member, target_user.to_string()),
+        prev_member_event_id = context.prev_state_ids.get(
+            (EventTypes.Member, event.state_key),
             None
         )
 
         if event.membership == Membership.JOIN:
-            if not prev_member_event or prev_member_event.membership != Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally joined the
-                # room. Don't bother if the user is just changing their profile
-                # info.
+            # Only fire user_joined_room if the user has acutally joined the
+            # room. Don't bother if the user is just changing their profile
+            # info.
+            newly_joined = True
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                newly_joined = prev_member_event.membership != Membership.JOIN
+            if newly_joined:
                 yield user_joined_room(self.distributor, target_user, room_id)
         elif event.membership == Membership.LEAVE:
-            if prev_member_event and prev_member_event.membership == Membership.JOIN:
-                user_left_room(self.distributor, target_user, room_id)
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                if prev_member_event.membership == Membership.JOIN:
+                    user_left_room(self.distributor, target_user, room_id)
 
-    def _can_guest_join(self, current_state):
+    @defer.inlineCallbacks
+    def _can_guest_join(self, current_state_ids):
         """
         Returns whether a guest can join a room based on its current state.
         """
-        guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
-        return (
+        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+        if not guest_access_id:
+            defer.returnValue(False)
+
+        guest_access = yield self.store.get_event(guest_access_id)
+
+        defer.returnValue(
             guest_access
             and guest_access.content
             and "guest_access" in guest_access.content
@@ -683,3 +712,24 @@ class RoomMemberHandler(BaseHandler):
 
         if membership:
             yield self.store.forget(user_id, room_id)
+
+    @defer.inlineCallbacks
+    def _is_host_in_room(self, current_state_ids):
+        # Have we just created the room, and is this about to be the very
+        # first member event?
+        create_event_id = current_state_ids.get(("m.room.create", ""))
+        if len(current_state_ids) == 1 and create_event_id:
+            defer.returnValue(self.hs.is_mine_id(create_event_id))
+
+        for (etype, state_key), event_id in current_state_ids.items():
+            if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
+                continue
+
+            event = yield self.store.get_event(event_id, allow_none=True)
+            if not event:
+                continue
+
+            if event.membership == Membership.JOIN:
+                defer.returnValue(True)
+
+        defer.returnValue(False)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c8dfd02e7b..b5962f4f5a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -35,6 +35,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [
     "filter_collection",
     "is_guest",
     "request_key",
+    "device_id",
 ])
 
 
@@ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "joined",  # JoinedSyncResult for each joined room.
     "invited",  # InvitedSyncResult for each invited room.
     "archived",  # ArchivedSyncResult for each archived room.
+    "to_device",  # List of direct messages for the device.
 ])):
     __slots__ = []
 
@@ -126,7 +128,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
             self.joined or
             self.invited or
             self.archived or
-            self.account_data
+            self.account_data or
+            self.to_device
         )
 
 
@@ -139,6 +142,7 @@ class SyncHandler(object):
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
         self.response_cache = ResponseCache(hs)
+        self.state = hs.get_state_handler()
 
     def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
                                full_state=False):
@@ -355,11 +359,11 @@ class SyncHandler(object):
         Returns:
             A Deferred map from ((type, state_key)->Event)
         """
-        state = yield self.store.get_state_for_event(event.event_id)
+        state_ids = yield self.store.get_state_ids_for_event(event.event_id)
         if event.is_state():
-            state = state.copy()
-            state[(event.type, event.state_key)] = event
-        defer.returnValue(state)
+            state_ids = state_ids.copy()
+            state_ids[(event.type, event.state_key)] = event.event_id
+        defer.returnValue(state_ids)
 
     @defer.inlineCallbacks
     def get_state_at(self, room_id, stream_position):
@@ -412,57 +416,61 @@ class SyncHandler(object):
         with Measure(self.clock, "compute_state_delta"):
             if full_state:
                 if batch:
-                    current_state = yield self.store.get_state_for_event(
+                    current_state_ids = yield self.store.get_state_ids_for_event(
                         batch.events[-1].event_id
                     )
 
-                    state = yield self.store.get_state_for_event(
+                    state_ids = yield self.store.get_state_ids_for_event(
                         batch.events[0].event_id
                     )
                 else:
-                    current_state = yield self.get_state_at(
+                    current_state_ids = yield self.get_state_at(
                         room_id, stream_position=now_token
                     )
 
-                    state = current_state
+                    state_ids = current_state_ids
 
                 timeline_state = {
-                    (event.type, event.state_key): event
+                    (event.type, event.state_key): event.event_id
                     for event in batch.events if event.is_state()
                 }
 
-                state = _calculate_state(
+                state_ids = _calculate_state(
                     timeline_contains=timeline_state,
-                    timeline_start=state,
+                    timeline_start=state_ids,
                     previous={},
-                    current=current_state,
+                    current=current_state_ids,
                 )
             elif batch.limited:
                 state_at_previous_sync = yield self.get_state_at(
                     room_id, stream_position=since_token
                 )
 
-                current_state = yield self.store.get_state_for_event(
+                current_state_ids = yield self.store.get_state_ids_for_event(
                     batch.events[-1].event_id
                 )
 
-                state_at_timeline_start = yield self.store.get_state_for_event(
+                state_at_timeline_start = yield self.store.get_state_ids_for_event(
                     batch.events[0].event_id
                 )
 
                 timeline_state = {
-                    (event.type, event.state_key): event
+                    (event.type, event.state_key): event.event_id
                     for event in batch.events if event.is_state()
                 }
 
-                state = _calculate_state(
+                state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
                     previous=state_at_previous_sync,
-                    current=current_state,
+                    current=current_state_ids,
                 )
             else:
-                state = {}
+                state_ids = {}
+
+        state = {}
+        if state_ids:
+            state = yield self.store.get_events(state_ids.values())
 
         defer.returnValue({
             (e.type, e.state_key): e
@@ -527,16 +535,58 @@ class SyncHandler(object):
             sync_result_builder, newly_joined_rooms, newly_joined_users
         )
 
+        yield self._generate_sync_entry_for_to_device(sync_result_builder)
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
             joined=sync_result_builder.joined,
             invited=sync_result_builder.invited,
             archived=sync_result_builder.archived,
+            to_device=sync_result_builder.to_device,
             next_batch=sync_result_builder.now_token,
         ))
 
     @defer.inlineCallbacks
+    def _generate_sync_entry_for_to_device(self, sync_result_builder):
+        """Generates the portion of the sync response. Populates
+        `sync_result_builder` with the result.
+
+        Args:
+            sync_result_builder(SyncResultBuilder)
+
+        Returns:
+            Deferred(dict): A dictionary containing the per room account data.
+        """
+        user_id = sync_result_builder.sync_config.user.to_string()
+        device_id = sync_result_builder.sync_config.device_id
+        now_token = sync_result_builder.now_token
+        since_stream_id = 0
+        if sync_result_builder.since_token is not None:
+            since_stream_id = int(sync_result_builder.since_token.to_device_key)
+
+        if since_stream_id != int(now_token.to_device_key):
+            # We only delete messages when a new message comes in, but that's
+            # fine so long as we delete them at some point.
+
+            logger.debug("Deleting messages up to %d", since_stream_id)
+            yield self.store.delete_messages_for_device(
+                user_id, device_id, since_stream_id
+            )
+
+            logger.debug("Getting messages up to %d", now_token.to_device_key)
+            messages, stream_id = yield self.store.get_new_messages_for_device(
+                user_id, device_id, since_stream_id, now_token.to_device_key
+            )
+            logger.debug("Got messages up to %d: %r", stream_id, messages)
+            sync_result_builder.now_token = now_token.copy_and_replace(
+                "to_device_key", stream_id
+            )
+            sync_result_builder.to_device = messages
+        else:
+            sync_result_builder.to_device = []
+
+    @defer.inlineCallbacks
     def _generate_sync_entry_for_account_data(self, sync_result_builder):
         """Generates the account data portion of the sync response. Populates
         `sync_result_builder` with the result.
@@ -626,7 +676,7 @@ class SyncHandler(object):
 
         extra_users_ids = set(newly_joined_users)
         for room_id in newly_joined_rooms:
-            users = yield self.store.get_users_in_room(room_id)
+            users = yield self.state.get_current_user_in_room(room_id)
             extra_users_ids.update(users)
         extra_users_ids.discard(user.to_string())
 
@@ -766,8 +816,13 @@ class SyncHandler(object):
             # the last sync (even if we have since left). This is to make sure
             # we do send down the room, and with full state, where necessary
             if room_id in joined_room_ids or has_join:
-                old_state = yield self.get_state_at(room_id, since_token)
-                old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
+                old_state_ids = yield self.get_state_at(room_id, since_token)
+                old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
+                old_mem_ev = None
+                if old_mem_ev_id:
+                    old_mem_ev = yield self.store.get_event(
+                        old_mem_ev_id, allow_none=True
+                    )
                 if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
                     newly_joined_rooms.append(room_id)
 
@@ -1059,27 +1114,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
     Returns:
         dict
     """
-    event_id_to_state = {
-        e.event_id: e
-        for e in itertools.chain(
-            timeline_contains.values(),
-            previous.values(),
-            timeline_start.values(),
-            current.values(),
+    event_id_to_key = {
+        e: key
+        for key, e in itertools.chain(
+            timeline_contains.items(),
+            previous.items(),
+            timeline_start.items(),
+            current.items(),
         )
     }
 
-    c_ids = set(e.event_id for e in current.values())
-    tc_ids = set(e.event_id for e in timeline_contains.values())
-    p_ids = set(e.event_id for e in previous.values())
-    ts_ids = set(e.event_id for e in timeline_start.values())
+    c_ids = set(e for e in current.values())
+    tc_ids = set(e for e in timeline_contains.values())
+    p_ids = set(e for e in previous.values())
+    ts_ids = set(e for e in timeline_start.values())
 
     state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
 
-    evs = (event_id_to_state[e] for e in state_ids)
     return {
-        (e.type, e.state_key): e
-        for e in evs
+        event_id_to_key[e]: e for e in state_ids
     }
 
 
@@ -1103,6 +1156,7 @@ class SyncResultBuilder(object):
         self.joined = []
         self.invited = []
         self.archived = []
+        self.device = []
 
 
 class RoomSyncResultBuilder(object):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 46181984c0..0b530b9034 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -20,7 +20,7 @@ from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
 )
 from synapse.util.metrics import Measure
-from synapse.types import UserID
+from synapse.types import UserID, get_domain_from_id
 
 import logging
 
@@ -42,6 +42,7 @@ class TypingHandler(object):
         self.auth = hs.get_auth()
         self.is_mine_id = hs.is_mine_id
         self.notifier = hs.get_notifier()
+        self.state = hs.get_state_handler()
 
         self.clock = hs.get_clock()
 
@@ -166,7 +167,8 @@ class TypingHandler(object):
 
     @defer.inlineCallbacks
     def _push_update(self, room_id, user_id, typing):
-        domains = yield self.store.get_joined_hosts_for_room(room_id)
+        users = yield self.state.get_current_user_in_room(room_id)
+        domains = set(get_domain_from_id(u) for u in users)
 
         deferreds = []
         for domain in domains:
@@ -199,7 +201,8 @@ class TypingHandler(object):
         # Check that the string is a valid user id
         UserID.from_string(user_id)
 
-        domains = yield self.store.get_joined_hosts_for_room(room_id)
+        users = yield self.state.get_current_user_in_room(room_id)
+        domains = set(get_domain_from_id(u) for u in users)
 
         if self.server_name in domains:
             self._push_update_local(
diff --git a/synapse/notifier.py b/synapse/notifier.py
index b86648f5e4..48653ae843 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -423,7 +423,8 @@ class Notifier(object):
     def _is_world_readable(self, room_id):
         state = yield self.state_handler.get_current_state(
             room_id,
-            EventTypes.RoomHistoryVisibility
+            EventTypes.RoomHistoryVisibility,
+            "",
         )
         if state and "history_visibility" in state.content:
             defer.returnValue(state.content["history_visibility"] == "world_readable")
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index ed2ccc4dfb..3f75d3f921 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,12 +40,12 @@ class ActionGenerator:
     def handle_push_actions_for_event(self, event, context):
         with Measure(self.clock, "evaluator_for_event"):
             bulk_evaluator = yield evaluator_for_event(
-                event, self.hs, self.store, context.state_group, context.current_state
+                event, self.hs, self.store, context
             )
 
         with Measure(self.clock, "action_for_event_by_user"):
             actions_by_user = yield bulk_evaluator.action_for_event_by_user(
-                event, context.current_state
+                event, context
             )
 
         context.push_actions = [
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 004eded61f..f1bbe57dcb 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,8 @@ from twisted.internet import defer
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
-from synapse.api.constants import EventTypes, Membership
-from synapse.visibility import filter_events_for_clients
+from synapse.api.constants import EventTypes
+from synapse.visibility import filter_events_for_clients_context
 
 
 logger = logging.getLogger(__name__)
@@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store):
 
 
 @defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, state_group, current_state):
+def evaluator_for_event(event, hs, store, context):
     rules_by_user = yield store.bulk_get_push_rules_for_room(
-        event.room_id, state_group, current_state
+        event, context
     )
 
     # if this event is an invite event, we may need to run rules for the user
@@ -72,7 +72,7 @@ class BulkPushRuleEvaluator:
         self.store = store
 
     @defer.inlineCallbacks
-    def action_for_event_by_user(self, event, current_state):
+    def action_for_event_by_user(self, event, context):
         actions_by_user = {}
 
         # None of these users can be peeking since this list of users comes
@@ -82,27 +82,25 @@ class BulkPushRuleEvaluator:
             (u, False) for u in self.rules_by_user.keys()
         ]
 
-        filtered_by_user = yield filter_events_for_clients(
-            self.store, user_tuples, [event], {event.event_id: current_state}
+        filtered_by_user = yield filter_events_for_clients_context(
+            self.store, user_tuples, [event], {event.event_id: context}
         )
 
-        room_members = set(
-            e.state_key for e in current_state.values()
-            if e.type == EventTypes.Member and e.membership == Membership.JOIN
+        room_members = yield self.store.get_joined_users_from_context(
+            event, context
         )
 
         evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
 
         condition_cache = {}
 
-        display_names = {}
-        for ev in current_state.values():
-            nm = ev.content.get("displayname", None)
-            if nm and ev.type == EventTypes.Member:
-                display_names[ev.state_key] = nm
-
         for uid, rules in self.rules_by_user.items():
-            display_name = display_names.get(uid, None)
+            display_name = None
+            member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
+            if member_ev_id:
+                member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
+                if member_ev:
+                    display_name = member_ev.content.get("displayname", None)
 
             filtered = filtered_by_user[uid]
             if len(filtered) == 0:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index feedb075e2..c0f8176e3d 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -245,7 +245,7 @@ class HttpPusher(object):
     @defer.inlineCallbacks
     def _build_notification_dict(self, event, tweaks, badge):
         ctx = yield push_tools.get_context_for_event(
-            self.state_handler, event, self.user_id
+            self.store, self.state_handler, event, self.user_id
         )
 
         d = {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 1028731bc9..2cafcfd8f5 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -22,7 +22,7 @@ from email.mime.text import MIMEText
 from email.mime.multipart import MIMEMultipart
 
 from synapse.util.async import concurrently_execute
-from synapse.util.presentable_names import (
+from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event, descriptor_from_member_events
 )
 from synapse.types import UserID
@@ -139,7 +139,7 @@ class Mailer(object):
 
         @defer.inlineCallbacks
         def _fetch_room_state(room_id):
-            room_state = yield self.state_handler.get_current_state(room_id)
+            room_state = yield self.state_handler.get_current_state_ids(room_id)
             state_by_room[room_id] = room_state
 
         # Run at most 3 of these at once: sync does 10 at a time but email
@@ -159,11 +159,12 @@ class Mailer(object):
             )
             rooms.append(roomvars)
 
-        reason['room_name'] = calculate_room_name(
-            state_by_room[reason['room_id']], user_id, fallback_to_members=True
+        reason['room_name'] = yield calculate_room_name(
+            self.store, state_by_room[reason['room_id']], user_id,
+            fallback_to_members=True
         )
 
-        summary_text = self.make_summary_text(
+        summary_text = yield self.make_summary_text(
             notifs_by_room, state_by_room, notif_events, user_id, reason
         )
 
@@ -203,12 +204,15 @@ class Mailer(object):
         )
 
     @defer.inlineCallbacks
-    def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
-        my_member_event = room_state[("m.room.member", user_id)]
+    def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
+        my_member_event_id = room_state_ids[("m.room.member", user_id)]
+        my_member_event = yield self.store.get_event(my_member_event_id)
         is_invite = my_member_event.content["membership"] == "invite"
 
+        room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
+
         room_vars = {
-            "title": calculate_room_name(room_state, user_id),
+            "title": room_name,
             "hash": string_ordinal_total(room_id),  # See sender avatar hash
             "notifs": [],
             "invite": is_invite,
@@ -218,7 +222,7 @@ class Mailer(object):
         if not is_invite:
             for n in notifs:
                 notifvars = yield self.get_notif_vars(
-                    n, user_id, notif_events[n['event_id']], room_state
+                    n, user_id, notif_events[n['event_id']], room_state_ids
                 )
 
                 # merge overlapping notifs together.
@@ -243,7 +247,7 @@ class Mailer(object):
         defer.returnValue(room_vars)
 
     @defer.inlineCallbacks
-    def get_notif_vars(self, notif, user_id, notif_event, room_state):
+    def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
         results = yield self.store.get_events_around(
             notif['room_id'], notif['event_id'],
             before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
@@ -261,17 +265,19 @@ class Mailer(object):
         the_events.append(notif_event)
 
         for event in the_events:
-            messagevars = self.get_message_vars(notif, event, room_state)
+            messagevars = yield self.get_message_vars(notif, event, room_state_ids)
             if messagevars is not None:
                 ret['messages'].append(messagevars)
 
         defer.returnValue(ret)
 
-    def get_message_vars(self, notif, event, room_state):
+    @defer.inlineCallbacks
+    def get_message_vars(self, notif, event, room_state_ids):
         if event.type != EventTypes.Message:
-            return None
+            return
 
-        sender_state_event = room_state[("m.room.member", event.sender)]
+        sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
+        sender_state_event = yield self.store.get_event(sender_state_event_id)
         sender_name = name_from_member_event(sender_state_event)
         sender_avatar_url = sender_state_event.content.get("avatar_url")
 
@@ -299,7 +305,7 @@ class Mailer(object):
         if "body" in event.content:
             ret["body_text_plain"] = event.content["body"]
 
-        return ret
+        defer.returnValue(ret)
 
     def add_text_message_vars(self, messagevars, event):
         msgformat = event.content.get("format")
@@ -321,6 +327,7 @@ class Mailer(object):
 
         return messagevars
 
+    @defer.inlineCallbacks
     def make_summary_text(self, notifs_by_room, state_by_room,
                           notif_events, user_id, reason):
         if len(notifs_by_room) == 1:
@@ -330,8 +337,8 @@ class Mailer(object):
             # If the room has some kind of name, use it, but we don't
             # want the generated-from-names one here otherwise we'll
             # end up with, "new message from Bob in the Bob room"
-            room_name = calculate_room_name(
-                state_by_room[room_id], user_id, fallback_to_members=False
+            room_name = yield calculate_room_name(
+                self.store, state_by_room[room_id], user_id, fallback_to_members=False
             )
 
             my_member_event = state_by_room[room_id][("m.room.member", user_id)]
@@ -342,16 +349,16 @@ class Mailer(object):
                 inviter_name = name_from_member_event(inviter_member_event)
 
                 if room_name is None:
-                    return INVITE_FROM_PERSON % {
+                    defer.returnValue(INVITE_FROM_PERSON % {
                         "person": inviter_name,
                         "app": self.app_name
-                    }
+                    })
                 else:
-                    return INVITE_FROM_PERSON_TO_ROOM % {
+                    defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % {
                         "person": inviter_name,
                         "room": room_name,
                         "app": self.app_name,
-                    }
+                    })
 
             sender_name = None
             if len(notifs_by_room[room_id]) == 1:
@@ -362,24 +369,24 @@ class Mailer(object):
                     sender_name = name_from_member_event(state_event)
 
                 if sender_name is not None and room_name is not None:
-                    return MESSAGE_FROM_PERSON_IN_ROOM % {
+                    defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % {
                         "person": sender_name,
                         "room": room_name,
                         "app": self.app_name,
-                    }
+                    })
                 elif sender_name is not None:
-                    return MESSAGE_FROM_PERSON % {
+                    defer.returnValue(MESSAGE_FROM_PERSON % {
                         "person": sender_name,
                         "app": self.app_name,
-                    }
+                    })
             else:
                 # There's more than one notification for this room, so just
                 # say there are several
                 if room_name is not None:
-                    return MESSAGES_IN_ROOM % {
+                    defer.returnValue(MESSAGES_IN_ROOM % {
                         "room": room_name,
                         "app": self.app_name,
-                    }
+                    })
                 else:
                     # If the room doesn't have a name, say who the messages
                     # are from explicitly to avoid, "messages in the Bob room"
@@ -388,22 +395,22 @@ class Mailer(object):
                         for n in notifs_by_room[room_id]
                     ]))
 
-                    return MESSAGES_FROM_PERSON % {
+                    defer.returnValue(MESSAGES_FROM_PERSON % {
                         "person": descriptor_from_member_events([
                             state_by_room[room_id][("m.room.member", s)]
                             for s in sender_ids
                         ]),
                         "app": self.app_name,
-                    }
+                    })
         else:
             # Stuff's happened in multiple different rooms
 
             # ...but we still refer to the 'reason' room which triggered the mail
             if reason['room_name'] is not None:
-                return MESSAGES_IN_ROOM_AND_OTHERS % {
+                defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % {
                     "room": reason['room_name'],
                     "app": self.app_name,
-                }
+                })
             else:
                 # If the reason room doesn't have a name, say who the messages
                 # are from explicitly to avoid, "messages in the Bob room"
@@ -412,13 +419,13 @@ class Mailer(object):
                     for n in notifs_by_room[reason['room_id']]
                 ]))
 
-                return MESSAGES_FROM_PERSON_AND_OTHERS % {
+                defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
                     "person": descriptor_from_member_events([
                         state_by_room[reason['room_id']][("m.room.member", s)]
                         for s in sender_ids
                     ]),
                     "app": self.app_name,
-                }
+                })
 
     def make_room_link(self, room_id):
         # need /beta for Universal Links to work on iOS
diff --git a/synapse/util/presentable_names.py b/synapse/push/presentable_names.py
index f68676e9e7..277da3cd35 100644
--- a/synapse/util/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 import re
 import logging
 
@@ -25,7 +27,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
 ALL_ALONE = "Empty Room"
 
 
-def calculate_room_name(room_state, user_id, fallback_to_members=True,
+@defer.inlineCallbacks
+def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True,
                         fallback_to_single_member=True):
     """
     Works out a user-facing name for the given room as per Matrix
@@ -42,59 +45,78 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
         (string or None) A human readable name for the room.
     """
     # does it have a name?
-    if ("m.room.name", "") in room_state:
-        m_room_name = room_state[("m.room.name", "")]
-        if m_room_name.content and m_room_name.content["name"]:
-            return m_room_name.content["name"]
+    if ("m.room.name", "") in room_state_ids:
+        m_room_name = yield store.get_event(
+            room_state_ids[("m.room.name", "")], allow_none=True
+        )
+        if m_room_name and m_room_name.content and m_room_name.content["name"]:
+            defer.returnValue(m_room_name.content["name"])
 
     # does it have a canonical alias?
-    if ("m.room.canonical_alias", "") in room_state:
-        canon_alias = room_state[("m.room.canonical_alias", "")]
+    if ("m.room.canonical_alias", "") in room_state_ids:
+        canon_alias = yield store.get_event(
+            room_state_ids[("m.room.canonical_alias", "")], allow_none=True
+        )
         if (
-            canon_alias.content and canon_alias.content["alias"] and
+            canon_alias and canon_alias.content and canon_alias.content["alias"] and
             _looks_like_an_alias(canon_alias.content["alias"])
         ):
-            return canon_alias.content["alias"]
+            defer.returnValue(canon_alias.content["alias"])
 
     # at this point we're going to need to search the state by all state keys
     # for an event type, so rearrange the data structure
-    room_state_bytype = _state_as_two_level_dict(room_state)
+    room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
 
     # right then, any aliases at all?
-    if "m.room.aliases" in room_state_bytype:
-        m_room_aliases = room_state_bytype["m.room.aliases"]
-        if len(m_room_aliases.values()) > 0:
-            first_alias_event = m_room_aliases.values()[0]
-            if first_alias_event.content and first_alias_event.content["aliases"]:
-                the_aliases = first_alias_event.content["aliases"]
+    if "m.room.aliases" in room_state_bytype_ids:
+        m_room_aliases = room_state_bytype_ids["m.room.aliases"]
+        for alias_id in m_room_aliases.values():
+            alias_event = yield store.get_event(
+                alias_id, allow_none=True
+            )
+            if alias_event and alias_event.content.get("aliases"):
+                the_aliases = alias_event.content["aliases"]
                 if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
-                    return the_aliases[0]
+                    defer.returnValue(the_aliases[0])
 
     if not fallback_to_members:
-        return None
+        defer.returnValue(None)
 
     my_member_event = None
-    if ("m.room.member", user_id) in room_state:
-        my_member_event = room_state[("m.room.member", user_id)]
+    if ("m.room.member", user_id) in room_state_ids:
+        my_member_event = yield store.get_event(
+            room_state_ids[("m.room.member", user_id)], allow_none=True
+        )
 
     if (
         my_member_event is not None and
         my_member_event.content['membership'] == "invite"
     ):
-        if ("m.room.member", my_member_event.sender) in room_state:
-            inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
-            if fallback_to_single_member:
-                return "Invite from %s" % (name_from_member_event(inviter_member_event),)
-            else:
-                return None
+        if ("m.room.member", my_member_event.sender) in room_state_ids:
+            inviter_member_event = yield store.get_event(
+                room_state_ids[("m.room.member", my_member_event.sender)],
+                allow_none=True,
+            )
+            if inviter_member_event:
+                if fallback_to_single_member:
+                    defer.returnValue(
+                        "Invite from %s" % (
+                            name_from_member_event(inviter_member_event),
+                        )
+                    )
+                else:
+                    return
         else:
-            return "Room Invite"
+            defer.returnValue("Room Invite")
 
     # we're going to have to generate a name based on who's in the room,
     # so find out who is in the room that isn't the user.
-    if "m.room.member" in room_state_bytype:
+    if "m.room.member" in room_state_bytype_ids:
+        member_events = yield store.get_events(
+            room_state_bytype_ids["m.room.member"].values()
+        )
         all_members = [
-            ev for ev in room_state_bytype["m.room.member"].values()
+            ev for ev in member_events.values()
             if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
         ]
         # Sort the member events oldest-first so the we name people in the
@@ -111,9 +133,9 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
             # self-chat, peeked room with 1 participant,
             # or inbound invite, or outbound 3PID invite.
             if all_members[0].sender == user_id:
-                if "m.room.third_party_invite" in room_state_bytype:
+                if "m.room.third_party_invite" in room_state_bytype_ids:
                     third_party_invites = (
-                        room_state_bytype["m.room.third_party_invite"].values()
+                        room_state_bytype_ids["m.room.third_party_invite"].values()
                     )
 
                     if len(third_party_invites) > 0:
@@ -126,17 +148,17 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
                         # return "Inviting %s" % (
                         #     descriptor_from_member_events(third_party_invites)
                         # )
-                        return "Inviting email address"
+                        defer.returnValue("Inviting email address")
                     else:
-                        return ALL_ALONE
+                        defer.returnValue(ALL_ALONE)
             else:
-                return name_from_member_event(all_members[0])
+                defer.returnValue(name_from_member_event(all_members[0]))
         else:
-            return ALL_ALONE
+            defer.returnValue(ALL_ALONE)
     elif len(other_members) == 1 and not fallback_to_single_member:
-        return None
+        return
     else:
-        return descriptor_from_member_events(other_members)
+        defer.returnValue(descriptor_from_member_events(other_members))
 
 
 def descriptor_from_member_events(member_events):
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index becb8ef1ae..b47bf1f92b 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from twisted.internet import defer
-from synapse.util.presentable_names import (
+from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event
 )
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@@ -49,21 +49,22 @@ def get_badge_count(store, user_id):
 
 
 @defer.inlineCallbacks
-def get_context_for_event(state_handler, ev, user_id):
+def get_context_for_event(store, state_handler, ev, user_id):
     ctx = {}
 
-    room_state = yield state_handler.get_current_state(ev.room_id)
+    room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
 
     # we no longer bother setting room_alias, and make room_name the
     # human-readable name instead, be that m.room.name, an alias or
     # a list of people in the room
-    name = calculate_room_name(
-        room_state, user_id, fallback_to_single_member=False
+    name = yield calculate_room_name(
+        store, room_state_ids, user_id, fallback_to_single_member=False
     )
     if name:
         ctx['name'] = name
 
-    sender_state_event = room_state[("m.room.member", ev.sender)]
+    sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
+    sender_state_event = yield store.get_event(sender_state_event_id)
     ctx['sender_display_name'] = name_from_member_event(sender_state_event)
 
     defer.returnValue(ctx)
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 84993b33b3..1ed9034bcb 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -40,8 +40,8 @@ STREAM_NAMES = (
     ("backfill",),
     ("push_rules",),
     ("pushers",),
-    ("state",),
     ("caches",),
+    ("to_device",),
 )
 
 
@@ -130,7 +130,6 @@ class ReplicationResource(Resource):
         backfill_token = yield self.store.get_current_backfill_token()
         push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
         pushers_token = self.store.get_pushers_stream_token()
-        state_token = self.store.get_state_stream_token()
         caches_token = self.store.get_cache_stream_token()
 
         defer.returnValue(_ReplicationToken(
@@ -142,8 +141,9 @@ class ReplicationResource(Resource):
             backfill_token,
             push_rules_token,
             pushers_token,
-            state_token,
+            0,  # State stream is no longer a thing
             caches_token,
+            int(stream_token.to_device_key),
         ))
 
     @request_handler()
@@ -191,8 +191,8 @@ class ReplicationResource(Resource):
         yield self.receipts(writer, current_token, limit, request_streams)
         yield self.push_rules(writer, current_token, limit, request_streams)
         yield self.pushers(writer, current_token, limit, request_streams)
-        yield self.state(writer, current_token, limit, request_streams)
         yield self.caches(writer, current_token, limit, request_streams)
+        yield self.to_device(writer, current_token, limit, request_streams)
         self.streams(writer, current_token, request_streams)
 
         logger.info("Replicated %d rows", writer.total)
@@ -366,25 +366,6 @@ class ReplicationResource(Resource):
             ))
 
     @defer.inlineCallbacks
-    def state(self, writer, current_token, limit, request_streams):
-        current_position = current_token.state
-
-        state = request_streams.get("state")
-
-        if state is not None:
-            state_groups, state_group_state = (
-                yield self.store.get_all_new_state_groups(
-                    state, current_position, limit
-                )
-            )
-            writer.write_header_and_rows("state_groups", state_groups, (
-                "position", "room_id", "event_id"
-            ))
-            writer.write_header_and_rows("state_group_state", state_group_state, (
-                "position", "type", "state_key", "event_id"
-            ))
-
-    @defer.inlineCallbacks
     def caches(self, writer, current_token, limit, request_streams):
         current_position = current_token.caches
 
@@ -398,6 +379,20 @@ class ReplicationResource(Resource):
                 "position", "cache_func", "keys", "invalidation_ts"
             ))
 
+    @defer.inlineCallbacks
+    def to_device(self, writer, current_token, limit, request_streams):
+        current_position = current_token.to_device
+
+        to_device = request_streams.get("to_device")
+
+        if to_device is not None:
+            to_device_rows = yield self.store.get_all_new_device_messages(
+                to_device, current_position, limit
+            )
+            writer.write_header_and_rows("to_device", to_device_rows, (
+                "position", "user_id", "device_id", "message_json"
+            ))
+
 
 class _Writer(object):
     """Writes the streams as a JSON object as the response to the request"""
@@ -426,7 +421,7 @@ class _Writer(object):
 
 class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
     "events", "presence", "typing", "receipts", "account_data", "backfill",
-    "push_rules", "pushers", "state", "caches",
+    "push_rules", "pushers", "state", "caches", "to_device",
 ))):
     __slots__ = []
 
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
new file mode 100644
index 0000000000..64d8eb2af1
--- /dev/null
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+
+
+class SlavedDeviceInboxStore(BaseSlavedStore):
+    def __init__(self, db_conn, hs):
+        super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
+        self._device_inbox_id_gen = SlavedIdTracker(
+            db_conn, "device_inbox", "stream_id",
+        )
+
+    get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
+    get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
+    delete_messages_for_device = DataStore.delete_messages_for_device.__func__
+
+    def stream_positions(self):
+        result = super(SlavedDeviceInboxStore, self).stream_positions()
+        result["to_device"] = self._device_inbox_id_gen.get_current_token()
+        return result
+
+    def process_replication(self, result):
+        stream = result.get("to_device")
+        if stream:
+            self._device_inbox_id_gen.advance(int(stream["position"]))
+
+        return super(SlavedDeviceInboxStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index f4f31f2d27..cbebd5b2f7 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -120,10 +120,21 @@ class SlavedEventStore(BaseSlavedStore):
     get_state_for_event = DataStore.get_state_for_event.__func__
     get_state_for_events = DataStore.get_state_for_events.__func__
     get_state_groups = DataStore.get_state_groups.__func__
+    get_state_groups_ids = DataStore.get_state_groups_ids.__func__
+    get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
+    get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
+    get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
+    get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
+    _get_joined_users_from_context = (
+        RoomMemberStore.__dict__["_get_joined_users_from_context"]
+    )
+
     get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
     get_room_events_stream_for_rooms = (
         DataStore.get_room_events_stream_for_rooms.__func__
     )
+    is_host_joined = DataStore.is_host_joined.__func__
+    _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
     get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
 
     _set_before_and_after = staticmethod(DataStore._set_before_and_after)
@@ -211,7 +222,6 @@ class SlavedEventStore(BaseSlavedStore):
             self._get_current_state_for_key.invalidate_all()
             self.get_rooms_for_user.invalidate_all()
             self.get_users_in_room.invalidate((event.room_id,))
-            # self.get_joined_hosts_for_room.invalidate((event.room_id,))
 
         self._invalidate_get_event_cache(event.event_id)
 
@@ -235,7 +245,6 @@ class SlavedEventStore(BaseSlavedStore):
 
         if event.type == EventTypes.Member:
             self.get_rooms_for_user.invalidate((event.state_key,))
-            # self.get_joined_hosts_for_room.invalidate((event.room_id,))
             self.get_users_in_room.invalidate((event.room_id,))
             self._membership_stream_cache.entity_has_changed(
                 event.state_key, event.internal_metadata.stream_ordering
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 326780405e..f9f5a3e077 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -49,6 +49,7 @@ from synapse.rest.client.v2_alpha import (
     notifications,
     devices,
     thirdparty,
+    sendtodevice,
 )
 
 from synapse.http.server import JsonResource
@@ -96,3 +97,4 @@ class ClientRestResource(JsonResource):
         notifications.register_servlets(hs, client_resource)
         devices.register_servlets(hs, client_resource)
         thirdparty.register_servlets(hs, client_resource)
+        sendtodevice.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
new file mode 100644
index 0000000000..9c10a99acf
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+from synapse.http.servlet import parse_json_object_from_request
+
+from synapse.http import servlet
+from synapse.rest.client.v1.transactions import HttpTransactionStore
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class SendToDeviceRestServlet(servlet.RestServlet):
+    PATTERNS = client_v2_patterns(
+        "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
+        releases=[], v2_alpha=False
+    )
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super(SendToDeviceRestServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.notifier = hs.get_notifier()
+        self.is_mine_id = hs.is_mine_id
+        self.txns = HttpTransactionStore()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, message_type, txn_id):
+        try:
+            defer.returnValue(
+                self.txns.get_client_transaction(request, txn_id)
+            )
+        except KeyError:
+            pass
+
+        requester = yield self.auth.get_user_by_req(request)
+
+        content = parse_json_object_from_request(request)
+
+        # TODO: Prod the notifier to wake up sync streams.
+        # TODO: Implement replication for the messages.
+        # TODO: Send the messages to remote servers if needed.
+
+        local_messages = {}
+        for user_id, by_device in content["messages"].items():
+            if self.is_mine_id(user_id):
+                messages_by_device = {
+                    device_id: {
+                        "content": message_content,
+                        "type": message_type,
+                        "sender": requester.user.to_string(),
+                    }
+                    for device_id, message_content in by_device.items()
+                }
+                if messages_by_device:
+                    local_messages[user_id] = messages_by_device
+
+        stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
+
+        self.notifier.on_new_event(
+            "to_device_key", stream_id, users=local_messages.keys()
+        )
+
+        response = (200, {})
+        self.txns.store_client_transaction(request, txn_id, response)
+        defer.returnValue(response)
+
+
+def register_servlets(hs, http_server):
+    SendToDeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index b11acdbea7..6fc63715aa 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -97,6 +97,7 @@ class SyncRestServlet(RestServlet):
             request, allow_guest=True
         )
         user = requester.user
+        device_id = requester.device_id
 
         timeout = parse_integer(request, "timeout", default=0)
         since = parse_string(request, "since")
@@ -109,12 +110,12 @@ class SyncRestServlet(RestServlet):
 
         logger.info(
             "/sync: user=%r, timeout=%r, since=%r,"
-            " set_presence=%r, filter_id=%r" % (
-                user, timeout, since, set_presence, filter_id
+            " set_presence=%r, filter_id=%r, device_id=%r" % (
+                user, timeout, since, set_presence, filter_id, device_id
             )
         )
 
-        request_key = (user, timeout, since, filter_id, full_state)
+        request_key = (user, timeout, since, filter_id, full_state, device_id)
 
         if filter_id:
             if filter_id.startswith('{'):
@@ -136,6 +137,7 @@ class SyncRestServlet(RestServlet):
             filter_collection=filter,
             is_guest=requester.is_guest,
             request_key=request_key,
+            device_id=device_id,
         )
 
         if since is not None:
@@ -173,6 +175,7 @@ class SyncRestServlet(RestServlet):
 
         response_content = {
             "account_data": {"events": sync_result.account_data},
+            "to_device": {"events": sync_result.to_device},
             "presence": self.encode_presence(
                 sync_result.presence, time_now
             ),
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 9abca3a8ad..4f6f1a7e17 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -18,15 +18,32 @@ import logging
 
 from twisted.internet import defer
 
+from synapse.api.constants import ThirdPartyEntityKind
 from synapse.http.servlet import RestServlet
-from synapse.types import ThirdPartyEntityKind
 from ._base import client_v2_patterns
 
 logger = logging.getLogger(__name__)
 
 
+class ThirdPartyProtocolsServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
+
+    def __init__(self, hs):
+        super(ThirdPartyProtocolsServlet, self).__init__()
+
+        self.auth = hs.get_auth()
+        self.appservice_handler = hs.get_application_service_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        yield self.auth.get_user_by_req(request)
+
+        protocols = yield self.appservice_handler.get_3pe_protocols()
+        defer.returnValue((200, protocols))
+
+
 class ThirdPartyUserServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
+    PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
                                   releases=())
 
     def __init__(self, hs):
@@ -50,7 +67,7 @@ class ThirdPartyUserServlet(RestServlet):
 
 
 class ThirdPartyLocationServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
+    PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
                                   releases=())
 
     def __init__(self, hs):
@@ -74,5 +91,6 @@ class ThirdPartyLocationServlet(RestServlet):
 
 
 def register_servlets(hs, http_server):
+    ThirdPartyProtocolsServlet(hs).register(http_server)
     ThirdPartyUserServlet(hs).register(http_server)
     ThirdPartyLocationServlet(hs).register(http_server)
diff --git a/synapse/state.py b/synapse/state.py
index ef1bc470be..cd792afed1 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.auth import AuthEventTypes
 from synapse.events.snapshot import EventContext
+from synapse.util.async import Linearizer
 
 from collections import namedtuple
 
@@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
 EVICTION_TIMEOUT_SECONDS = 60 * 60
 
 
+_NEXT_STATE_ID = 1
+
+
+def _gen_state_id():
+    global _NEXT_STATE_ID
+    s = "X%d" % (_NEXT_STATE_ID,)
+    _NEXT_STATE_ID += 1
+    return s
+
+
 class _StateCacheEntry(object):
-    def __init__(self, state, state_group, ts):
+    __slots__ = ["state", "state_group", "state_id"]
+
+    def __init__(self, state, state_group):
         self.state = state
         self.state_group = state_group
 
+        # The `state_id` is a unique ID we generate that can be used as ID for
+        # this collection of state. Usually this would be the same as the
+        # state group, but on worker instances we can't generate a new state
+        # group each time we resolve state, so we generate a separate one that
+        # isn't persisted and is used solely for caches.
+        # `state_id` is either a state_group (and so an int) or a string. This
+        # ensures we don't accidentally persist a state_id as a stateg_group
+        if state_group:
+            self.state_id = state_group
+        else:
+            self.state_id = _gen_state_id()
+
 
 class StateHandler(object):
     """ Responsible for doing state conflict resolution.
@@ -60,6 +85,7 @@ class StateHandler(object):
 
         # dict of set of event_ids -> _StateCacheEntry.
         self._state_cache = None
+        self.resolve_linearizer = Linearizer()
 
     def start_caching(self):
         logger.debug("start_caching")
@@ -93,8 +119,32 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
-        res = yield self.resolve_state_groups(room_id, latest_event_ids)
-        state = res[1]
+        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        state = ret.state
+
+        if event_type:
+            event_id = state.get((event_type, state_key))
+            event = None
+            if event_id:
+                event = yield self.store.get_event(event_id, allow_none=True)
+            defer.returnValue(event)
+            return
+
+        state_map = yield self.store.get_events(state.values(), get_prev_content=False)
+        state = {
+            key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
+        }
+
+        defer.returnValue(state)
+
+    @defer.inlineCallbacks
+    def get_current_state_ids(self, room_id, event_type=None, state_key="",
+                              latest_event_ids=None):
+        if not latest_event_ids:
+            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+
+        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        state = ret.state
 
         if event_type:
             defer.returnValue(state.get((event_type, state_key)))
@@ -103,6 +153,15 @@ class StateHandler(object):
         defer.returnValue(state)
 
     @defer.inlineCallbacks
+    def get_current_user_in_room(self, room_id):
+        latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+        joined_users = yield self.store.get_joined_users_from_state(
+            room_id, entry.state_id, entry.state
+        )
+        defer.returnValue(joined_users)
+
+    @defer.inlineCallbacks
     def compute_event_context(self, event, old_state=None):
         """ Fills out the context with the `current state` of the graph. The
         `current state` here is defined to be the state of the event graph
@@ -123,54 +182,75 @@ class StateHandler(object):
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
             if old_state:
-                context.current_state = {
-                    (s.type, s.state_key): s for s in old_state
+                context.prev_state_ids = {
+                    (s.type, s.state_key): s.event_id for s in old_state
                 }
+                if event.is_state():
+                    context.current_state_events = dict(context.prev_state_ids)
+                    key = (event.type, event.state_key)
+                    context.current_state_events[key] = event.event_id
+                else:
+                    context.current_state_events = context.prev_state_ids
             else:
-                context.current_state = {}
+                context.current_state_ids = {}
+                context.prev_state_ids = {}
             context.prev_state_events = []
-            context.state_group = None
+            context.state_group = self.store.get_next_state_group()
             defer.returnValue(context)
 
         if old_state:
-            context.current_state = {
-                (s.type, s.state_key): s for s in old_state
+            context.prev_state_ids = {
+                (s.type, s.state_key): s.event_id for s in old_state
             }
-            context.state_group = None
+            context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
-                if key in context.current_state:
-                    replaces = context.current_state[key]
-                    if replaces.event_id != event.event_id:  # Paranoia check
-                        event.unsigned["replaces_state"] = replaces.event_id
+                if key in context.prev_state_ids:
+                    replaces = context.prev_state_ids[key]
+                    if replaces != event.event_id:  # Paranoia check
+                        event.unsigned["replaces_state"] = replaces
+                context.current_state_ids = dict(context.prev_state_ids)
+                context.current_state_ids[key] = event.event_id
+            else:
+                context.current_state_ids = context.prev_state_ids
 
             context.prev_state_events = []
             defer.returnValue(context)
 
         if event.is_state():
-            ret = yield self.resolve_state_groups(
+            entry = yield self.resolve_state_groups(
                 event.room_id, [e for e, _ in event.prev_events],
                 event_type=event.type,
                 state_key=event.state_key,
             )
         else:
-            ret = yield self.resolve_state_groups(
+            entry = yield self.resolve_state_groups(
                 event.room_id, [e for e, _ in event.prev_events],
             )
 
-        group, curr_state, prev_state = ret
+        curr_state = entry.state
 
-        context.current_state = curr_state
-        context.state_group = group if not event.is_state() else None
+        context.prev_state_ids = curr_state
+        if event.is_state():
+            context.state_group = self.store.get_next_state_group()
+        else:
+            if entry.state_group is None:
+                entry.state_group = self.store.get_next_state_group()
+                entry.state_id = entry.state_group
+            context.state_group = entry.state_group
 
         if event.is_state():
             key = (event.type, event.state_key)
-            if key in context.current_state:
-                replaces = context.current_state[key]
-                event.unsigned["replaces_state"] = replaces.event_id
+            if key in context.prev_state_ids:
+                replaces = context.prev_state_ids[key]
+                event.unsigned["replaces_state"] = replaces
+            context.current_state_ids = dict(context.prev_state_ids)
+            context.current_state_ids[key] = event.event_id
+        else:
+            context.current_state_ids = context.prev_state_ids
 
-        context.prev_state_events = prev_state
+        context.prev_state_events = []
         defer.returnValue(context)
 
     @defer.inlineCallbacks
@@ -187,72 +267,88 @@ class StateHandler(object):
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        state_groups = yield self.store.get_state_groups(
+        state_groups_ids = yield self.store.get_state_groups_ids(
             room_id, event_ids
         )
 
         logger.debug(
             "resolve_state_groups state_groups %s",
-            state_groups.keys()
+            state_groups_ids.keys()
         )
 
-        group_names = frozenset(state_groups.keys())
+        group_names = frozenset(state_groups_ids.keys())
         if len(group_names) == 1:
-            name, state_list = state_groups.items().pop()
-            state = {
-                (e.type, e.state_key): e
-                for e in state_list
-            }
-            prev_state = state.get((event_type, state_key), None)
-            if prev_state:
-                prev_state = prev_state.event_id
-                prev_states = [prev_state]
-            else:
-                prev_states = []
+            name, state_list = state_groups_ids.items().pop()
 
-            defer.returnValue((name, state, prev_states))
+            defer.returnValue(_StateCacheEntry(
+                state=state_list,
+                state_group=name,
+            ))
 
-        if self._state_cache is not None:
-            cache = self._state_cache.get(group_names, None)
-            if cache:
-                cache.ts = self.clock.time_msec()
+        with (yield self.resolve_linearizer.queue(group_names)):
+            if self._state_cache is not None:
+                cache = self._state_cache.get(group_names, None)
+                if cache:
+                    defer.returnValue(cache)
 
-                event_dict = yield self.store.get_events(cache.state.values())
-                state = {(e.type, e.state_key): e for e in event_dict.values()}
+            logger.info(
+                "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
+            )
 
-                prev_state = state.get((event_type, state_key), None)
-                if prev_state:
-                    prev_state = prev_state.event_id
-                    prev_states = [prev_state]
-                else:
-                    prev_states = []
-                defer.returnValue(
-                    (cache.state_group, state, prev_states)
-                )
+            state = {}
+            for st in state_groups_ids.values():
+                for key, e_id in st.items():
+                    state.setdefault(key, set()).add(e_id)
 
-        logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
+            conflicted_state = {
+                k: list(v)
+                for k, v in state.items()
+                if len(v) > 1
+            }
 
-        new_state, prev_states = self._resolve_events(
-            state_groups.values(), event_type, state_key
-        )
+            if conflicted_state:
+                logger.info("Resolving conflicted state for %r", room_id)
+                state_map = yield self.store.get_events(
+                    [e_id for st in state_groups_ids.values() for e_id in st.values()],
+                    get_prev_content=False
+                )
+                state_sets = [
+                    [state_map[e_id] for key, e_id in st.items() if e_id in state_map]
+                    for st in state_groups_ids.values()
+                ]
+                new_state, _ = self._resolve_events(
+                    state_sets, event_type, state_key
+                )
+                new_state = {
+                    key: e.event_id for key, e in new_state.items()
+                }
+            else:
+                new_state = {
+                    key: e_ids.pop() for key, e_ids in state.items()
+                }
 
-        state_group = None
-        new_state_event_ids = frozenset(e.event_id for e in new_state.values())
-        for sg, events in state_groups.items():
-            if new_state_event_ids == frozenset(e.event_id for e in events):
-                state_group = sg
-                break
+            state_group = None
+            new_state_event_ids = frozenset(new_state.values())
+            for sg, events in state_groups_ids.items():
+                if new_state_event_ids == frozenset(e_id for e_id in events):
+                    state_group = sg
+                    break
+            if state_group is None:
+                # Worker instances don't have access to this method, but we want
+                # to set the state_group on the main instance to increase cache
+                # hits.
+                if hasattr(self.store, "get_next_state_group"):
+                    state_group = self.store.get_next_state_group()
 
-        if self._state_cache is not None:
             cache = _StateCacheEntry(
-                state={key: event.event_id for key, event in new_state.items()},
+                state=new_state,
                 state_group=state_group,
-                ts=self.clock.time_msec()
             )
 
-            self._state_cache[group_names] = cache
+            if self._state_cache is not None:
+                self._state_cache[group_names] = cache
 
-        defer.returnValue((state_group, new_state, prev_states))
+            defer.returnValue(cache)
 
     def resolve_events(self, state_sets, event):
         logger.info(
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7efc5bfeef..6c32773f25 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -36,6 +36,7 @@ from .push_rule import PushRuleStore
 from .media_repository import MediaRepositoryStore
 from .rejections import RejectionsStore
 from .event_push_actions import EventPushActionsStore
+from .deviceinbox import DeviceInboxStore
 
 from .state import StateStore
 from .signatures import SignatureStore
@@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 OpenIdStore,
                 ClientIpStore,
                 DeviceStore,
+                DeviceInboxStore,
                 ):
 
     def __init__(self, db_conn, hs):
@@ -108,9 +110,12 @@ class DataStore(RoomMemberStore, RoomStore,
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
+        self._device_inbox_id_gen = StreamIdGenerator(
+            db_conn, "device_inbox", "stream_id"
+        )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
-        self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
+        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
new file mode 100644
index 0000000000..68116b0394
--- /dev/null
+++ b/synapse/storage/deviceinbox.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import ujson
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceInboxStore(SQLBaseStore):
+
+    @defer.inlineCallbacks
+    def add_messages_to_device_inbox(self, messages_by_user_then_device):
+        """
+        Args:
+            messages_by_user_and_device(dict):
+                Dictionary of user_id to device_id to message.
+        Returns:
+            A deferred stream_id that resolves when the messages have been
+            inserted.
+        """
+
+        def select_devices_txn(txn, user_id, devices):
+            if not devices:
+                return []
+            sql = (
+                "SELECT user_id, device_id FROM devices"
+                " WHERE user_id = ? AND device_id IN ("
+                + ",".join("?" * len(devices))
+                + ")"
+            )
+            # TODO: Maybe this needs to be done in batches if there are
+            # too many local devices for a given user.
+            args = [user_id] + devices
+            txn.execute(sql, args)
+            return [tuple(row) for row in txn.fetchall()]
+
+        def add_messages_to_device_inbox_txn(txn, stream_id):
+            local_users_and_devices = set()
+            for user_id, messages_by_device in messages_by_user_then_device.items():
+                local_users_and_devices.update(
+                    select_devices_txn(txn, user_id, messages_by_device.keys())
+                )
+
+            sql = (
+                "INSERT INTO device_inbox"
+                " (user_id, device_id, stream_id, message_json)"
+                " VALUES (?,?,?,?)"
+            )
+            rows = []
+            for user_id, messages_by_device in messages_by_user_then_device.items():
+                for device_id, message in messages_by_device.items():
+                    message_json = ujson.dumps(message)
+                    # Only insert into the local inbox if the device exists on
+                    # this server
+                    if (user_id, device_id) in local_users_and_devices:
+                        rows.append((user_id, device_id, stream_id, message_json))
+
+            txn.executemany(sql, rows)
+
+        with self._device_inbox_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_messages_to_device_inbox",
+                add_messages_to_device_inbox_txn,
+                stream_id
+            )
+
+        defer.returnValue(self._device_inbox_id_gen.get_current_token())
+
+    def get_new_messages_for_device(
+        self, user_id, device_id, last_stream_id, current_stream_id, limit=100
+    ):
+        """
+        Args:
+            user_id(str): The recipient user_id.
+            device_id(str): The recipient device_id.
+            current_stream_id(int): The current position of the to device
+                message stream.
+        Returns:
+            Deferred ([dict], int): List of messages for the device and where
+                in the stream the messages got to.
+        """
+        def get_new_messages_for_device_txn(txn):
+            sql = (
+                "SELECT stream_id, message_json FROM device_inbox"
+                " WHERE user_id = ? AND device_id = ?"
+                " AND ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (
+                user_id, device_id, last_stream_id, current_stream_id, limit
+            ))
+            messages = []
+            for row in txn.fetchall():
+                stream_pos = row[0]
+                messages.append(ujson.loads(row[1]))
+            if len(messages) < limit:
+                stream_pos = current_stream_id
+            return (messages, stream_pos)
+
+        return self.runInteraction(
+            "get_new_messages_for_device", get_new_messages_for_device_txn,
+        )
+
+    def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+        """
+        Args:
+            user_id(str): The recipient user_id.
+            device_id(str): The recipient device_id.
+            up_to_stream_id(int): Where to delete messages up to.
+        Returns:
+            A deferred that resolves when the messages have been deleted.
+        """
+        def delete_messages_for_device_txn(txn):
+            sql = (
+                "DELETE FROM device_inbox"
+                " WHERE user_id = ? AND device_id = ?"
+                " AND stream_id <= ?"
+            )
+            txn.execute(sql, (user_id, device_id, up_to_stream_id))
+
+        return self.runInteraction(
+            "delete_messages_for_device", delete_messages_for_device_txn
+        )
+
+    def get_all_new_device_messages(self, last_pos, current_pos, limit):
+        """
+        Args:
+            last_pos(int):
+            current_pos(int):
+            limit(int):
+        Returns:
+            A deferred list of rows from the device inbox
+        """
+        if last_pos == current_pos:
+            return defer.succeed([])
+
+        def get_all_new_device_messages_txn(txn):
+            sql = (
+                "SELECT stream_id FROM device_inbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY stream_id"
+                " ORDER BY stream_id ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (last_pos, current_pos, limit))
+            stream_ids = txn.fetchall()
+            if not stream_ids:
+                return []
+            max_stream_id_in_limit = stream_ids[-1]
+
+            sql = (
+                "SELECT stream_id, user_id, device_id, message_json"
+                " FROM device_inbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+            )
+            txn.execute(sql, (last_pos, max_stream_id_in_limit))
+            return txn.fetchall()
+
+        return self.runInteraction(
+            "get_all_new_device_messages", get_all_new_device_messages_txn
+        )
+
+    def get_to_device_stream_token(self):
+        return self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 57e5005285..1a7d4c5199 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -271,39 +271,28 @@ class EventsStore(SQLBaseStore):
                 len(events_and_contexts)
             )
 
-        state_group_id_manager = self._state_groups_id_gen.get_next_mult(
-            len(events_and_contexts)
-        )
         with stream_ordering_manager as stream_orderings:
-            with state_group_id_manager as state_group_ids:
-                for (event, context), stream, state_group_id in zip(
-                    events_and_contexts, stream_orderings, state_group_ids
-                ):
-                    event.internal_metadata.stream_ordering = stream
-                    # Assign a state group_id in case a new id is needed for
-                    # this context. In theory we only need to assign this
-                    # for contexts that have current_state and aren't outliers
-                    # but that make the code more complicated. Assigning an ID
-                    # per event only causes the state_group_ids to grow as fast
-                    # as the stream_ordering so in practise shouldn't be a problem.
-                    context.new_state_group_id = state_group_id
-
-                chunks = [
-                    events_and_contexts[x:x + 100]
-                    for x in xrange(0, len(events_and_contexts), 100)
-                ]
+            for (event, context), stream, in zip(
+                events_and_contexts, stream_orderings
+            ):
+                event.internal_metadata.stream_ordering = stream
 
-                for chunk in chunks:
-                    # We can't easily parallelize these since different chunks
-                    # might contain the same event. :(
-                    yield self.runInteraction(
-                        "persist_events",
-                        self._persist_events_txn,
-                        events_and_contexts=chunk,
-                        backfilled=backfilled,
-                        delete_existing=delete_existing,
-                    )
-                    persist_event_counter.inc_by(len(chunk))
+            chunks = [
+                events_and_contexts[x:x + 100]
+                for x in xrange(0, len(events_and_contexts), 100)
+            ]
+
+            for chunk in chunks:
+                # We can't easily parallelize these since different chunks
+                # might contain the same event. :(
+                yield self.runInteraction(
+                    "persist_events",
+                    self._persist_events_txn,
+                    events_and_contexts=chunk,
+                    backfilled=backfilled,
+                    delete_existing=delete_existing,
+                )
+                persist_event_counter.inc_by(len(chunk))
 
     @_retry_on_integrity_error
     @defer.inlineCallbacks
@@ -312,19 +301,17 @@ class EventsStore(SQLBaseStore):
                        delete_existing=False):
         try:
             with self._stream_id_gen.get_next() as stream_ordering:
-                with self._state_groups_id_gen.get_next() as state_group_id:
-                    event.internal_metadata.stream_ordering = stream_ordering
-                    context.new_state_group_id = state_group_id
-                    yield self.runInteraction(
-                        "persist_event",
-                        self._persist_event_txn,
-                        event=event,
-                        context=context,
-                        current_state=current_state,
-                        backfilled=backfilled,
-                        delete_existing=delete_existing,
-                    )
-                    persist_event_counter.inc()
+                event.internal_metadata.stream_ordering = stream_ordering
+                yield self.runInteraction(
+                    "persist_event",
+                    self._persist_event_txn,
+                    event=event,
+                    context=context,
+                    current_state=current_state,
+                    backfilled=backfilled,
+                    delete_existing=delete_existing,
+                )
+                persist_event_counter.inc()
         except _RollbackButIsFineException:
             pass
 
@@ -393,7 +380,6 @@ class EventsStore(SQLBaseStore):
             txn.call_after(self._get_current_state_for_key.invalidate_all)
             txn.call_after(self.get_rooms_for_user.invalidate_all)
             txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
-            txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
 
             # Add an entry to the current_state_resets table to record the point
             # where we clobbered the current state
@@ -529,7 +515,7 @@ class EventsStore(SQLBaseStore):
                 # Add an entry to the ex_outlier_stream table to replicate the
                 # change in outlier status to our workers.
                 stream_order = event.internal_metadata.stream_ordering
-                state_group_id = context.state_group or context.new_state_group_id
+                state_group_id = context.state_group
                 self._simple_insert_txn(
                     txn,
                     table="ex_outlier_stream",
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 78334a98cf..49721656b6 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -16,7 +16,6 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.push.baserules import list_with_base_rules
-from synapse.api.constants import EventTypes, Membership
 from twisted.internet import defer
 
 import logging
@@ -124,7 +123,8 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(results)
 
-    def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
+    def bulk_get_push_rules_for_room(self, event, context):
+        state_group = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -132,11 +132,13 @@ class PushRuleStore(SQLBaseStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
+        return self._bulk_get_push_rules_for_room(
+            event.room_id, state_group, context.current_state_ids, event=event
+        )
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
-    def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
-                                      cache_context):
+    def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
+                                      cache_context, event=None):
         # We don't use `state_group`, its there so that we can cache based
         # on it. However, its important that its never None, since two current_state's
         # with a state_group of None are likely to be different.
@@ -147,12 +149,15 @@ class PushRuleStore(SQLBaseStore):
         # their unread countss are correct in the event stream, but to avoid
         # generating them for bot / AS users etc, we only do so for people who've
         # sent a read receipt into the room.
-        local_users_in_room = set(
-            e.state_key for e in current_state.values()
-            if e.type == EventTypes.Member and e.membership == Membership.JOIN
-            and self.hs.is_mine_id(e.state_key)
+
+        users_in_room = yield self._get_joined_users_from_context(
+            room_id, state_group, current_state_ids,
+            on_invalidate=cache_context.invalidate,
+            event=event,
         )
 
+        local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u))
+
         # users in the room who have pushers need to get push rules run because
         # that's how their pushers work
         if_users_with_pushers = yield self.get_if_users_have_pushers(
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index ccc3811e84..9747a04a9a 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
 
         defer.returnValue([ev for res in results.values() for ev in res])
 
-    @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
+    @cachedInlineCallbacks(num_args=3, tree=True)
     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """Get receipts for a single room for sending to clients.
 
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index a422ddf633..6ab10db328 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -20,7 +20,7 @@ from collections import namedtuple
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, EventTypes
 from synapse.types import get_domain_from_id
 
 import logging
@@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore):
 
         for event in events:
             txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
-            txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
             txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
             txn.call_after(
                 self._membership_stream_cache.entity_has_changed,
@@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore):
 
         return results
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_joined_hosts_for_room(self, room_id):
-        user_ids = yield self.get_users_in_room(room_id)
-        defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
-
     def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
         where_clause = "c.room_id = ?"
         where_values = [room_id]
@@ -325,7 +319,8 @@ class RoomMemberStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=3)
     def was_forgotten_at(self, user_id, room_id, event_id):
-        """Returns whether user_id has elected to discard history for room_id at event_id.
+        """Returns whether user_id has elected to discard history for room_id at
+        event_id.
 
         event_id must be a membership event."""
         def f(txn):
@@ -358,3 +353,98 @@ class RoomMemberStore(SQLBaseStore):
             },
             desc="who_forgot"
         )
+
+    def get_joined_users_from_context(self, event, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_users_from_context(
+            event.room_id, state_group, context.current_state_ids, event=event,
+        )
+
+    def get_joined_users_from_state(self, room_id, state_group, state_ids):
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_users_from_context(
+            room_id, state_group, state_ids,
+        )
+
+    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
+                                       cache_context, event=None):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        member_event_ids = [
+            e_id
+            for key, e_id in current_state_ids.iteritems()
+            if key[0] == EventTypes.Member
+        ]
+
+        rows = yield self._simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=member_event_ids,
+            retcols=['user_id'],
+            keyvalues={
+                "membership": Membership.JOIN,
+            },
+            batch_size=1000,
+            desc="_get_joined_users_from_context",
+        )
+
+        users_in_room = set(row["user_id"] for row in rows)
+        if event is not None and event.type == EventTypes.Member:
+            if event.membership == Membership.JOIN:
+                if event.event_id in member_event_ids:
+                    users_in_room.add(event.state_key)
+
+        defer.returnValue(users_in_room)
+
+    def is_host_joined(self, room_id, host, state_group, state_ids):
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._is_host_joined(
+            room_id, host, state_group, state_ids
+        )
+
+    @cachedInlineCallbacks(num_args=3)
+    def _is_host_joined(self, room_id, host, state_group, current_state_ids):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        for (etype, state_key), event_id in current_state_ids.items():
+            if etype == EventTypes.Member:
+                try:
+                    if get_domain_from_id(state_key) != host:
+                        continue
+                except:
+                    logger.warn("state_key not user_id: %s", state_key)
+                    continue
+
+                event = yield self.get_event(event_id, allow_none=True)
+                if event and event.content["membership"] == Membership.JOIN:
+                    defer.returnValue(True)
+
+        defer.returnValue(False)
diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/schema/delta/34/device_inbox.sql
new file mode 100644
index 0000000000..e68844c74a
--- /dev/null
+++ b/synapse/storage/schema/delta/34/device_inbox.sql
@@ -0,0 +1,24 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_inbox (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    stream_id BIGINT NOT NULL,
+    message_json TEXT NOT NULL -- {"type":, "sender":, "content",}
+);
+
+CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id);
+CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id);
diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py
new file mode 100644
index 0000000000..81948e3431
--- /dev/null
+++ b/synapse/storage/schema/delta/34/sent_txn_purge.py
@@ -0,0 +1,32 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if isinstance(database_engine, PostgresEngine):
+        cur.execute("TRUNCATE sent_transactions")
+    else:
+        cur.execute("DELETE FROM sent_transactions")
+
+    cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)")
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+    pass
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0e8fa93e1f..ec551b0b4f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
     """
 
     @defer.inlineCallbacks
-    def get_state_groups(self, room_id, event_ids):
-        """ Get the state groups for the given list of event_ids
-
-        The return value is a dict mapping group names to lists of events.
-        """
+    def get_state_groups_ids(self, room_id, event_ids):
         if not event_ids:
             defer.returnValue({})
 
@@ -59,36 +55,64 @@ class StateStore(SQLBaseStore):
         groups = set(event_to_groups.values())
         group_to_state = yield self._get_state_for_groups(groups)
 
+        defer.returnValue(group_to_state)
+
+    @defer.inlineCallbacks
+    def get_state_groups(self, room_id, event_ids):
+        """ Get the state groups for the given list of event_ids
+
+        The return value is a dict mapping group names to lists of events.
+        """
+        if not event_ids:
+            defer.returnValue({})
+
+        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+        state_event_map = yield self.get_events(
+            [
+                ev_id for group_ids in group_to_ids.values()
+                for ev_id in group_ids.values()
+            ],
+            get_prev_content=False
+        )
+
         defer.returnValue({
-            group: state_map.values()
-            for group, state_map in group_to_state.items()
+            group: [
+                state_event_map[v] for v in event_id_map.values() if v in state_event_map
+            ]
+            for group, event_id_map in group_to_ids.items()
         })
 
+    def _have_persisted_state_group_txn(self, txn, state_group):
+        txn.execute(
+            "SELECT count(*) FROM state_groups WHERE id = ?",
+            (state_group,)
+        )
+        row = txn.fetchone()
+        return row and row[0]
+
     def _store_mult_state_groups_txn(self, txn, events_and_contexts):
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
                 continue
 
-            if context.current_state is None:
-                continue
-
-            if context.state_group is not None:
-                state_groups[event.event_id] = context.state_group
+            if context.current_state_ids is None:
                 continue
 
-            state_events = dict(context.current_state)
+            state_groups[event.event_id] = context.state_group
 
-            if event.is_state():
-                state_events[(event.type, event.state_key)] = event
+            if self._have_persisted_state_group_txn(txn, context.state_group):
+                logger.info("Already persisted state_group: %r", context.state_group)
+                continue
 
-            state_group = context.new_state_group_id
+            state_event_ids = dict(context.current_state_ids)
 
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
-                    "id": state_group,
+                    "id": context.state_group,
                     "room_id": event.room_id,
                     "event_id": event.event_id,
                 },
@@ -99,16 +123,15 @@ class StateStore(SQLBaseStore):
                 table="state_groups_state",
                 values=[
                     {
-                        "state_group": state_group,
-                        "room_id": state.room_id,
-                        "type": state.type,
-                        "state_key": state.state_key,
-                        "event_id": state.event_id,
+                        "state_group": context.state_group,
+                        "room_id": event.room_id,
+                        "type": key[0],
+                        "state_key": key[1],
+                        "event_id": state_id,
                     }
-                    for state in state_events.values()
+                    for key, state_id in state_event_ids.items()
                 ],
             )
-            state_groups[event.event_id] = state_group
 
         self._simple_insert_many_txn(
             txn,
@@ -248,6 +271,31 @@ class StateStore(SQLBaseStore):
         groups = set(event_to_groups.values())
         group_to_state = yield self._get_state_for_groups(groups, types)
 
+        state_event_map = yield self.get_events(
+            [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+            get_prev_content=False
+        )
+
+        event_to_state = {
+            event_id: {
+                k: state_event_map[v]
+                for k, v in group_to_state[group].items()
+                if v in state_event_map
+            }
+            for event_id, group in event_to_groups.items()
+        }
+
+        defer.returnValue({event: event_to_state[event] for event in event_ids})
+
+    @defer.inlineCallbacks
+    def get_state_ids_for_events(self, event_ids, types):
+        event_to_groups = yield self._get_state_group_for_events(
+            event_ids,
+        )
+
+        groups = set(event_to_groups.values())
+        group_to_state = yield self._get_state_for_groups(groups, types)
+
         event_to_state = {
             event_id: group_to_state[group]
             for event_id, group in event_to_groups.items()
@@ -272,6 +320,23 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
+    @defer.inlineCallbacks
+    def get_state_ids_for_event(self, event_id, types=None):
+        """
+        Get the state dict corresponding to a particular event
+
+        Args:
+            event_id(str): event whose state should be returned
+            types(list[(str, str)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. May be None, which
+                matches any key
+
+        Returns:
+            A deferred dict from (type, state_key) -> state_event
+        """
+        state_map = yield self.get_state_ids_for_events([event_id], types)
+        defer.returnValue(state_map[event_id])
+
     @cached(num_args=2, max_entries=10000)
     def _get_state_group_for_event(self, room_id, event_id):
         return self._simple_select_one_onecol(
@@ -428,20 +493,13 @@ class StateStore(SQLBaseStore):
                     full=(types is None),
                 )
 
-        state_events = yield self._get_events(
-            [ev_id for sd in results.values() for ev_id in sd.values()],
-            get_prev_content=False
-        )
-
-        state_events = {e.event_id: e for e in state_events}
-
         # Remove all the entries with None values. The None values were just
         # used for bookkeeping in the cache.
         for group, state_dict in results.items():
             results[group] = {
-                key: state_events[event_id]
+                key: event_id
                 for key, event_id in state_dict.items()
-                if event_id and event_id in state_events
+                if event_id
             }
 
         defer.returnValue(results)
@@ -473,5 +531,5 @@ class StateStore(SQLBaseStore):
             "get_all_new_state_groups", get_all_new_state_groups_txn
         )
 
-    def get_state_stream_token(self):
-        return self._state_groups_id_gen.get_current_token()
+    def get_next_state_group(self):
+        return self._state_groups_id_gen.get_next()
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 58d4de4f1d..5055c04b24 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -245,7 +245,7 @@ class TransactionStore(SQLBaseStore):
 
         return self.cursor_to_dict(txn)
 
-    @cached()
+    @cached(max_entries=10000)
     def get_destination_retry_timings(self, destination):
         """Gets the current retry timings (if any) for a given destination.
 
@@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore):
     def _cleanup_transactions(self):
         now = self._clock.time_msec()
         month_ago = now - 30 * 24 * 60 * 60 * 1000
+        six_hours_ago = now - 6 * 60 * 60 * 1000
 
         def _cleanup_transactions_txn(txn):
             txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+            txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
 
         return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index d4c0bb6732..6bf21d6f5e 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -43,6 +43,7 @@ class EventSources(object):
     @defer.inlineCallbacks
     def get_current_token(self, direction='f'):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
+        to_device_key = self.store.get_to_device_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -61,5 +62,6 @@ class EventSources(object):
                 yield self.sources["account_data"].get_current_key()
             ),
             push_rules_key=push_rules_key,
+            to_device_key=to_device_key,
         )
         defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index fd17ecbbe0..9d64e8c4de 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -154,6 +154,7 @@ class StreamToken(
         "receipt_key",
         "account_data_key",
         "push_rules_key",
+        "to_device_key",
     ))
 ):
     _SEPARATOR = "_"
@@ -190,6 +191,7 @@ class StreamToken(
             or (int(other.receipt_key) < int(self.receipt_key))
             or (int(other.account_data_key) < int(self.account_data_key))
             or (int(other.push_rules_key) < int(self.push_rules_key))
+            or (int(other.to_device_key) < int(self.to_device_key))
         )
 
     def copy_and_advance(self, key, new_value):
@@ -269,10 +271,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
             return "t%d-%d" % (self.topological, self.stream)
         else:
             return "s%d" % (self.stream,)
-
-
-# Some arbitrary constants used for internal API enumerations. Don't rely on
-# exact values; always pass or compare symbolically
-class ThirdPartyEntityKind(object):
-    USER = 'user'
-    LOCATION = 'location'
diff --git a/synapse/visibility.py b/synapse/visibility.py
index cc12c0a23d..199b16d827 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -181,6 +181,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
 
 
 @defer.inlineCallbacks
+def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
+    user_ids = set(u[0] for u in user_tuples)
+    event_id_to_state = {}
+    for event_id, context in event_id_to_context.items():
+        state = yield store.get_events([
+            e_id
+            for key, e_id in context.current_state_ids.iteritems()
+            if key == (EventTypes.RoomHistoryVisibility, "")
+            or (key[0] == EventTypes.Member and key[1] in user_ids)
+        ])
+        event_id_to_state[event_id] = state
+
+    res = yield filter_events_for_clients(
+        store, user_tuples, events, event_id_to_state
+    )
+    defer.returnValue(res)
+
+
+@defer.inlineCallbacks
 def filter_events_for_client(store, user_id, events, is_peeking=False):
     """
     Check which events a user is allowed to see
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index b531ba8540..d9e8f634ae 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase):
             ),
         ], any_order=True)
 
+    def test_online_to_online_last_active_noop(self):
+        wheel_timer = Mock()
+        user_id = "@foo:bar"
+        now = 5000000
+
+        prev_state = UserPresenceState.default(user_id)
+        prev_state = prev_state.copy_and_replace(
+            state=PresenceState.ONLINE,
+            last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10,
+            currently_active=True,
+        )
+
+        new_state = prev_state.copy_and_replace(
+            state=PresenceState.ONLINE,
+            last_active_ts=now,
+        )
+
+        state, persist_and_notify, federation_ping = handle_update(
+            prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+        )
+
+        self.assertFalse(persist_and_notify)
+        self.assertTrue(federation_ping)
+        self.assertTrue(state.currently_active)
+        self.assertEquals(new_state.state, state.state)
+        self.assertEquals(new_state.status_msg, state.status_msg)
+        self.assertEquals(state.last_federation_update_ts, now)
+
+        self.assertEquals(wheel_timer.insert.call_count, 3)
+        wheel_timer.insert.assert_has_calls([
+            call(
+                now=now,
+                obj=user_id,
+                then=new_state.last_active_ts + IDLE_TIMER
+            ),
+            call(
+                now=now,
+                obj=user_id,
+                then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
+            ),
+            call(
+                now=now,
+                obj=user_id,
+                then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
+            ),
+        ], any_order=True)
+
     def test_online_to_online_last_active(self):
         wheel_timer = Mock()
         user_id = "@foo:bar"
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index ab9899b7d5..b2957eef9f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -62,6 +62,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
         self.on_new_event = mock_notifier.on_new_event
 
         self.auth = Mock(spec=[])
+        self.state_handler = Mock()
 
         hs = yield setup_test_homeserver(
             "test",
@@ -75,6 +76,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
                 "set_received_txn_response",
                 "get_destination_retry_timings",
             ]),
+            state_handler=self.state_handler,
             handlers=None,
             notifier=mock_notifier,
             resource_for_client=Mock(),
@@ -113,6 +115,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
             return set(member.domain for member in self.room_members)
         self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
 
+        def get_current_user_in_room(room_id):
+            return set(str(u) for u in self.room_members)
+        self.state_handler.get_current_user_in_room = get_current_user_in_room
+
         self.auth.check_joined_room = check_joined_room
 
         # Some local users to test with
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f33e6f60fb..44e859b5d1 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -305,7 +305,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         self.event_id += 1
 
-        context = EventContext(current_state=state)
+        if state is not None:
+            state_ids = {
+                key: e.event_id for key, e in state.items()
+            }
+        else:
+            state_ids = None
+
+        context = EventContext()
+        context.current_state_ids = state_ids
+        context.prev_state_ids = state_ids
         context.push_actions = push_actions
 
         ordering = None
diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
index e70ac6f14d..b69832cc1b 100644
--- a/tests/replication/test_resource.py
+++ b/tests/replication/test_resource.py
@@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
         self.assertEquals(body, {})
 
     @defer.inlineCallbacks
-    def test_events_and_state(self):
-        get = self.get(events="-1", state="-1", timeout="0")
+    def test_events(self):
+        get = self.get(events="-1", timeout="0")
         yield self.hs.get_handlers().room_creation_handler.create_room(
             synapse.types.create_requester(self.user), {}
         )
@@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
         self.assertEquals(body["events"]["field_names"], [
             "position", "internal", "json", "state_group"
         ])
-        self.assertEquals(body["state_groups"]["field_names"], [
-            "position", "room_id", "event_id"
-        ])
-        self.assertEquals(body["state_group_state"]["field_names"], [
-            "position", "type", "state_key", "event_id"
-        ])
 
     @defer.inlineCallbacks
     def test_presence(self):
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 8853cbb5fc..4fe99ebc0b 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_topo_token_is_accepted(self):
-        token = "t1-0_0_0_0_0_0"
+        token = "t1-0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_stream_token_is_accepted_for_fwd_pagianation(self):
-        token = "s0_0_0_0_0_0"
+        token = "s0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 27b2b3d123..1be7d932f6 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -78,44 +78,3 @@ class RoomMemberStoreTestCase(unittest.TestCase):
                 )
             )]
         )
-
-    @defer.inlineCallbacks
-    def test_room_hosts(self):
-        yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
-
-        self.assertEquals(
-            {"test"},
-            (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
-        )
-
-        # Should still have just one host after second join from it
-        yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
-
-        self.assertEquals(
-            {"test"},
-            (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
-        )
-
-        # Should now have two hosts after join from other host
-        yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
-
-        self.assertEquals(
-            {"test", "elsewhere"},
-            (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
-        )
-
-        # Should still have both hosts
-        yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
-
-        self.assertEquals(
-            {"test", "elsewhere"},
-            (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
-        )
-
-        # Should have only one host after other leaves
-        yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
-
-        self.assertEquals(
-            {"test"},
-            (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
-        )
diff --git a/tests/test_state.py b/tests/test_state.py
index 1a11bbcee0..6454f994e3 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -67,9 +67,11 @@ class StateGroupStore(object):
         self._event_to_state_group = {}
         self._group_to_state = {}
 
+        self._event_id_to_event = {}
+
         self._next_group = 1
 
-    def get_state_groups(self, room_id, event_ids):
+    def get_state_groups_ids(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
@@ -79,22 +81,23 @@ class StateGroupStore(object):
         return defer.succeed(groups)
 
     def store_state_groups(self, event, context):
-        if context.current_state is None:
+        if context.current_state_ids is None:
             return
 
-        state_events = context.current_state
-
-        if event.is_state():
-            state_events[(event.type, event.state_key)] = event
+        state_events = dict(context.current_state_ids)
 
-        state_group = context.state_group
-        if not state_group:
-            state_group = self._next_group
-            self._next_group += 1
+        self._group_to_state[context.state_group] = state_events
+        self._event_to_state_group[event.event_id] = context.state_group
 
-            self._group_to_state[state_group] = state_events.values()
+    def get_events(self, event_ids, **kwargs):
+        return {
+            e_id: self._event_id_to_event[e_id] for e_id in event_ids
+            if e_id in self._event_id_to_event
+        }
 
-        self._event_to_state_group[event.event_id] = state_group
+    def register_events(self, events):
+        for e in events:
+            self._event_id_to_event[e.event_id] = e
 
 
 class DictObj(dict):
@@ -136,8 +139,10 @@ class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = Mock(
             spec_set=[
-                "get_state_groups",
+                "get_state_groups_ids",
                 "add_event_hashes",
+                "get_events",
+                "get_next_state_group",
             ]
         )
         hs = Mock(spec_set=[
@@ -148,6 +153,8 @@ class StateTestCase(unittest.TestCase):
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
 
+        self.store.get_next_state_group.side_effect = Mock
+
         self.state = StateHandler(hs)
         self.event_id = 0
 
@@ -187,7 +194,7 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
 
         context_store = {}
 
@@ -196,7 +203,7 @@ class StateTestCase(unittest.TestCase):
             store.store_state_groups(event, context)
             context_store[event.event_id] = context
 
-        self.assertEqual(2, len(context_store["D"].current_state))
+        self.assertEqual(2, len(context_store["D"].prev_state_ids))
 
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
@@ -239,7 +246,9 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -250,7 +259,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"START", "A", "C"},
-            {e.event_id for e in context_store["D"].current_state.values()}
+            {e_id for e_id in context_store["D"].prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -303,7 +312,9 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -314,7 +325,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"START", "A", "B", "C"},
-            {e.event_id for e in context_store["E"].current_state.values()}
+            {e for e in context_store["E"].prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -384,7 +395,9 @@ class StateTestCase(unittest.TestCase):
         graph = Graph(nodes, edges)
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -395,7 +408,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"},
-            {e.event_id for e in context_store["D"].current_state.values()}
+            {e for e in context_store["D"].prev_state_ids.values()}
         )
 
     def _add_depths(self, nodes, edges):
@@ -424,16 +437,11 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
-            set(old_state), set(context.current_state.values())
+            set(e.event_id for e in old_state), set(context.current_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
@@ -449,18 +457,10 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
-            set(old_state),
-            set(context.current_state.values())
+            set(e.event_id for e in old_state), set(context.prev_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
-
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
         event = create_event(type="test_message", name="event")
@@ -473,20 +473,15 @@ class StateTestCase(unittest.TestCase):
 
         group_name = "group_name_1"
 
-        self.store.get_state_groups.return_value = {
-            group_name: old_state,
+        self.store.get_state_groups_ids.return_value = {
+            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
         }
 
         context = yield self.state.compute_event_context(event)
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set([e.event_id for e in context.current_state.values()])
+            set(context.current_state_ids.values())
         )
 
         self.assertEqual(group_name, context.state_group)
@@ -503,23 +498,18 @@ class StateTestCase(unittest.TestCase):
 
         group_name = "group_name_1"
 
-        self.store.get_state_groups.return_value = {
-            group_name: old_state,
+        self.store.get_state_groups_ids.return_value = {
+            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
         }
 
         context = yield self.state.compute_event_context(event)
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set([e.event_id for e in context.current_state.values()])
+            set(context.prev_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
@@ -543,11 +533,16 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 6)
+        self.assertEqual(len(context.current_state_ids), 6)
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
@@ -571,11 +566,16 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 6)
+        self.assertEqual(len(context.current_state_ids), 6)
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_standard_depth_conflict(self):
@@ -606,9 +606,16 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=2),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
+        self.assertEqual(
+            old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+        )
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
@@ -625,17 +632,22 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=1),
         ]
 
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
+        self.assertEqual(
+            old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+        )
 
     def _get_context(self, event, old_state_1, old_state_2):
         group_name_1 = "group_name_1"
         group_name_2 = "group_name_2"
 
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
+        self.store.get_state_groups_ids.return_value = {
+            group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
+            group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
         }
 
         return self.state.compute_event_context(event)