summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2018-02-13 12:16:01 +0000
committerRichard van der Hoff <richard@matrix.org>2018-02-13 12:16:01 +0000
commita9b712e9dce119c65fe24913d2a79212ae07a354 (patch)
tree2c98b0b3d7d31c59ce41db257c3a18c5b3a2df03 /synapse
parentFactor out common code for search insert (diff)
parentMerge pull request #2857 from matrix-org/erikj/upload_store (diff)
downloadsynapse-a9b712e9dce119c65fe24913d2a79212ae07a354.tar.xz
Merge branch 'develop' into matthew/gin_work_mem
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/snapshot.py4
-rw-r--r--synapse/handlers/directory.py7
-rw-r--r--synapse/handlers/federation.py54
-rw-r--r--synapse/handlers/message.py321
-rw-r--r--synapse/handlers/room.py11
-rw-r--r--synapse/handlers/room_member.py21
-rw-r--r--synapse/http/servlet.py18
-rw-r--r--synapse/metrics/metric.py11
-rw-r--r--synapse/replication/slave/storage/events.py4
-rw-r--r--synapse/rest/client/v1/admin.py38
-rw-r--r--synapse/rest/client/v1/room.py17
-rw-r--r--synapse/rest/media/v1/media_repository.py20
-rw-r--r--synapse/rest/media/v1/media_storage.py42
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py55
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py3
-rw-r--r--synapse/server.py11
-rw-r--r--synapse/server.pyi3
-rw-r--r--synapse/state.py253
-rw-r--r--synapse/storage/__init__.py1
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite3.py19
-rw-r--r--synapse/storage/events.py188
-rw-r--r--synapse/storage/room.py153
-rw-r--r--synapse/storage/schema/delta/47/state_group_seq.py37
-rw-r--r--synapse/storage/search.py13
-rw-r--r--synapse/storage/state.py196
-rw-r--r--synapse/util/caches/descriptors.py4
-rw-r--r--synapse/util/caches/expiringcache.py6
-rw-r--r--synapse/util/caches/lrucache.py28
29 files changed, 995 insertions, 549 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e9a732ff03..87e3fe7b97 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -25,7 +25,9 @@ class EventContext(object):
             The current state map excluding the current event.
             (type, state_key) -> event_id
 
-        state_group (int): state group id
+        state_group (int|None): state group id, if the state has been stored
+            as a state group. This is usually only None if e.g. the event is
+            an outlier.
         rejected (bool|str): A rejection reason if the event was rejected, else
             False
 
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a0464ae5c0..8580ada60a 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -34,6 +34,7 @@ class DirectoryHandler(BaseHandler):
 
         self.state = hs.get_state_handler()
         self.appservice_handler = hs.get_application_service_handler()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.federation = hs.get_replication_layer()
         self.federation.register_query_handler(
@@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
     def send_room_alias_update_event(self, requester, user_id, room_id):
         aliases = yield self.store.get_aliases_for_room(room_id)
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.Aliases,
@@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
         if not alias_event or alias_event.content.get("alias", "") != alias_str:
             return
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.CanonicalAlias,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 677532c87b..46bcf8b081 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -75,6 +76,7 @@ class FederationHandler(BaseHandler):
         self.is_mine_id = hs.is_mine_id
         self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.replication_layer.set_handler(self)
 
@@ -808,13 +810,12 @@ class FederationHandler(BaseHandler):
         event_ids = list(extremities.keys())
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
+        resolve = logcontext.preserve_fn(
+            self.state_handler.resolve_state_groups_for_events
+        )
         states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
-                    room_id, [e]
-                )
-                for e in event_ids
-            ], consumeErrors=True,
+            [resolve(room_id, [e]) for e in event_ids],
+            consumeErrors=True,
         ))
         states = dict(zip(event_ids, [s.state for s in states]))
 
@@ -1008,8 +1009,7 @@ class FederationHandler(BaseHandler):
         })
 
         try:
-            message_handler = self.hs.get_handlers().message_handler
-            event, context = yield message_handler._create_new_client_event(
+            event, context = yield self.event_creation_handler.create_new_client_event(
                 builder=builder,
             )
         except AuthError as e:
@@ -1249,8 +1249,7 @@ class FederationHandler(BaseHandler):
             "state_key": user_id,
         })
 
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder=builder,
         )
 
@@ -1832,8 +1831,8 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                self._update_context_for_auth_events(
-                    context, auth_events, event_key,
+                yield self._update_context_for_auth_events(
+                    event, context, auth_events, event_key,
                 )
 
         if different_auth and not event.internal_metadata.is_outlier():
@@ -1914,8 +1913,8 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                self._update_context_for_auth_events(
-                    context, auth_events, event_key,
+                yield self._update_context_for_auth_events(
+                    event, context, auth_events, event_key,
                 )
 
         try:
@@ -1924,11 +1923,15 @@ class FederationHandler(BaseHandler):
             logger.warn("Failed auth resolution for %r because %s", event, e)
             raise e
 
-    def _update_context_for_auth_events(self, context, auth_events,
+    @defer.inlineCallbacks
+    def _update_context_for_auth_events(self, event, context, auth_events,
                                         event_key):
-        """Update the state_ids in an event context after auth event resolution
+        """Update the state_ids in an event context after auth event resolution,
+        storing the changes as a new state group.
 
         Args:
+            event (Event): The event we're handling the context for
+
             context (synapse.events.snapshot.EventContext): event context
                 to be updated
 
@@ -1951,7 +1954,13 @@ class FederationHandler(BaseHandler):
         context.prev_state_ids.update({
             k: a.event_id for k, a in auth_events.iteritems()
         })
-        context.state_group = self.store.get_next_state_group()
+        context.state_group = yield self.store.store_state_group(
+            event.event_id,
+            event.room_id,
+            prev_group=context.prev_group,
+            delta_ids=context.delta_ids,
+            current_state_ids=context.current_state_ids,
+        )
 
     @defer.inlineCallbacks
     def construct_auth_difference(self, local_auth, remote_auth):
@@ -2121,8 +2130,7 @@ class FederationHandler(BaseHandler):
         if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
             builder = self.event_builder_factory.new(event_dict)
             EventValidator().validate_new(builder)
-            message_handler = self.hs.get_handlers().message_handler
-            event, context = yield message_handler._create_new_client_event(
+            event, context = yield self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
 
@@ -2160,8 +2168,7 @@ class FederationHandler(BaseHandler):
         """
         builder = self.event_builder_factory.new(event_dict)
 
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder=builder,
         )
 
@@ -2211,8 +2218,9 @@ class FederationHandler(BaseHandler):
 
         builder = self.event_builder_factory.new(event_dict)
         EventValidator().validate_new(builder)
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(builder=builder)
+        event, context = yield self.event_creation_handler.create_new_client_event(
+            builder=builder,
+        )
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 21f1717dd2..4e9752ccbd 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd
+# Copyright 2017 - 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -47,23 +47,11 @@ class MessageHandler(BaseHandler):
         self.hs = hs
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
-        self.validator = EventValidator()
-        self.profile_handler = hs.get_profile_handler()
 
         self.pagination_lock = ReadWriteLock()
 
-        self.pusher_pool = hs.get_pusherpool()
-
-        # We arbitrarily limit concurrent event creation for a room to 5.
-        # This is to stop us from diverging history *too* much.
-        self.limiter = Limiter(max_count=5)
-
-        self.action_generator = hs.get_action_generator()
-
-        self.spam_checker = hs.get_spam_checker()
-
     @defer.inlineCallbacks
-    def purge_history(self, room_id, event_id):
+    def purge_history(self, room_id, event_id, delete_local_events=False):
         event = yield self.store.get_event(event_id)
 
         if event.room_id != room_id:
@@ -72,7 +60,7 @@ class MessageHandler(BaseHandler):
         depth = event.depth
 
         with (yield self.pagination_lock.write(room_id)):
-            yield self.store.delete_old_state(room_id, depth)
+            yield self.store.purge_history(room_id, depth, delete_local_events)
 
     @defer.inlineCallbacks
     def get_messages(self, requester, room_id=None, pagin_config=None,
@@ -183,6 +171,162 @@ class MessageHandler(BaseHandler):
         defer.returnValue(chunk)
 
     @defer.inlineCallbacks
+    def get_room_data(self, user_id=None, room_id=None,
+                      event_type=None, state_key="", is_guest=False):
+        """ Get data from a room.
+
+        Args:
+            event : The room path event
+        Returns:
+            The path data content.
+        Raises:
+            SynapseError if something went wrong.
+        """
+        membership, membership_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id
+        )
+
+        if membership == Membership.JOIN:
+            data = yield self.state_handler.get_current_state(
+                room_id, event_type, state_key
+            )
+        elif membership == Membership.LEAVE:
+            key = (event_type, state_key)
+            room_state = yield self.store.get_state_for_events(
+                [membership_event_id], [key]
+            )
+            data = room_state[membership_event_id].get(key)
+
+        defer.returnValue(data)
+
+    @defer.inlineCallbacks
+    def _check_in_room_or_world_readable(self, room_id, user_id):
+        try:
+            # check_user_was_in_room will return the most recent membership
+            # event for the user if:
+            #  * The user is a non-guest user, and was ever in the room
+            #  * The user is a guest user, and has joined the room
+            # else it will throw.
+            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+            defer.returnValue((member_event.membership, member_event.event_id))
+            return
+        except AuthError:
+            visibility = yield self.state_handler.get_current_state(
+                room_id, EventTypes.RoomHistoryVisibility, ""
+            )
+            if (
+                visibility and
+                visibility.content["history_visibility"] == "world_readable"
+            ):
+                defer.returnValue((Membership.JOIN, None))
+                return
+            raise AuthError(
+                403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+            )
+
+    @defer.inlineCallbacks
+    def get_state_events(self, user_id, room_id, is_guest=False):
+        """Retrieve all state events for a given room. If the user is
+        joined to the room then return the current state. If the user has
+        left the room return the state events from when they left.
+
+        Args:
+            user_id(str): The user requesting state events.
+            room_id(str): The room ID to get all state events from.
+        Returns:
+            A list of dicts representing state events. [{}, {}, {}]
+        """
+        membership, membership_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id
+        )
+
+        if membership == Membership.JOIN:
+            room_state = yield self.state_handler.get_current_state(room_id)
+        elif membership == Membership.LEAVE:
+            room_state = yield self.store.get_state_for_events(
+                [membership_event_id], None
+            )
+            room_state = room_state[membership_event_id]
+
+        now = self.clock.time_msec()
+        defer.returnValue(
+            [serialize_event(c, now) for c in room_state.values()]
+        )
+
+    @defer.inlineCallbacks
+    def get_joined_members(self, requester, room_id):
+        """Get all the joined members in the room and their profile information.
+
+        If the user has left the room return the state events from when they left.
+
+        Args:
+            requester(Requester): The user requesting state events.
+            room_id(str): The room ID to get all state events from.
+        Returns:
+            A dict of user_id to profile info
+        """
+        user_id = requester.user.to_string()
+        if not requester.app_service:
+            # We check AS auth after fetching the room membership, as it
+            # requires us to pull out all joined members anyway.
+            membership, _ = yield self._check_in_room_or_world_readable(
+                room_id, user_id
+            )
+            if membership != Membership.JOIN:
+                raise NotImplementedError(
+                    "Getting joined members after leaving is not implemented"
+                )
+
+        users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+        # If this is an AS, double check that they are allowed to see the members.
+        # This can either be because the AS user is in the room or becuase there
+        # is a user in the room that the AS is "interested in"
+        if requester.app_service and user_id not in users_with_profile:
+            for uid in users_with_profile:
+                if requester.app_service.is_interested_in_user(uid):
+                    break
+            else:
+                # Loop fell through, AS has no interested users in room
+                raise AuthError(403, "Appservice not in room")
+
+        defer.returnValue({
+            user_id: {
+                "avatar_url": profile.avatar_url,
+                "display_name": profile.display_name,
+            }
+            for user_id, profile in users_with_profile.iteritems()
+        })
+
+
+class EventCreationHandler(object):
+    def __init__(self, hs):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
+        self.clock = hs.get_clock()
+        self.validator = EventValidator()
+        self.profile_handler = hs.get_profile_handler()
+        self.event_builder_factory = hs.get_event_builder_factory()
+        self.server_name = hs.hostname
+        self.ratelimiter = hs.get_ratelimiter()
+        self.notifier = hs.get_notifier()
+
+        # This is only used to get at ratelimit function, and maybe_kick_guest_users
+        self.base_handler = BaseHandler(hs)
+
+        self.pusher_pool = hs.get_pusherpool()
+
+        # We arbitrarily limit concurrent event creation for a room to 5.
+        # This is to stop us from diverging history *too* much.
+        self.limiter = Limiter(max_count=5)
+
+        self.action_generator = hs.get_action_generator()
+
+        self.spam_checker = hs.get_spam_checker()
+
+    @defer.inlineCallbacks
     def create_event(self, requester, event_dict, token_id=None, txn_id=None,
                      prev_event_ids=None):
         """
@@ -234,7 +378,7 @@ class MessageHandler(BaseHandler):
             if txn_id is not None:
                 builder.internal_metadata.txn_id = txn_id
 
-            event, context = yield self._create_new_client_event(
+            event, context = yield self.create_new_client_event(
                 builder=builder,
                 requester=requester,
                 prev_event_ids=prev_event_ids,
@@ -259,11 +403,6 @@ class MessageHandler(BaseHandler):
                 "Tried to send member event through non-member codepath"
             )
 
-        # We check here if we are currently being rate limited, so that we
-        # don't do unnecessary work. We check again just before we actually
-        # send the event.
-        yield self.ratelimit(requester, update=False)
-
         user = UserID.from_string(event.sender)
 
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -342,137 +481,9 @@ class MessageHandler(BaseHandler):
         )
         defer.returnValue(event)
 
+    @measure_func("create_new_client_event")
     @defer.inlineCallbacks
-    def get_room_data(self, user_id=None, room_id=None,
-                      event_type=None, state_key="", is_guest=False):
-        """ Get data from a room.
-
-        Args:
-            event : The room path event
-        Returns:
-            The path data content.
-        Raises:
-            SynapseError if something went wrong.
-        """
-        membership, membership_event_id = yield self._check_in_room_or_world_readable(
-            room_id, user_id
-        )
-
-        if membership == Membership.JOIN:
-            data = yield self.state_handler.get_current_state(
-                room_id, event_type, state_key
-            )
-        elif membership == Membership.LEAVE:
-            key = (event_type, state_key)
-            room_state = yield self.store.get_state_for_events(
-                [membership_event_id], [key]
-            )
-            data = room_state[membership_event_id].get(key)
-
-        defer.returnValue(data)
-
-    @defer.inlineCallbacks
-    def _check_in_room_or_world_readable(self, room_id, user_id):
-        try:
-            # check_user_was_in_room will return the most recent membership
-            # event for the user if:
-            #  * The user is a non-guest user, and was ever in the room
-            #  * The user is a guest user, and has joined the room
-            # else it will throw.
-            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
-            defer.returnValue((member_event.membership, member_event.event_id))
-            return
-        except AuthError:
-            visibility = yield self.state_handler.get_current_state(
-                room_id, EventTypes.RoomHistoryVisibility, ""
-            )
-            if (
-                visibility and
-                visibility.content["history_visibility"] == "world_readable"
-            ):
-                defer.returnValue((Membership.JOIN, None))
-                return
-            raise AuthError(
-                403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
-            )
-
-    @defer.inlineCallbacks
-    def get_state_events(self, user_id, room_id, is_guest=False):
-        """Retrieve all state events for a given room. If the user is
-        joined to the room then return the current state. If the user has
-        left the room return the state events from when they left.
-
-        Args:
-            user_id(str): The user requesting state events.
-            room_id(str): The room ID to get all state events from.
-        Returns:
-            A list of dicts representing state events. [{}, {}, {}]
-        """
-        membership, membership_event_id = yield self._check_in_room_or_world_readable(
-            room_id, user_id
-        )
-
-        if membership == Membership.JOIN:
-            room_state = yield self.state_handler.get_current_state(room_id)
-        elif membership == Membership.LEAVE:
-            room_state = yield self.store.get_state_for_events(
-                [membership_event_id], None
-            )
-            room_state = room_state[membership_event_id]
-
-        now = self.clock.time_msec()
-        defer.returnValue(
-            [serialize_event(c, now) for c in room_state.values()]
-        )
-
-    @defer.inlineCallbacks
-    def get_joined_members(self, requester, room_id):
-        """Get all the joined members in the room and their profile information.
-
-        If the user has left the room return the state events from when they left.
-
-        Args:
-            requester(Requester): The user requesting state events.
-            room_id(str): The room ID to get all state events from.
-        Returns:
-            A dict of user_id to profile info
-        """
-        user_id = requester.user.to_string()
-        if not requester.app_service:
-            # We check AS auth after fetching the room membership, as it
-            # requires us to pull out all joined members anyway.
-            membership, _ = yield self._check_in_room_or_world_readable(
-                room_id, user_id
-            )
-            if membership != Membership.JOIN:
-                raise NotImplementedError(
-                    "Getting joined members after leaving is not implemented"
-                )
-
-        users_with_profile = yield self.state.get_current_user_in_room(room_id)
-
-        # If this is an AS, double check that they are allowed to see the members.
-        # This can either be because the AS user is in the room or becuase there
-        # is a user in the room that the AS is "interested in"
-        if requester.app_service and user_id not in users_with_profile:
-            for uid in users_with_profile:
-                if requester.app_service.is_interested_in_user(uid):
-                    break
-            else:
-                # Loop fell through, AS has no interested users in room
-                raise AuthError(403, "Appservice not in room")
-
-        defer.returnValue({
-            user_id: {
-                "avatar_url": profile.avatar_url,
-                "display_name": profile.display_name,
-            }
-            for user_id, profile in users_with_profile.iteritems()
-        })
-
-    @measure_func("_create_new_client_event")
-    @defer.inlineCallbacks
-    def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
+    def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
         if prev_event_ids:
             prev_events = yield self.store.add_event_hashes(prev_event_ids)
             prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@@ -509,9 +520,7 @@ class MessageHandler(BaseHandler):
         builder.prev_events = prev_events
         builder.depth = depth
 
-        state_handler = self.state_handler
-
-        context = yield state_handler.compute_event_context(builder)
+        context = yield self.state.compute_event_context(builder)
         if requester:
             context.app_service = requester.app_service
 
@@ -551,7 +560,7 @@ class MessageHandler(BaseHandler):
         # We now need to go and hit out to wherever we need to hit out to.
 
         if ratelimit:
-            yield self.ratelimit(requester)
+            yield self.base_handler.ratelimit(requester)
 
         try:
             yield self.auth.check_from_context(event, context)
@@ -567,7 +576,7 @@ class MessageHandler(BaseHandler):
             logger.exception("Failed to encode content: %r", event.content)
             raise
 
-        yield self.maybe_kick_guest_users(event, context)
+        yield self.base_handler.maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Check the alias is acually valid (at this time at least)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d1cc87a016..6ab020bf41 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
         super(RoomCreationHandler, self).__init__(hs)
 
         self.spam_checker = hs.get_spam_checker()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
     @defer.inlineCallbacks
     def create_room(self, requester, config, ratelimit=True):
@@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
 
         creation_content = config.get("creation_content", {})
 
-        msg_handler = self.hs.get_handlers().message_handler
         room_member_handler = self.hs.get_handlers().room_member_handler
 
         yield self._send_events_for_new_room(
             requester,
             room_id,
-            msg_handler,
             room_member_handler,
             preset_config=preset_config,
             invite_list=invite_list,
@@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
 
         if "name" in config:
             name = config["name"]
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
                 {
                     "type": EventTypes.Name,
@@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
 
         if "topic" in config:
             topic = config["topic"]
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
                 {
                     "type": EventTypes.Topic,
@@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
             self,
             creator,  # A Requester object.
             room_id,
-            msg_handler,
             room_member_handler,
             preset_config,
             invite_list,
@@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
         @defer.inlineCallbacks
         def send(etype, content, **kwargs):
             event = create(etype, content, **kwargs)
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 creator,
                 event,
                 ratelimit=False
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7e6467cd1d..37dc5e99ab 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -46,6 +47,7 @@ class RoomMemberHandler(BaseHandler):
         super(RoomMemberHandler, self).__init__(hs)
 
         self.profile_handler = hs.get_profile_handler()
+        self.event_creation_hander = hs.get_event_creation_handler()
 
         self.member_linearizer = Linearizer(name="member")
 
@@ -66,13 +68,12 @@ class RoomMemberHandler(BaseHandler):
     ):
         if content is None:
             content = {}
-        msg_handler = self.hs.get_handlers().message_handler
 
         content["membership"] = membership
         if requester.is_guest:
             content["kind"] = "guest"
 
-        event, context = yield msg_handler.create_event(
+        event, context = yield self.event_creation_hander.create_event(
             requester,
             {
                 "type": EventTypes.Member,
@@ -90,12 +91,14 @@ class RoomMemberHandler(BaseHandler):
         )
 
         # Check if this event matches the previous membership event for the user.
-        duplicate = yield msg_handler.deduplicate_state_event(event, context)
+        duplicate = yield self.event_creation_hander.deduplicate_state_event(
+            event, context,
+        )
         if duplicate is not None:
             # Discard the new event since this membership change is a no-op.
             defer.returnValue(duplicate)
 
-        yield msg_handler.handle_new_client_event(
+        yield self.event_creation_hander.handle_new_client_event(
             requester,
             event,
             context,
@@ -394,8 +397,9 @@ class RoomMemberHandler(BaseHandler):
         else:
             requester = synapse.types.create_requester(target_user)
 
-        message_handler = self.hs.get_handlers().message_handler
-        prev_event = yield message_handler.deduplicate_state_event(event, context)
+        prev_event = yield self.event_creation_hander.deduplicate_state_event(
+            event, context,
+        )
         if prev_event is not None:
             return
 
@@ -412,7 +416,7 @@ class RoomMemberHandler(BaseHandler):
             if is_blocked:
                 raise SynapseError(403, "This room has been blocked on this server")
 
-        yield message_handler.handle_new_client_event(
+        yield self.event_creation_hander.handle_new_client_event(
             requester,
             event,
             context,
@@ -644,8 +648,7 @@ class RoomMemberHandler(BaseHandler):
             )
         )
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_hander.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.ThirdPartyInvite,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 71420e54db..ef8e62901b 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
             return default
 
 
-def parse_json_value_from_request(request):
+def parse_json_value_from_request(request, allow_empty_body=False):
     """Parse a JSON value from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
+        allow_empty_body (bool): if True, an empty body will be accepted and
+            turned into None
 
     Returns:
         The JSON value.
@@ -165,6 +167,9 @@ def parse_json_value_from_request(request):
     except Exception:
         raise SynapseError(400, "Error reading JSON content.")
 
+    if not content_bytes and allow_empty_body:
+        return None
+
     try:
         content = simplejson.loads(content_bytes)
     except Exception as e:
@@ -174,17 +179,24 @@ def parse_json_value_from_request(request):
     return content
 
 
-def parse_json_object_from_request(request):
+def parse_json_object_from_request(request, allow_empty_body=False):
     """Parse a JSON object from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
+        allow_empty_body (bool): if True, an empty body will be accepted and
+            turned into an empty dict.
 
     Raises:
         SynapseError if the request body couldn't be decoded as JSON or
             if it wasn't a JSON object.
     """
-    content = parse_json_value_from_request(request)
+    content = parse_json_value_from_request(
+        request, allow_empty_body=allow_empty_body,
+    )
+
+    if allow_empty_body and content is None:
+        return {}
 
     if type(content) != dict:
         message = "Content must be a JSON object."
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index 1e783e5ff4..ff5aa8c0e1 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -193,7 +193,9 @@ class DistributionMetric(object):
 
 
 class CacheMetric(object):
-    __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
+    __slots__ = (
+        "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
+    )
 
     def __init__(self, name, size_callback, cache_name):
         self.name = name
@@ -201,6 +203,7 @@ class CacheMetric(object):
 
         self.hits = 0
         self.misses = 0
+        self.evicted_size = 0
 
         self.size_callback = size_callback
 
@@ -210,6 +213,9 @@ class CacheMetric(object):
     def inc_misses(self):
         self.misses += 1
 
+    def inc_evictions(self, size=1):
+        self.evicted_size += size
+
     def render(self):
         size = self.size_callback()
         hits = self.hits
@@ -219,6 +225,9 @@ class CacheMetric(object):
             """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
             """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
             """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
+            """%s:evicted_size{name="%s"} %d""" % (
+                self.name, self.cache_name, self.evicted_size
+            ),
         ]
 
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 29d7296b43..8acb5df0f3 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -19,7 +19,7 @@ from synapse.storage import DataStore
 from synapse.storage.event_federation import EventFederationStore
 from synapse.storage.event_push_actions import EventPushActionsStore
 from synapse.storage.roommember import RoomMemberStore
-from synapse.storage.state import StateGroupReadStore
+from synapse.storage.state import StateGroupWorkerStore
 from synapse.storage.stream import StreamStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from ._base import BaseSlavedStore
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
 # the method descriptor on the DataStore and chuck them into our class.
 
 
-class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
+class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
         super(SlavedEventStore, self).__init__(db_conn, hs)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 5022808ea9..2ad486c67d 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -128,7 +129,16 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
         if not is_admin:
             raise AuthError(403, "You are not a server admin")
 
-        yield self.handlers.message_handler.purge_history(room_id, event_id)
+        body = parse_json_object_from_request(request, allow_empty_body=True)
+
+        delete_local_events = bool(
+            body.get("delete_local_history", False)
+        )
+
+        yield self.handlers.message_handler.purge_history(
+            room_id, event_id,
+            delete_local_events=delete_local_events,
+        )
 
         defer.returnValue((200, {}))
 
@@ -171,6 +181,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
         self.store = hs.get_datastore()
         self.handlers = hs.get_handlers()
         self.state = hs.get_state_handler()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, room_id):
@@ -203,8 +214,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
         )
         new_room_id = info["room_id"]
 
-        msg_handler = self.handlers.message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_handler.create_and_send_nonmember_event(
             room_creator_requester,
             {
                 "type": "m.room.message",
@@ -289,6 +299,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
         defer.returnValue((200, {"num_quarantined": num_quarantined}))
 
 
+class ListMediaInRoom(ClientV1RestServlet):
+    """Lists all of the media in a given room.
+    """
+    PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+    def __init__(self, hs):
+        super(ListMediaInRoom, self).__init__(hs)
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+        defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
 class ResetPasswordRestServlet(ClientV1RestServlet):
     """Post request to allow an administrator reset password for a user.
     This needs user to have administrator access in Synapse.
@@ -487,3 +518,4 @@ def register_servlets(hs, http_server):
     SearchUsersRestServlet(hs).register(http_server)
     ShutdownRoomRestServlet(hs).register(http_server)
     QuarantineMediaInRoom(hs).register(http_server)
+    ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 867ec8602c..fbb2fc36e4 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -82,6 +83,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomStateEventRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.event_creation_hander = hs.get_event_creation_handler()
 
     def register(self, http_server):
         # /room/$roomid/state/$eventtype
@@ -162,15 +164,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
                 content=content,
             )
         else:
-            msg_handler = self.handlers.message_handler
-            event, context = yield msg_handler.create_event(
+            event, context = yield self.event_creation_hander.create_event(
                 requester,
                 event_dict,
                 token_id=requester.access_token_id,
                 txn_id=txn_id,
             )
 
-            yield msg_handler.send_nonmember_event(requester, event, context)
+            yield self.event_creation_hander.send_nonmember_event(
+                requester, event, context,
+            )
 
         ret = {}
         if event:
@@ -184,6 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomSendEventRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.event_creation_hander = hs.get_event_creation_handler()
 
     def register(self, http_server):
         # /rooms/$roomid/send/$event_type[/$txn_id]
@@ -205,8 +209,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
         if 'ts' in request.args and requester.app_service:
             event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
 
-        msg_handler = self.handlers.message_handler
-        event = yield msg_handler.create_and_send_nonmember_event(
+        event = yield self.event_creation_hander.create_and_send_nonmember_event(
             requester,
             event_dict,
             txn_id=txn_id,
@@ -670,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomRedactEventRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
     def register(self, http_server):
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -680,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request)
         content = parse_json_object_from_request(request)
 
-        msg_handler = self.handlers.message_handler
-        event = yield msg_handler.create_and_send_nonmember_event(
+        event = yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.Redaction,
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 485db8577a..bb79599379 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -472,8 +472,10 @@ class MediaRepository(object):
 
     @defer.inlineCallbacks
     def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
-                                       t_method, t_type):
-        input_path = self.filepaths.local_media_filepath(media_id)
+                                       t_method, t_type, url_cache):
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            None, media_id, url_cache=url_cache,
+        ))
 
         thumbnailer = Thumbnailer(input_path)
         t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -486,6 +488,7 @@ class MediaRepository(object):
                 file_info = FileInfo(
                     server_name=None,
                     file_id=media_id,
+                    url_cache=url_cache,
                     thumbnail=True,
                     thumbnail_width=t_width,
                     thumbnail_height=t_height,
@@ -512,7 +515,9 @@ class MediaRepository(object):
     @defer.inlineCallbacks
     def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
                                         t_width, t_height, t_method, t_type):
-        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            server_name, file_id, url_cache=False,
+        ))
 
         thumbnailer = Thumbnailer(input_path)
         t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -570,12 +575,9 @@ class MediaRepository(object):
         if not requirements:
             return
 
-        if server_name:
-            input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-        elif url_cache:
-            input_path = self.filepaths.url_cache_filepath(media_id)
-        else:
-            input_path = self.filepaths.local_media_filepath(media_id)
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            server_name, file_id, url_cache=url_cache,
+        ))
 
         thumbnailer = Thumbnailer(input_path)
         m_width = thumbnailer.width
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 041ae396cd..3f8d4b9c22 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -18,6 +18,7 @@ from twisted.protocols.basic import FileSender
 
 from ._base import Responder
 
+from synapse.util.file_consumer import BackgroundFileConsumer
 from synapse.util.logcontext import make_deferred_yieldable
 
 import contextlib
@@ -26,6 +27,7 @@ import logging
 import shutil
 import sys
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -68,6 +70,12 @@ class MediaStorage(object):
             _write_file_synchronously, source, fname,
         ))
 
+        # Tell the storage providers about the new file. They'll decide
+        # if they should upload it and whether to do so synchronously
+        # or not.
+        for provider in self.storage_providers:
+            yield provider.store_file(path, file_info)
+
         defer.returnValue(fname)
 
     @contextlib.contextmanager
@@ -151,6 +159,37 @@ class MediaStorage(object):
 
         defer.returnValue(None)
 
+    @defer.inlineCallbacks
+    def ensure_media_is_in_local_cache(self, file_info):
+        """Ensures that the given file is in the local cache. Attempts to
+        download it from storage providers if it isn't.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            Deferred[str]: Full path to local file
+        """
+        path = self._file_info_to_path(file_info)
+        local_path = os.path.join(self.local_media_directory, path)
+        if os.path.exists(local_path):
+            defer.returnValue(local_path)
+
+        dirname = os.path.dirname(local_path)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        for provider in self.storage_providers:
+            res = yield provider.fetch(path, file_info)
+            if res:
+                with res:
+                    consumer = BackgroundFileConsumer(open(local_path, "w"))
+                    yield res.write_to_consumer(consumer)
+                    yield consumer.wait()
+                defer.returnValue(local_path)
+
+        raise Exception("file could not be found")
+
     def _file_info_to_path(self, file_info):
         """Converts file_info into a relative path.
 
@@ -228,9 +267,8 @@ class FileResponder(Responder):
     def __init__(self, open_file):
         self.open_file = open_file
 
-    @defer.inlineCallbacks
     def write_to_consumer(self, consumer):
-        yield FileSender().beginFileTransfer(self.open_file, consumer)
+        return FileSender().beginFileTransfer(self.open_file, consumer)
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 981f01e417..31fe7aa75c 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,6 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import cgi
+import datetime
+import errno
+import fnmatch
+import itertools
+import logging
+import os
+import re
+import shutil
+import sys
+import traceback
+import ujson as json
+import urlparse
 
 from twisted.web.server import NOT_DONE_YET
 from twisted.internet import defer
@@ -33,18 +46,6 @@ from synapse.http.server import (
 from synapse.util.async import ObservableDeferred
 from synapse.util.stringutils import is_ascii
 
-import os
-import re
-import fnmatch
-import cgi
-import ujson as json
-import urlparse
-import itertools
-import datetime
-import errno
-import shutil
-
-import logging
 logger = logging.getLogger(__name__)
 
 
@@ -286,17 +287,28 @@ class PreviewUrlResource(Resource):
             url_cache=True,
         )
 
-        try:
-            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+            try:
                 logger.debug("Trying to get url '%s'" % url)
                 length, headers, uri, code = yield self.client.get_file(
                     url, output_stream=f, max_size=self.max_spider_size,
                 )
+            except Exception as e:
                 # FIXME: pass through 404s and other error messages nicely
+                logger.warn("Error downloading %s: %r", url, e)
+                raise SynapseError(
+                    500, "Failed to download content: %s" % (
+                        traceback.format_exception_only(sys.exc_type, e),
+                    ),
+                    Codes.UNKNOWN,
+                )
+            yield finish()
 
-                yield finish()
-
-            media_type = headers["Content-Type"][0]
+        try:
+            if "Content-Type" in headers:
+                media_type = headers["Content-Type"][0]
+            else:
+                media_type = "application/octet-stream"
             time_now_ms = self.clock.time_msec()
 
             content_disposition = headers.get("Content-Disposition", None)
@@ -336,10 +348,11 @@ class PreviewUrlResource(Resource):
             )
 
         except Exception as e:
-            raise SynapseError(
-                500, ("Failed to download content: %s" % e),
-                Codes.UNKNOWN
-            )
+            logger.error("Error handling downloaded %s: %r", url, e)
+            # TODO: we really ought to delete the downloaded file in this
+            # case, since we won't have recorded it in the db, and will
+            # therefore not expire it.
+            raise
 
         defer.returnValue({
             "media_type": media_type,
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 12e84a2b7c..58ada49711 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -164,7 +164,8 @@ class ThumbnailResource(Resource):
 
         # Okay, so we generate one.
         file_path = yield self.media_repo.generate_local_exact_thumbnail(
-            media_id, desired_width, desired_height, desired_method, desired_type
+            media_id, desired_width, desired_height, desired_method, desired_type,
+            url_cache=media_info["url_cache"],
         )
 
         if file_path:
diff --git a/synapse/server.py b/synapse/server.py
index ff8a8fbc46..fbd602d40e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -55,6 +55,7 @@ from synapse.handlers.read_marker import ReadMarkerHandler
 from synapse.handlers.user_directory import UserDirectoryHandler
 from synapse.handlers.groups_local import GroupsLocalHandler
 from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.message import EventCreationHandler
 from synapse.groups.groups_server import GroupsServerHandler
 from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
 from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@@ -66,7 +67,7 @@ from synapse.rest.media.v1.media_repository import (
     MediaRepository,
     MediaRepositoryResource,
 )
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
 from synapse.storage import DataStore
 from synapse.streams.events import EventSources
 from synapse.util import Clock
@@ -102,6 +103,7 @@ class HomeServer(object):
         'v1auth',
         'auth',
         'state_handler',
+        'state_resolution_handler',
         'presence_handler',
         'sync_handler',
         'typing_handler',
@@ -117,6 +119,7 @@ class HomeServer(object):
         'application_service_handler',
         'device_message_handler',
         'profile_handler',
+        'event_creation_handler',
         'deactivate_account_handler',
         'set_password_handler',
         'notifier',
@@ -224,6 +227,9 @@ class HomeServer(object):
     def build_state_handler(self):
         return StateHandler(self)
 
+    def build_state_resolution_handler(self):
+        return StateResolutionHandler(self)
+
     def build_presence_handler(self):
         return PresenceHandler(self)
 
@@ -272,6 +278,9 @@ class HomeServer(object):
     def build_profile_handler(self):
         return ProfileHandler(self)
 
+    def build_event_creation_handler(self):
+        return EventCreationHandler(self)
+
     def build_deactivate_account_handler(self):
         return DeactivateAccountHandler(self)
 
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 41416ef252..c3a9a3847b 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -34,6 +34,9 @@ class HomeServer(object):
     def get_state_handler(self) -> synapse.state.StateHandler:
         pass
 
+    def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
+        pass
+
     def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
         pass
 
diff --git a/synapse/state.py b/synapse/state.py
index 4c8247e7c2..cc93bbcb6b 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -58,7 +58,11 @@ class _StateCacheEntry(object):
     __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
     def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+        # dict[(str, str), str] map  from (type, state_key) to event_id
         self.state = frozendict(state)
+
+        # the ID of a state group if one and only one is involved.
+        # otherwise, None otherwise?
         self.state_group = state_group
 
         self.prev_group = prev_group
@@ -81,31 +85,19 @@ class _StateCacheEntry(object):
 
 
 class StateHandler(object):
-    """ Responsible for doing state conflict resolution.
+    """Fetches bits of state from the stores, and does state resolution
+    where necessary
     """
 
     def __init__(self, hs):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.hs = hs
-
-        # dict of set of event_ids -> _StateCacheEntry.
-        self._state_cache = None
-        self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+        self._state_resolution_handler = hs.get_state_resolution_handler()
 
     def start_caching(self):
-        logger.debug("start_caching")
-
-        self._state_cache = ExpiringCache(
-            cache_name="state_cache",
-            clock=self.clock,
-            max_len=SIZE_OF_CACHE,
-            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
-            iterable=True,
-            reset_expiry_on_get=True,
-        )
-
-        self._state_cache.start()
+        # TODO: remove this shim
+        self._state_resolution_handler.start_caching()
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key="",
@@ -127,7 +119,7 @@ class StateHandler(object):
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
         logger.debug("calling resolve_state_groups from get_current_state")
-        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = ret.state
 
         if event_type:
@@ -164,7 +156,7 @@ class StateHandler(object):
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = ret.state
 
         defer.returnValue(state)
@@ -174,7 +166,7 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         logger.debug("calling resolve_state_groups from get_current_user_in_room")
-        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+        entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
         defer.returnValue(joined_users)
 
@@ -183,7 +175,7 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
-        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+        entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
         defer.returnValue(joined_hosts)
 
@@ -191,8 +183,15 @@ class StateHandler(object):
     def compute_event_context(self, event, old_state=None):
         """Build an EventContext structure for the event.
 
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
+
         Args:
             event (synapse.events.EventBase):
+            old_state (dict|None): The state at the event if it can't be
+                calculated from existing events. This is normally only specified
+                when receiving an event from federation where we don't have the
+                prev events for, e.g. when backfilling.
         Returns:
             synapse.events.snapshot.EventContext:
         """
@@ -216,15 +215,22 @@ class StateHandler(object):
                 context.current_state_ids = {}
                 context.prev_state_ids = {}
             context.prev_state_events = []
-            context.state_group = self.store.get_next_state_group()
+
+            # We don't store state for outliers, so we don't generate a state
+            # froup for it.
+            context.state_group = None
+
             defer.returnValue(context)
 
         if old_state:
+            # We already have the state, so we don't need to calculate it.
+            # Let's just correctly fill out the context and create a
+            # new state group for it.
+
             context = EventContext()
             context.prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
-            context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -237,11 +243,19 @@ class StateHandler(object):
             else:
                 context.current_state_ids = context.prev_state_ids
 
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=None,
+                delta_ids=None,
+                current_state_ids=context.current_state_ids,
+            )
+
             context.prev_state_events = []
             defer.returnValue(context)
 
         logger.debug("calling resolve_state_groups from compute_event_context")
-        entry = yield self.resolve_state_groups(
+        entry = yield self.resolve_state_groups_for_events(
             event.room_id, [e for e, _ in event.prev_events],
         )
 
@@ -250,7 +264,8 @@ class StateHandler(object):
         context = EventContext()
         context.prev_state_ids = curr_state
         if event.is_state():
-            context.state_group = self.store.get_next_state_group()
+            # If this is a state event then we need to create a new state
+            # group for the state after this event.
 
             key = (event.type, event.state_key)
             if key in context.prev_state_ids:
@@ -261,38 +276,57 @@ class StateHandler(object):
             context.current_state_ids[key] = event.event_id
 
             if entry.state_group:
+                # If the state at the event has a state group assigned then
+                # we can use that as the prev group
                 context.prev_group = entry.state_group
                 context.delta_ids = {
                     key: event.event_id
                 }
             elif entry.prev_group:
+                # If the state at the event only has a prev group, then we can
+                # use that as a prev group too.
                 context.prev_group = entry.prev_group
                 context.delta_ids = dict(entry.delta_ids)
                 context.delta_ids[key] = event.event_id
+
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=context.prev_group,
+                delta_ids=context.delta_ids,
+                current_state_ids=context.current_state_ids,
+            )
         else:
+            context.current_state_ids = context.prev_state_ids
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+
             if entry.state_group is None:
-                entry.state_group = self.store.get_next_state_group()
+                entry.state_group = yield self.store.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=entry.prev_group,
+                    delta_ids=entry.delta_ids,
+                    current_state_ids=context.current_state_ids,
+                )
                 entry.state_id = entry.state_group
 
             context.state_group = entry.state_group
-            context.current_state_ids = context.prev_state_ids
-            context.prev_group = entry.prev_group
-            context.delta_ids = entry.delta_ids
 
         context.prev_state_events = []
         defer.returnValue(context)
 
     @defer.inlineCallbacks
-    @log_function
-    def resolve_state_groups(self, room_id, event_ids):
+    def resolve_state_groups_for_events(self, room_id, event_ids):
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
+        Args:
+            room_id (str):
+            event_ids (list[str]):
+
         Returns:
-            a Deferred tuple of (`state_group`, `state`, `prev_state`).
-            `state_group` is the name of a state group if one and only one is
-            involved. `state` is a map from (type, state_key) to event, and
-            `prev_state` is a list of event ids.
+            Deferred[_StateCacheEntry]: resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
@@ -303,13 +337,7 @@ class StateHandler(object):
             room_id, event_ids
         )
 
-        logger.debug(
-            "resolve_state_groups state_groups %s",
-            state_groups_ids.keys()
-        )
-
-        group_names = frozenset(state_groups_ids.keys())
-        if len(group_names) == 1:
+        if len(state_groups_ids) == 1:
             name, state_list = state_groups_ids.items().pop()
 
             prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -321,6 +349,92 @@ class StateHandler(object):
                 delta_ids=delta_ids,
             ))
 
+        result = yield self._state_resolution_handler.resolve_state_groups(
+            room_id, state_groups_ids, self._state_map_factory,
+        )
+        defer.returnValue(result)
+
+    def _state_map_factory(self, ev_ids):
+        return self.store.get_events(
+            ev_ids, get_prev_content=False, check_redacted=False,
+        )
+
+    def resolve_events(self, state_sets, event):
+        logger.info(
+            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+        )
+        state_set_ids = [{
+            (ev.type, ev.state_key): ev.event_id
+            for ev in st
+        } for st in state_sets]
+
+        state_map = {
+            ev.event_id: ev
+            for st in state_sets
+            for ev in st
+        }
+
+        with Measure(self.clock, "state._resolve_events"):
+            new_state = resolve_events_with_state_map(state_set_ids, state_map)
+
+        new_state = {
+            key: state_map[ev_id] for key, ev_id in new_state.items()
+        }
+
+        return new_state
+
+
+class StateResolutionHandler(object):
+    """Responsible for doing state conflict resolution.
+
+    Note that the storage layer depends on this handler, so all functions must
+    be storage-independent.
+    """
+    def __init__(self, hs):
+        self.clock = hs.get_clock()
+
+        # dict of set of event_ids -> _StateCacheEntry.
+        self._state_cache = None
+        self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+
+    def start_caching(self):
+        logger.debug("start_caching")
+
+        self._state_cache = ExpiringCache(
+            cache_name="state_cache",
+            clock=self.clock,
+            max_len=SIZE_OF_CACHE,
+            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+            iterable=True,
+            reset_expiry_on_get=True,
+        )
+
+        self._state_cache.start()
+
+    @defer.inlineCallbacks
+    @log_function
+    def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
+        """Resolves conflicts between a set of state groups
+
+        Always generates a new state group (unless we hit the cache), so should
+        not be called for a single state group
+
+        Args:
+            room_id (str): room we are resolving for (used for logging)
+            state_groups_ids (dict[int, dict[(str, str), str]]):
+                 map from state group id to the state in that state group
+                (where 'state' is a map from state key to event id)
+
+        Returns:
+            Deferred[_StateCacheEntry]: resolved state
+        """
+        logger.debug(
+            "resolve_state_groups state_groups %s",
+            state_groups_ids.keys()
+        )
+
+        group_names = frozenset(state_groups_ids.keys())
+
         with (yield self.resolve_linearizer.queue(group_names)):
             if self._state_cache is not None:
                 cache = self._state_cache.get(group_names, None)
@@ -351,15 +465,17 @@ class StateHandler(object):
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
                         state_groups_ids.values(),
-                        state_map_factory=lambda ev_ids: self.store.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False,
-                        ),
+                        state_map_factory=state_map_factory,
                     )
             else:
                 new_state = {
                     key: e_ids.pop() for key, e_ids in state.items()
                 }
 
+            # if the new state matches any of the input state groups, we can
+            # use that state group again. Otherwise we will generate a state_id
+            # which will be used as a cache key for future resolutions, but
+            # not get persisted.
             state_group = None
             new_state_event_ids = frozenset(new_state.values())
             for sg, events in state_groups_ids.items():
@@ -396,30 +512,6 @@ class StateHandler(object):
 
             defer.returnValue(cache)
 
-    def resolve_events(self, state_sets, event):
-        logger.info(
-            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
-        )
-        state_set_ids = [{
-            (ev.type, ev.state_key): ev.event_id
-            for ev in st
-        } for st in state_sets]
-
-        state_map = {
-            ev.event_id: ev
-            for st in state_sets
-            for ev in st
-        }
-
-        with Measure(self.clock, "state._resolve_events"):
-            new_state = resolve_events_with_state_map(state_set_ids, state_map)
-
-        new_state = {
-            key: state_map[ev_id] for key, ev_id in new_state.items()
-        }
-
-        return new_state
-
 
 def _ordered_events(events):
     def key_func(e):
@@ -437,8 +529,8 @@ def resolve_events_with_state_map(state_sets, state_map):
             state_sets.
 
     Returns
-        dict[(str, str), synapse.events.FrozenEvent]:
-            a map from (type, state_key) to event.
+        dict[(str, str), str]:
+            a map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -460,6 +552,21 @@ def _seperate(state_sets):
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
+
+    Args:
+        state_sets(list[dict[(str, str), str]]):
+            List of dicts of (type, state_key) -> event_id, which are the
+            different state groups to resolve.
+
+    Returns:
+        (dict[(str, str), str], dict[(str, str), set[str]]):
+            A tuple of (unconflicted_state, conflicted_state), where:
+
+            unconflicted_state is a dict mapping (type, state_key)->event_id
+            for unconflicted state keys.
+
+            conflicted_state is a dict mapping (type, state_key) to a set of
+            event ids for conflicted state keys.
     """
     unconflicted_state = dict(state_sets[0])
     conflicted_state = {}
@@ -500,8 +607,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
             a Deferred of dict of event_id to event.
 
     Returns
-        Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
-            a map from (type, state_key) to event.
+        Deferred[dict[(str, str), str]]:
+            a map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         defer.returnValue(state_sets[0])
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d01d46338a..f8fbd02ceb 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
         )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
-        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
 
     def lock_table(self, txn, table):
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        txn.execute("SELECT nextval('state_group_id_seq')")
+        return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..60f0fa7fb3 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -16,6 +16,7 @@
 from synapse.storage.prepare_database import prepare_database
 
 import struct
+import threading
 
 
 class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
     def __init__(self, database_module, database_config):
         self.module = database_module
 
+        # The current max state_group, or None if we haven't looked
+        # in the DB yet.
+        self._current_state_group_id = None
+        self._current_state_group_id_lock = threading.Lock()
+
     def check_database(self, txn):
         pass
 
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
     def lock_table(self, txn, table):
         return
 
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        # We do application locking here since if we're using sqlite then
+        # we are a single process synapse.
+        with self._current_state_group_id_lock:
+            if self._current_state_group_id is None:
+                txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+                self._current_state_group_id = txn.fetchone()[0]
+
+            self._current_state_group_id += 1
+            return self._current_state_group_id
+
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 33fccfa7a8..86a7c5920d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -342,8 +342,20 @@ class EventsStore(SQLBaseStore):
 
                 # NB: Assumes that we are only persisting events for one room
                 # at a time.
+
+                # map room_id->list[event_ids] giving the new forward
+                # extremities in each room
                 new_forward_extremeties = {}
+
+                # map room_id->(type,state_key)->event_id tracking the full
+                # state in each room after adding these events
                 current_state_for_room = {}
+
+                # map room_id->(to_delete, to_insert) where each entry is
+                # a map (type,key)->event_id giving the state delta in each
+                # room
+                state_delta_for_room = {}
+
                 if not backfilled:
                     with Measure(self._clock, "_calculate_state_and_extrem"):
                         # Work out the new "current state" for each room.
@@ -386,11 +398,19 @@ class EventsStore(SQLBaseStore):
                                 if all_single_prev_not_state:
                                     continue
 
-                            state = yield self._calculate_state_delta(
-                                room_id, ev_ctx_rm, new_latest_event_ids
+                            logger.info(
+                                "Calculating state delta for room %s", room_id,
+                            )
+                            current_state = yield self._get_new_state_after_events(
+                                ev_ctx_rm, new_latest_event_ids,
                             )
-                            if state:
-                                current_state_for_room[room_id] = state
+                            if current_state is not None:
+                                current_state_for_room[room_id] = current_state
+                                delta = yield self._calculate_state_delta(
+                                    room_id, current_state,
+                                )
+                                if delta is not None:
+                                    state_delta_for_room[room_id] = delta
 
                 yield self.runInteraction(
                     "persist_events",
@@ -398,7 +418,7 @@ class EventsStore(SQLBaseStore):
                     events_and_contexts=chunk,
                     backfilled=backfilled,
                     delete_existing=delete_existing,
-                    current_state_for_room=current_state_for_room,
+                    state_delta_for_room=state_delta_for_room,
                     new_forward_extremeties=new_forward_extremeties,
                 )
                 persist_event_counter.inc_by(len(chunk))
@@ -415,7 +435,7 @@ class EventsStore(SQLBaseStore):
 
                     event_counter.inc(event.type, origin_type, origin_entity)
 
-                for room_id, (_, _, new_state) in current_state_for_room.iteritems():
+                for room_id, new_state in current_state_for_room.iteritems():
                     self.get_current_state_ids.prefill(
                         (room_id, ), new_state
                     )
@@ -467,20 +487,22 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
-        """Calculate the new state deltas for a room.
+    def _get_new_state_after_events(self, events_context, new_latest_event_ids):
+        """Calculate the current state dict after adding some new events to
+        a room
 
-        Assumes that we are only persisting events for one room at a time.
+        Args:
+            events_context (list[(EventBase, EventContext)]):
+                events and contexts which are being added to the room
+
+            new_latest_event_ids (iterable[str]):
+                the new forward extremities for the room.
 
         Returns:
-            3-tuple (to_delete, to_insert, new_state) where both are state dicts,
-            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
-            first be deleted from current_state_events, `to_insert` are entries
-            to insert. `new_state` is the full set of state.
-            May return None if there are no changes to be applied.
+            Deferred[dict[(str,str), str]|None]:
+                None if there are no changes to the room state, or
+                a dict of (type, state_key) -> event_id].
         """
-        # Now we need to work out the different state sets for
-        # each state extremities
         state_sets = []
         state_groups = set()
         missing_event_ids = []
@@ -523,12 +545,12 @@ class EventsStore(SQLBaseStore):
                 state_sets.extend(group_to_state.itervalues())
 
         if not new_latest_event_ids:
-            current_state = {}
+            defer.returnValue({})
         elif was_updated:
             if len(state_sets) == 1:
                 # If there is only one state set, then we know what the current
                 # state is.
-                current_state = state_sets[0]
+                defer.returnValue(state_sets[0])
             else:
                 # We work out the current state by passing the state sets to the
                 # state resolution algorithm. It may ask for some events, including
@@ -537,8 +559,7 @@ class EventsStore(SQLBaseStore):
                 # up in the db.
 
                 logger.info(
-                    "Resolving state for %s with %i state sets",
-                    room_id, len(state_sets),
+                    "Resolving state with %i state sets", len(state_sets),
                 )
 
                 events_map = {ev.event_id: ev for ev, _ in events_context}
@@ -567,9 +588,22 @@ class EventsStore(SQLBaseStore):
                     state_sets,
                     state_map_factory=get_events,
                 )
+                defer.returnValue(current_state)
         else:
             return
 
+    @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, current_state):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            2-tuple (to_delete, to_insert) where both are state dicts,
+            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert.
+        """
         existing_state = yield self.get_current_state_ids(room_id)
 
         existing_events = set(existing_state.itervalues())
@@ -589,7 +623,7 @@ class EventsStore(SQLBaseStore):
             if ev_id in events_to_insert
         }
 
-        defer.returnValue((to_delete, to_insert, current_state))
+        defer.returnValue((to_delete, to_insert))
 
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
@@ -649,7 +683,7 @@ class EventsStore(SQLBaseStore):
 
     @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False, current_state_for_room={},
+                            delete_existing=False, state_delta_for_room={},
                             new_forward_extremeties={}):
         """Insert some number of room events into the necessary database tables.
 
@@ -665,7 +699,7 @@ class EventsStore(SQLBaseStore):
             delete_existing (bool): True to purge existing table rows for the
                 events from the database. This is useful when retrying due to
                 IntegrityError.
-            current_state_for_room (dict[str, (list[str], list[str])]):
+            state_delta_for_room (dict[str, (list[str], list[str])]):
                 The current-state delta for each room. For each room, a tuple
                 (to_delete, to_insert), being a list of event ids to be removed
                 from the current state, and a list of event ids to be added to
@@ -677,7 +711,7 @@ class EventsStore(SQLBaseStore):
         """
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
-        self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+        self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
 
         self._update_forward_extremities_txn(
             txn,
@@ -721,9 +755,8 @@ class EventsStore(SQLBaseStore):
             events_and_contexts=events_and_contexts,
         )
 
-        # Insert into the state_groups, state_groups_state, and
-        # event_to_state_groups tables.
-        self._store_mult_state_groups_txn(txn, events_and_contexts)
+        # Insert into event_to_state_groups.
+        self._store_event_state_mappings_txn(txn, events_and_contexts)
 
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
@@ -743,7 +776,7 @@ class EventsStore(SQLBaseStore):
 
     def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
         for room_id, current_state_tuple in state_delta_by_room.iteritems():
-                to_delete, to_insert, _ = current_state_tuple
+                to_delete, to_insert = current_state_tuple
                 txn.executemany(
                     "DELETE FROM current_state_events WHERE event_id = ?",
                     [(ev_id,) for ev_id in to_delete.itervalues()],
@@ -958,10 +991,9 @@ class EventsStore(SQLBaseStore):
                 # an outlier in the database. We now have some state at that
                 # so we need to update the state_groups table with that state.
 
-                # insert into the state_group, state_groups_state and
-                # event_to_state_groups tables.
+                # insert into event_to_state_groups.
                 try:
-                    self._store_mult_state_groups_txn(txn, ((event, context),))
+                    self._store_event_state_mappings_txn(txn, ((event, context),))
                 except Exception:
                     logger.exception("")
                     raise
@@ -2031,16 +2063,32 @@ class EventsStore(SQLBaseStore):
             )
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
-    def delete_old_state(self, room_id, topological_ordering):
-        return self.runInteraction(
-            "delete_old_state",
-            self._delete_old_state_txn, room_id, topological_ordering
-        )
+    def purge_history(
+        self, room_id, topological_ordering, delete_local_events,
+    ):
+        """Deletes room history before a certain point
+
+        Args:
+            room_id (str):
+
+            topological_ordering (int):
+                minimum topo ordering to preserve
 
-    def _delete_old_state_txn(self, txn, room_id, topological_ordering):
-        """Deletes old room state
+            delete_local_events (bool):
+                if True, we will delete local events as well as remote ones
+                (instead of just marking them as outliers and deleting their
+                state groups).
         """
 
+        return self.runInteraction(
+            "purge_history",
+            self._purge_history_txn, room_id, topological_ordering,
+            delete_local_events,
+        )
+
+    def _purge_history_txn(
+        self, txn, room_id, topological_ordering, delete_local_events,
+    ):
         # Tables that should be pruned:
         #     event_auth
         #     event_backward_extremities
@@ -2081,7 +2129,7 @@ class EventsStore(SQLBaseStore):
                 400, "topological_ordering is greater than forward extremeties"
             )
 
-        logger.debug("[purge] looking for events to delete")
+        logger.info("[purge] looking for events to delete")
 
         txn.execute(
             "SELECT event_id, state_key FROM events"
@@ -2093,16 +2141,16 @@ class EventsStore(SQLBaseStore):
 
         to_delete = [
             (event_id,) for event_id, state_key in event_rows
-            if state_key is None and not self.hs.is_mine_id(event_id)
+            if state_key is None and (
+                delete_local_events or not self.hs.is_mine_id(event_id)
+            )
         ]
         logger.info(
-            "[purge] found %i events before cutoff, of which %i are remote"
-            " non-state events to delete", len(event_rows), len(to_delete))
-
-        for event_id, state_key in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+            "[purge] found %i events before cutoff, of which %i can be deleted",
+            len(event_rows), len(to_delete),
+        )
 
-        logger.debug("[purge] Finding new backward extremities")
+        logger.info("[purge] Finding new backward extremities")
 
         # We calculate the new entries for the backward extremeties by finding
         # all events that point to events that are to be purged
@@ -2116,7 +2164,7 @@ class EventsStore(SQLBaseStore):
         )
         new_backwards_extrems = txn.fetchall()
 
-        logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
+        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
 
         txn.execute(
             "DELETE FROM event_backward_extremities WHERE room_id = ?",
@@ -2132,7 +2180,7 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
-        logger.debug("[purge] finding redundant state groups")
+        logger.info("[purge] finding redundant state groups")
 
         # Get all state groups that are only referenced by events that are
         # to be deleted.
@@ -2149,15 +2197,15 @@ class EventsStore(SQLBaseStore):
         )
 
         state_rows = txn.fetchall()
-        logger.debug("[purge] found %i redundant state groups", len(state_rows))
+        logger.info("[purge] found %i redundant state groups", len(state_rows))
 
         # make a set of the redundant state groups, so that we can look them up
         # efficiently
         state_groups_to_delete = set([sg for sg, in state_rows])
 
         # Now we get all the state groups that rely on these state groups
-        logger.debug("[purge] finding state groups which depend on redundant"
-                     " state groups")
+        logger.info("[purge] finding state groups which depend on redundant"
+                    " state groups")
         remaining_state_groups = []
         for i in xrange(0, len(state_rows), 100):
             chunk = [sg for sg, in state_rows[i:i + 100]]
@@ -2182,7 +2230,7 @@ class EventsStore(SQLBaseStore):
         # Now we turn the state groups that reference to-be-deleted state
         # groups to non delta versions.
         for sg in remaining_state_groups:
-            logger.debug("[purge] de-delta-ing remaining state group %s", sg)
+            logger.info("[purge] de-delta-ing remaining state group %s", sg)
             curr_state = self._get_state_groups_from_groups_txn(
                 txn, [sg], types=None
             )
@@ -2219,7 +2267,7 @@ class EventsStore(SQLBaseStore):
                 ],
             )
 
-        logger.debug("[purge] removing redundant state groups")
+        logger.info("[purge] removing redundant state groups")
         txn.executemany(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             state_rows
@@ -2229,18 +2277,15 @@ class EventsStore(SQLBaseStore):
             state_rows
         )
 
-        # Delete all non-state
-        logger.debug("[purge] removing events from event_to_state_groups")
+        logger.info("[purge] removing events from event_to_state_groups")
         txn.executemany(
             "DELETE FROM event_to_state_groups WHERE event_id = ?",
             [(event_id,) for event_id, _ in event_rows]
         )
-
-        logger.debug("[purge] updating room_depth")
-        txn.execute(
-            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (topological_ordering, room_id,)
-        )
+        for event_id, _ in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (
+                event_id,
+            ))
 
         # Delete all remote non-state events
         for table in (
@@ -2258,7 +2303,8 @@ class EventsStore(SQLBaseStore):
             "event_signatures",
             "rejections",
         ):
-            logger.debug("[purge] removing remote non-state events from %s", table)
+            logger.info("[purge] removing remote non-state events from %s",
+                        table)
 
             txn.executemany(
                 "DELETE FROM %s WHERE event_id = ?" % (table,),
@@ -2266,16 +2312,30 @@ class EventsStore(SQLBaseStore):
             )
 
         # Mark all state and own events as outliers
-        logger.debug("[purge] marking remaining events as outliers")
+        logger.info("[purge] marking remaining events as outliers")
         txn.executemany(
             "UPDATE events SET outlier = ?"
             " WHERE event_id = ?",
             [
                 (True, event_id,) for event_id, state_key in event_rows
-                if state_key is not None or self.hs.is_mine_id(event_id)
+                if state_key is not None or (
+                    not delete_local_events and self.hs.is_mine_id(event_id)
+                )
             ]
         )
 
+        # synapse tries to take out an exclusive lock on room_depth whenever it
+        # persists events (because upsert), and once we run this update, we
+        # will block that for the rest of our transaction.
+        #
+        # So, let's stick it at the end so that we don't block event
+        # persistence.
+        logger.info("[purge] updating room_depth")
+        txn.execute(
+            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+            (topological_ordering, room_id,)
+        )
+
         logger.info("[purge] done")
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 0fcfb7f86d..fff6652e05 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -506,73 +506,114 @@ class RoomStore(SearchStore):
         )
         self.is_room_blocked.invalidate((room_id,))
 
+    def get_media_mxcs_in_room(self, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        def _get_media_mxcs_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            local_media_mxcs = []
+            remote_media_mxcs = []
+
+            # Convert the IDs to MXC URIs
+            for media_id in local_mxcs:
+                local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+            for hostname, media_id in remote_mxcs:
+                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+            return local_media_mxcs, remote_media_mxcs
+        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
     def quarantine_media_ids_in_room(self, room_id, quarantined_by):
         """For a room loops through all events with media and quarantines
         the associated media
         """
-        def _get_media_ids_in_room(txn):
-            mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+        def _quarantine_media_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            total_media_quarantined = 0
 
-            next_token = self.get_current_events_token() + 1
+            # Now update all the tables to set the quarantined_by flag
 
-            total_media_quarantined = 0
+            txn.executemany("""
+                UPDATE local_media_repository
+                SET quarantined_by = ?
+                WHERE media_id = ?
+            """, ((quarantined_by, media_id) for media_id in local_mxcs))
 
-            while next_token:
-                sql = """
-                    SELECT stream_ordering, content FROM events
-                    WHERE room_id = ?
-                        AND stream_ordering < ?
-                        AND contains_url = ? AND outlier = ?
-                    ORDER BY stream_ordering DESC
-                    LIMIT ?
+            txn.executemany(
                 """
-                txn.execute(sql, (room_id, next_token, True, False, 100))
-
-                next_token = None
-                local_media_mxcs = []
-                remote_media_mxcs = []
-                for stream_ordering, content_json in txn:
-                    next_token = stream_ordering
-                    content = json.loads(content_json)
-
-                    content_url = content.get("url")
-                    thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
-                    for url in (content_url, thumbnail_url):
-                        if not url:
-                            continue
-                        matches = mxc_re.match(url)
-                        if matches:
-                            hostname = matches.group(1)
-                            media_id = matches.group(2)
-                            if hostname == self.hostname:
-                                local_media_mxcs.append(media_id)
-                            else:
-                                remote_media_mxcs.append((hostname, media_id))
-
-                # Now update all the tables to set the quarantined_by flag
-
-                txn.executemany("""
-                    UPDATE local_media_repository
+                    UPDATE remote_media_cache
                     SET quarantined_by = ?
-                    WHERE media_id = ?
-                """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
-                txn.executemany(
-                    """
-                        UPDATE remote_media_cache
-                        SET quarantined_by = ?
-                        WHERE media_origin = ? AND media_id = ?
-                    """,
-                    (
-                        (quarantined_by, origin, media_id)
-                        for origin, media_id in remote_media_mxcs
-                    )
+                    WHERE media_origin = ? AND media_id = ?
+                """,
+                (
+                    (quarantined_by, origin, media_id)
+                    for origin, media_id in remote_mxcs
                 )
+            )
 
-                total_media_quarantined += len(local_media_mxcs)
-                total_media_quarantined += len(remote_media_mxcs)
+            total_media_quarantined += len(local_mxcs)
+            total_media_quarantined += len(remote_mxcs)
 
             return total_media_quarantined
 
-        return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+        return self.runInteraction(
+            "quarantine_media_in_room",
+            _quarantine_media_in_room_txn,
+        )
+
+    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            txn (cursor)
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+        next_token = self.get_current_events_token() + 1
+        local_media_mxcs = []
+        remote_media_mxcs = []
+
+        while next_token:
+            sql = """
+                SELECT stream_ordering, content FROM events
+                WHERE room_id = ?
+                    AND stream_ordering < ?
+                    AND contains_url = ? AND outlier = ?
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """
+            txn.execute(sql, (room_id, next_token, True, False, 100))
+
+            next_token = None
+            for stream_ordering, content_json in txn:
+                next_token = stream_ordering
+                content = json.loads(content_json)
+
+                content_url = content.get("url")
+                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+                for url in (content_url, thumbnail_url):
+                    if not url:
+                        continue
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+        return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if isinstance(database_engine, PostgresEngine):
+        # if we already have some state groups, we want to start making new
+        # ones with a higher id.
+        cur.execute("SELECT max(id) FROM state_groups")
+        row = cur.fetchone()
+
+        if row[0] is None:
+            start_val = 1
+        else:
+            start_val = row[0] + 1
+
+        cur.execute(
+            "CREATE SEQUENCE state_group_id_seq START WITH %s",
+            (start_val, ),
+        )
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index eecf778516..7c80fab277 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -12,18 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from collections import namedtuple
+import logging
+import re
 import sys
+import ujson as json
+
 from twisted.internet import defer
 
 from .background_updates import BackgroundUpdateStore
 from synapse.api.errors import SynapseError
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-import logging
-import re
-import ujson as json
-
 
 logger = logging.getLogger(__name__)
 
@@ -280,10 +281,10 @@ class SearchStore(BackgroundUpdateStore):
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
                 "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, "
-                "  origin_server_ts)"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
                 " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
             )
+
             args = ((
                 entry.event_id, entry.room_id, entry.key, entry.value,
                 entry.stream_ordering, entry.origin_server_ts,
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 360e3e4355..adb48df73e 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
         return len(self.delta_ids) if self.delta_ids else 0
 
 
-class StateGroupReadStore(SQLBaseStore):
-    """The read-only parts of StateGroupStore
-
-    None of these functions write to the state tables, so are suitable for
-    including in the SlavedStores.
+class StateGroupWorkerStore(SQLBaseStore):
+    """The parts of StateGroupStore that can be called from workers.
     """
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
     def __init__(self, db_conn, hs):
-        super(StateGroupReadStore, self).__init__(db_conn, hs)
+        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
 
         self._state_group_cache = DictionaryCache(
             "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@@ -549,116 +546,66 @@ class StateGroupReadStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+                          current_state_ids):
+        """Store a new set of state, returning a newly assigned state group.
 
-class StateStore(StateGroupReadStore, BackgroundUpdateStore):
-    """ Keeps track of the state at a given event.
-
-    This is done by the concept of `state groups`. Every event is a assigned
-    a state group (identified by an arbitrary string), which references a
-    collection of state events. The current state of an event is then the
-    collection of state events referenced by the event's state group.
-
-    Hence, every change in the current state causes a new state group to be
-    generated. However, if no change happens (e.g., if we get a message event
-    with only one parent it inherits the state group from its parent.)
-
-    There are three tables:
-      * `state_groups`: Stores group name, first event with in the group and
-        room id.
-      * `event_to_state_groups`: Maps events to state groups.
-      * `state_groups_state`: Maps state group to state events.
-    """
-
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
-    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
-            self._background_deduplicate_state,
-        )
-        self.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME,
-            self._background_index_state,
-        )
-        self.register_background_index_update(
-            self.CURRENT_STATE_INDEX_UPDATE_NAME,
-            index_name="current_state_events_member_index",
-            table="current_state_events",
-            columns=["state_key"],
-            where_clause="type='m.room.member'",
-        )
-
-    def _have_persisted_state_group_txn(self, txn, state_group):
-        txn.execute(
-            "SELECT count(*) FROM state_groups WHERE id = ?",
-            (state_group,)
-        )
-        row = txn.fetchone()
-        return row and row[0]
-
-    def _store_mult_state_groups_txn(self, txn, events_and_contexts):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
 
-            if context.current_state_ids is None:
+        Returns:
+            Deferred[int]: The state group ID
+        """
+        def _store_state_group_txn(txn):
+            if current_state_ids is None:
                 # AFAIK, this can never happen
-                logger.error(
-                    "Non-outlier event %s had current_state_ids==None",
-                    event.event_id)
-                continue
+                raise Exception("current_state_ids cannot be None")
 
-            # if the event was rejected, just give it the same state as its
-            # predecessor.
-            if context.rejected:
-                state_groups[event.event_id] = context.prev_group
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-            if self._have_persisted_state_group_txn(txn, context.state_group):
-                continue
+            state_group = self.database_engine.get_next_state_group_id(txn)
 
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
-                    "id": context.state_group,
-                    "room_id": event.room_id,
-                    "event_id": event.event_id,
+                    "id": state_group,
+                    "room_id": room_id,
+                    "event_id": event_id,
                 },
             )
 
             # We persist as a delta if we can, while also ensuring the chain
             # of deltas isn't tooo long, as otherwise read performance degrades.
-            if context.prev_group:
+            if prev_group:
                 is_in_db = self._simple_select_one_onecol_txn(
                     txn,
                     table="state_groups",
-                    keyvalues={"id": context.prev_group},
+                    keyvalues={"id": prev_group},
                     retcol="id",
                     allow_none=True,
                 )
                 if not is_in_db:
                     raise Exception(
                         "Trying to persist state with unpersisted prev_group: %r"
-                        % (context.prev_group,)
+                        % (prev_group,)
                     )
 
                 potential_hops = self._count_state_group_hops_txn(
-                    txn, context.prev_group
+                    txn, prev_group
                 )
-            if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                 self._simple_insert_txn(
                     txn,
                     table="state_group_edges",
                     values={
-                        "state_group": context.state_group,
-                        "prev_state_group": context.prev_group,
+                        "state_group": state_group,
+                        "prev_state_group": prev_group,
                     },
                 )
 
@@ -667,13 +614,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
                     table="state_groups_state",
                     values=[
                         {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
+                            "state_group": state_group,
+                            "room_id": room_id,
                             "type": key[0],
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.delta_ids.iteritems()
+                        for key, state_id in delta_ids.iteritems()
                     ],
                 )
             else:
@@ -682,13 +629,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
                     table="state_groups_state",
                     values=[
                         {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
+                            "state_group": state_group,
+                            "room_id": room_id,
                             "type": key[0],
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.current_state_ids.iteritems()
+                        for key, state_id in current_state_ids.iteritems()
                     ],
                 )
 
@@ -699,11 +646,71 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
             txn.call_after(
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
-                key=context.state_group,
-                value=dict(context.current_state_ids),
+                key=state_group,
+                value=dict(current_state_ids),
                 full=True,
             )
 
+            return state_group
+
+        return self.runInteraction("store_state_group", _store_state_group_txn)
+
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+    """ Keeps track of the state at a given event.
+
+    This is done by the concept of `state groups`. Every event is a assigned
+    a state group (identified by an arbitrary string), which references a
+    collection of state events. The current state of an event is then the
+    collection of state events referenced by the event's state group.
+
+    Hence, every change in the current state causes a new state group to be
+    generated. However, if no change happens (e.g., if we get a message event
+    with only one parent it inherits the state group from its parent.)
+
+    There are three tables:
+      * `state_groups`: Stores group name, first event with in the group and
+        room id.
+      * `event_to_state_groups`: Maps events to state groups.
+      * `state_groups_state`: Maps state group to state events.
+    """
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+    def __init__(self, db_conn, hs):
+        super(StateStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+            self._background_deduplicate_state,
+        )
+        self.register_background_update_handler(
+            self.STATE_GROUP_INDEX_UPDATE_NAME,
+            self._background_index_state,
+        )
+        self.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
+
+    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.prev_group
+                continue
+
+            state_groups[event.event_id] = context.state_group
+
         self._simple_insert_many_txn(
             txn,
             table="event_to_state_groups",
@@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
 
             return count
 
-    def get_next_state_group(self):
-        return self._state_groups_id_gen.get_next()
-
     @defer.inlineCallbacks
     def _background_deduplicate_state(self, progress, batch_size):
         """This background update will slowly deduplicate state by reencoding
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index af65bfe7b8..bf3a66eae4 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -75,6 +75,7 @@ class Cache(object):
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
+            evicted_callback=self._on_evicted,
         )
 
         self.name = name
@@ -83,6 +84,9 @@ class Cache(object):
         self.thread = None
         self.metrics = register_cache(name, self.cache)
 
+    def _on_evicted(self, evicted_count):
+        self.metrics.inc_evictions(evicted_count)
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6ad53a6390..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
         while self._max_len and len(self) > self._max_len:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self._size_estimate -= len(value.value)
+                removed_len = len(value.value)
+                self.metrics.inc_evictions(removed_len)
+                self._size_estimate -= removed_len
+            else:
+                self.metrics.inc_evictions()
 
     def __getitem__(self, key):
         try:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..f088dd430e 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+                 evicted_callback=None):
+        """
+        Args:
+            max_size (int):
+
+            keylen (int):
+
+            cache_type (type):
+                type of underlying cache to be used. Typically one of dict
+                or TreeCache.
+
+            size_callback (func(V) -> int | None):
+
+            evicted_callback (func(int)|None):
+                if not None, called on eviction with the size of the evicted
+                entry
+        """
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
         def evict():
             while cache_len() > max_size:
                 todelete = list_root.prev_node
-                delete_node(todelete)
+                evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
+                if evicted_callback:
+                    evicted_callback(evicted_len)
 
         def synchronized(f):
             @wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            deleted_len = 1
             if size_callback:
-                cached_cache_len[0] -= size_callback(node.value)
+                deleted_len = size_callback(node.value)
+                cached_cache_len[0] -= deleted_len
 
             for cb in node.callbacks:
                 cb()
             node.callbacks.clear()
+            return deleted_len
 
         @synchronized
         def cache_get(key, default=None, callbacks=[]):