summary refs log tree commit diff
diff options
context:
space:
mode:
authorMark Haines <mjark@negativecurvature.net>2014-11-11 16:40:50 +0000
committerMark Haines <mjark@negativecurvature.net>2014-11-11 16:40:50 +0000
commita8ceeec0fd512e287cbf71efff42015787517a5d (patch)
tree45643674a31b637799e347f2251c72417e685616
parentno evil horizontal textarea resizing (diff)
parentFix bugs which broke federation due to changes in function signatures. (diff)
downloadsynapse-a8ceeec0fd512e287cbf71efff42015787517a5d.tar.xz
Merge pull request #12 from matrix-org/federation_authorization
Federation authorization
-rwxr-xr-xdemo/start.sh2
-rw-r--r--docs/server-server/signing.rst16
-rw-r--r--scripts/check_event_hash.py47
-rw-r--r--scripts/check_signature.py73
-rw-r--r--scripts/hash_history.py69
-rw-r--r--synapse/api/auth.py457
-rw-r--r--synapse/api/errors.py34
-rw-r--r--synapse/api/events/__init__.py77
-rw-r--r--synapse/api/events/factory.py29
-rw-r--r--synapse/api/events/room.py21
-rw-r--r--synapse/api/events/utils.py65
-rw-r--r--synapse/api/events/validator.py87
-rwxr-xr-xsynapse/app/homeserver.py5
-rw-r--r--synapse/config/server.py3
-rw-r--r--synapse/crypto/event_signing.py98
-rw-r--r--synapse/federation/pdu_codec.py60
-rw-r--r--synapse/federation/persistence.py73
-rw-r--r--synapse/federation/replication.py319
-rw-r--r--synapse/federation/transport.py320
-rw-r--r--synapse/federation/units.py84
-rw-r--r--synapse/handlers/_base.py64
-rw-r--r--synapse/handlers/directory.py10
-rw-r--r--synapse/handlers/federation.py497
-rw-r--r--synapse/handlers/message.py30
-rw-r--r--synapse/handlers/profile.py22
-rw-r--r--synapse/handlers/room.py118
-rw-r--r--synapse/rest/base.py7
-rw-r--r--synapse/rest/events.py34
-rw-r--r--synapse/rest/room.py15
-rw-r--r--synapse/server.py12
-rw-r--r--synapse/state.py315
-rw-r--r--synapse/storage/__init__.py282
-rw-r--r--synapse/storage/_base.py249
-rw-r--r--synapse/storage/directory.py1
-rw-r--r--synapse/storage/event_federation.py377
-rw-r--r--synapse/storage/pdu.py915
-rw-r--r--synapse/storage/registration.py13
-rw-r--r--synapse/storage/room.py220
-rw-r--r--synapse/storage/schema/edge_pdus.sql31
-rw-r--r--synapse/storage/schema/event_edges.sql75
-rw-r--r--synapse/storage/schema/event_signatures.sql65
-rw-r--r--synapse/storage/schema/im.sql69
-rw-r--r--synapse/storage/schema/pdu.sql106
-rw-r--r--synapse/storage/schema/state.sql33
-rw-r--r--synapse/storage/signatures.py177
-rw-r--r--synapse/storage/state.py96
-rw-r--r--synapse/storage/stream.py8
-rw-r--r--synapse/storage/transactions.py74
-rw-r--r--synapse/types.py10
-rw-r--r--synapse/util/async.py7
-rw-r--r--synapse/util/jsonobject.py2
-rw-r--r--tests/events/test_events.py55
-rw-r--r--tests/federation/test_federation.py236
-rw-r--r--tests/federation/test_pdu_codec.py160
-rw-r--r--tests/handlers/test_directory.py11
-rw-r--r--tests/handlers/test_federation.py84
-rw-r--r--tests/handlers/test_presence.py1
-rw-r--r--tests/handlers/test_presencelike.py6
-rw-r--r--tests/handlers/test_profile.py6
-rw-r--r--tests/handlers/test_room.py195
-rw-r--r--tests/handlers/test_typing.py1
-rw-r--r--tests/rest/test_events.py57
-rw-r--r--tests/rest/test_profile.py8
-rw-r--r--tests/rest/test_rooms.py209
-rw-r--r--tests/storage/test_base.py2
-rw-r--r--tests/storage/test_redaction.py16
-rw-r--r--tests/storage/test_room.py4
-rw-r--r--tests/storage/test_roommember.py26
-rw-r--r--tests/storage/test_stream.py41
-rw-r--r--tests/test_state.py693
-rw-r--r--tests/utils.py3
71 files changed, 3774 insertions, 3913 deletions
diff --git a/demo/start.sh b/demo/start.sh
index 8b0cc84fe6..886d21cfa8 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -32,7 +32,7 @@ for port in 8080 8081 8082; do
         -D --pid-file "$DIR/$port.pid" \
         --manhole $((port + 1000)) \
         --tls-dh-params-path "demo/demo.tls.dh" \
-		$PARAMS
+		$PARAMS $SYNAPSE_PARAMS
 
     python -m synapse.app.homeserver \
         --config-path "demo/etc/$port.config" \
diff --git a/docs/server-server/signing.rst b/docs/server-server/signing.rst
index dae10f121b..60c701ca91 100644
--- a/docs/server-server/signing.rst
+++ b/docs/server-server/signing.rst
@@ -1,13 +1,13 @@
 Signing JSON
 ============
 
-JSON is signed by encoding the JSON object without ``signatures`` or ``meta``
+JSON is signed by encoding the JSON object without ``signatures`` or ``unsigned``
 keys using a canonical encoding. The JSON bytes are then signed using the
 signature algorithm and the signature encoded using base64 with the padding
 stripped. The resulting base64 signature is added to an object under the
 *signing key identifier* which is added to the ``signatures`` object under the
 name of the server signing it which is added back to the original JSON object
-along with the ``meta`` object.
+along with the ``unsigned`` object.
 
 The *signing key identifier* is the concatenation of the *signing algorithm*
 and a *key version*. The *signing algorithm* identifies the algorithm used to
@@ -15,8 +15,8 @@ sign the JSON. The currently support value for *signing algorithm* is
 ``ed25519`` as implemented by NACL (http://nacl.cr.yp.to/). The *key version*
 is used to distinguish between different signing keys used by the same entity.
 
-The ``meta`` object and the ``signatures`` object are not covered by the
-signature. Therefore intermediate servers can add metadata such as time stamps
+The ``unsigned`` object and the ``signatures`` object are not covered by the
+signature. Therefore intermediate servers can add unsigneddata such as time stamps
 and additional signatures.
 
 
@@ -27,7 +27,7 @@ and additional signatures.
      "signing_keys": {
        "ed25519:1": "XSl0kuyvrXNj6A+7/tkrB9sxSbRi08Of5uRhxOqZtEQ"
      },
-     "meta": {
+     "unsigned": {
         "retrieved_ts_ms": 922834800000
      },
      "signatures": {
@@ -41,7 +41,7 @@ and additional signatures.
 
   def sign_json(json_object, signing_key, signing_name):
       signatures = json_object.pop("signatures", {})
-      meta = json_object.pop("meta", None)
+      unsigned = json_object.pop("unsigned", None)
 
       signed = signing_key.sign(encode_canonical_json(json_object))
       signature_base64 = encode_base64(signed.signature)
@@ -50,8 +50,8 @@ and additional signatures.
       signatures.setdefault(sigature_name, {})[key_id] = signature_base64
 
       json_object["signatures"] = signatures
-      if meta is not None:
-          json_object["meta"] = meta
+      if unsigned is not None:
+          json_object["unsigned"] = unsigned
 
       return json_object
 
diff --git a/scripts/check_event_hash.py b/scripts/check_event_hash.py
new file mode 100644
index 0000000000..7c32f8102a
--- /dev/null
+++ b/scripts/check_event_hash.py
@@ -0,0 +1,47 @@
+from synapse.crypto.event_signing import *
+from syutil.base64util import encode_base64
+
+import argparse
+import hashlib
+import sys
+import json
+
+
+class dictobj(dict):
+    def __init__(self, *args, **kargs):
+        dict.__init__(self, *args, **kargs)
+        self.__dict__ = self
+
+    def get_dict(self):
+        return dict(self)
+
+    def get_full_dict(self):
+        return dict(self)
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'),
+                        default=sys.stdin)
+    args = parser.parse_args()
+    logging.basicConfig()
+
+    event_json = dictobj(json.load(args.input_json))
+
+    algorithms = {
+        "sha256": hashlib.sha256,
+    }
+
+    for alg_name in event_json.hashes:
+        if check_event_content_hash(event_json, algorithms[alg_name]):
+            print "PASS content hash %s" % (alg_name,)
+        else:
+            print "FAIL content hash %s" % (alg_name,)
+
+    for algorithm in algorithms.values():
+        name, h_bytes = compute_event_reference_hash(event_json, algorithm)
+        print "Reference hash %s: %s" % (name, encode_base64(h_bytes))
+
+if __name__=="__main__":
+    main()
+
diff --git a/scripts/check_signature.py b/scripts/check_signature.py
new file mode 100644
index 0000000000..e146e18e24
--- /dev/null
+++ b/scripts/check_signature.py
@@ -0,0 +1,73 @@
+
+from syutil.crypto.jsonsign import verify_signed_json
+from syutil.crypto.signing_key import (
+    decode_verify_key_bytes, write_signing_keys
+)
+from syutil.base64util import decode_base64
+
+import urllib2
+import json
+import sys
+import dns.resolver
+import pprint
+import argparse
+import logging
+
+def get_targets(server_name):
+    if ":" in server_name:
+        target, port = server_name.split(":")
+        yield (target, int(port))
+        return
+    try:
+        answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
+        for srv in answers:
+            yield (srv.target, srv.port)
+    except dns.resolver.NXDOMAIN:
+        yield (server_name, 8480)
+
+def get_server_keys(server_name, target, port):
+    url = "https://%s:%i/_matrix/key/v1" % (target, port)
+    keys = json.load(urllib2.urlopen(url))
+    verify_keys = {}
+    for key_id, key_base64 in keys["verify_keys"].items():
+        verify_key = decode_verify_key_bytes(key_id, decode_base64(key_base64))
+        verify_signed_json(keys, server_name, verify_key)
+        verify_keys[key_id] = verify_key
+    return verify_keys
+
+def main():
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("signature_name")
+    parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'),
+                        default=sys.stdin)
+
+    args = parser.parse_args()
+    logging.basicConfig()
+
+    server_name = args.signature_name
+    keys = {}
+    for target, port in get_targets(server_name):
+        try:
+            keys = get_server_keys(server_name, target, port)
+            print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port)
+            write_signing_keys(sys.stdout, keys.values())
+            break
+        except:
+            logging.exception("Error talking to %s:%s", target, port)
+
+    json_to_check = json.load(args.input_json)
+    print "Checking JSON:"
+    for key_id in json_to_check["signatures"][args.signature_name]:
+        try:
+            key = keys[key_id]
+            verify_signed_json(json_to_check, args.signature_name, key)
+            print "PASS %s" % (key_id,)
+        except:
+            logging.exception("Check for key %s failed" % (key_id,))
+            print "FAIL %s" % (key_id,)
+
+
+if __name__ == '__main__':
+    main()
+
diff --git a/scripts/hash_history.py b/scripts/hash_history.py
new file mode 100644
index 0000000000..bdad530af8
--- /dev/null
+++ b/scripts/hash_history.py
@@ -0,0 +1,69 @@
+from synapse.storage.pdu import PduStore
+from synapse.storage.signatures import SignatureStore
+from synapse.storage._base import SQLBaseStore
+from synapse.federation.units import Pdu
+from synapse.crypto.event_signing import (
+    add_event_pdu_content_hash, compute_pdu_event_reference_hash
+)
+from synapse.api.events.utils import prune_pdu
+from syutil.base64util import encode_base64, decode_base64
+from syutil.jsonutil import encode_canonical_json
+import sqlite3
+import sys
+
+class Store(object):
+    _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
+    _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
+    _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
+    _get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"]
+    _store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"]
+    _store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"]
+    _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"]
+    _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
+
+
+store = Store()
+
+
+def select_pdus(cursor):
+    cursor.execute(
+        "SELECT pdu_id, origin FROM pdus ORDER BY depth ASC"
+    )
+
+    ids = cursor.fetchall()
+
+    pdu_tuples = store._get_pdu_tuples(cursor, ids)
+
+    pdus = [Pdu.from_pdu_tuple(p) for p in pdu_tuples]
+
+    reference_hashes = {}
+
+    for pdu in pdus:
+        try:
+            if pdu.prev_pdus:
+                print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
+                for pdu_id, origin, hashes in pdu.prev_pdus:
+                    ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)]
+                    hashes[ref_alg] = encode_base64(ref_hsh)
+                    store._store_prev_pdu_hash_txn(cursor,  pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh)
+                print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
+            pdu = add_event_pdu_content_hash(pdu)
+            ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu)
+            reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh)
+            store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh)
+
+            for alg, hsh_base64 in pdu.hashes.items():
+                print alg, hsh_base64
+                store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64))
+
+        except:
+            print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus
+
+def main():
+    conn = sqlite3.connect(sys.argv[1])
+    cursor = conn.cursor()
+    select_pdus(cursor)
+    conn.commit()
+
+if __name__=='__main__':
+    main()
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e1b1823cd7..6c2d3db26e 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -21,8 +21,10 @@ from synapse.api.constants import Membership, JoinRules
 from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
 from synapse.api.events.room import (
     RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
+    RoomJoinRulesEvent, RoomCreateEvent,
 )
 from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
 
 import logging
 
@@ -35,8 +37,7 @@ class Auth(object):
         self.hs = hs
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
-    def check(self, event, snapshot, raises=False):
+    def check(self, event, raises=False):
         """ Checks if this event is correctly authed.
 
         Returns:
@@ -47,43 +48,51 @@ class Auth(object):
         """
         try:
             if hasattr(event, "room_id"):
-                is_state = hasattr(event, "state_key")
+                if event.old_state_events is None:
+                    # Oh, we don't know what the state of the room was, so we
+                    # are trusting that this is allowed (at least for now)
+                    logger.warn("Trusting event: %s", event.event_id)
+                    return True
+
+                if hasattr(event, "outlier") and event.outlier is True:
+                    # TODO (erikj): Auth for outliers is done differently.
+                    return True
+
+                if event.type == RoomCreateEvent.TYPE:
+                    # FIXME
+                    return True
 
                 if event.type == RoomMemberEvent.TYPE:
-                    yield self._can_replace_state(event)
-                    allowed = yield self.is_membership_change_allowed(event)
-                    defer.returnValue(allowed)
-                    return
-
-                self._check_joined_room(
-                    member=snapshot.membership_state,
-                    user_id=snapshot.user_id,
-                    room_id=snapshot.room_id,
-                )
+                    allowed = self.is_membership_change_allowed(event)
+                    if allowed:
+                        logger.debug("Allowing! %s", event)
+                    else:
+                        logger.debug("Denying! %s", event)
+                    return allowed
 
-                if is_state:
-                    # TODO (erikj): This really only should be called for *new*
-                    # state
-                    yield self._can_add_state(event)
-                    yield self._can_replace_state(event)
-                else:
-                    yield self._can_send_event(event)
+                self.check_event_sender_in_room(event)
+                self._can_send_event(event)
 
                 if event.type == RoomPowerLevelsEvent.TYPE:
-                    yield self._check_power_levels(event)
+                    self._check_power_levels(event)
 
                 if event.type == RoomRedactionEvent.TYPE:
-                    yield self._check_redaction(event)
+                    self._check_redaction(event)
 
-                defer.returnValue(True)
+                logger.debug("Allowing! %s", event)
+                return True
             else:
                 raise AuthError(500, "Unknown event: %s" % event)
         except AuthError as e:
-            logger.info("Event auth check failed on event %s with msg: %s",
-                        event, e.msg)
+            logger.info(
+                "Event auth check failed on event %s with msg: %s",
+                event, e.msg
+            )
+            logger.info("Denying! %s", event)
             if raises:
                 raise e
-        defer.returnValue(False)
+
+        return False
 
     @defer.inlineCallbacks
     def check_joined_room(self, room_id, user_id):
@@ -98,45 +107,80 @@ class Auth(object):
             pass
         defer.returnValue(None)
 
+    @defer.inlineCallbacks
+    def check_host_in_room(self, room_id, host):
+        joined_hosts = yield self.store.get_joined_hosts_for_room(room_id)
+
+        defer.returnValue(host in joined_hosts)
+
+    def check_event_sender_in_room(self, event):
+        key = (RoomMemberEvent.TYPE, event.user_id, )
+        member_event = event.state_events.get(key)
+
+        return self._check_joined_room(
+            member_event,
+            event.user_id,
+            event.room_id
+        )
+
     def _check_joined_room(self, member, user_id, room_id):
         if not member or member.membership != Membership.JOIN:
             raise AuthError(403, "User %s not in room %s (%s)" % (
                 user_id, room_id, repr(member)
             ))
 
-    @defer.inlineCallbacks
+    @log_function
     def is_membership_change_allowed(self, event):
         target_user_id = event.state_key
 
-        # does this room even exist
-        room = yield self.store.get_room(event.room_id)
-        if not room:
-            raise AuthError(403, "Room does not exist")
-
         # get info about the caller
-        try:
-            caller = yield self.store.get_room_member(
-                user_id=event.user_id,
-                room_id=event.room_id)
-        except:
-            caller = None
-        caller_in_room = caller and caller.membership == "join"
+        key = (RoomMemberEvent.TYPE, event.user_id, )
+        caller = event.old_state_events.get(key)
+
+        caller_in_room = caller and caller.membership == Membership.JOIN
+        caller_invited = caller and caller.membership == Membership.INVITE
 
         # get info about the target
-        try:
-            target = yield self.store.get_room_member(
-                user_id=target_user_id,
-                room_id=event.room_id)
-        except:
-            target = None
-        target_in_room = target and target.membership == "join"
+        key = (RoomMemberEvent.TYPE, target_user_id, )
+        target = event.old_state_events.get(key)
+
+        target_in_room = target and target.membership == Membership.JOIN
 
         membership = event.content["membership"]
 
-        join_rule = yield self.store.get_room_join_rule(event.room_id)
-        if not join_rule:
+        key = (RoomJoinRulesEvent.TYPE, "", )
+        join_rule_event = event.old_state_events.get(key)
+        if join_rule_event:
+            join_rule = join_rule_event.content.get(
+                "join_rule", JoinRules.INVITE
+            )
+        else:
             join_rule = JoinRules.INVITE
 
+        user_level = self._get_power_level_from_event_state(
+            event,
+            event.user_id,
+        )
+
+        ban_level, kick_level, redact_level = (
+            self._get_ops_level_from_event_state(
+                event
+            )
+        )
+
+        logger.debug(
+            "is_membership_change_allowed: %s",
+            {
+                "caller_in_room": caller_in_room,
+                "caller_invited": caller_invited,
+                "target_in_room": target_in_room,
+                "membership": membership,
+                "join_rule": join_rule,
+                "target_user_id": target_user_id,
+                "event.user_id": event.user_id,
+            }
+        )
+
         if Membership.INVITE == membership:
             # TODO (erikj): We should probably handle this more intelligently
             # PRIVATE join rules.
@@ -153,13 +197,10 @@ class Auth(object):
             # joined: It's a NOOP
             if event.user_id != target_user_id:
                 raise AuthError(403, "Cannot force another user to join.")
-            elif join_rule == JoinRules.PUBLIC or room.is_public:
+            elif join_rule == JoinRules.PUBLIC:
                 pass
             elif join_rule == JoinRules.INVITE:
-                if (
-                    not caller or caller.membership not in
-                    [Membership.INVITE, Membership.JOIN]
-                ):
+                if not caller_in_room and not caller_invited:
                     raise AuthError(403, "You are not invited to this room.")
             else:
                 # TODO (erikj): may_join list
@@ -171,29 +212,16 @@ class Auth(object):
             if not caller_in_room:  # trying to leave a room you aren't joined
                 raise AuthError(403, "You are not in room %s." % event.room_id)
             elif target_user_id != event.user_id:
-                user_level = yield self.store.get_power_level(
-                    event.room_id,
-                    event.user_id,
-                )
-                _, kick_level, _ = yield self.store.get_ops_levels(event.room_id)
-
                 if kick_level:
                     kick_level = int(kick_level)
                 else:
-                    kick_level = 50
+                    kick_level = 50  # FIXME (erikj): What should we do here?
 
                 if user_level < kick_level:
                     raise AuthError(
                         403, "You cannot kick user %s." % target_user_id
                     )
         elif Membership.BAN == membership:
-            user_level = yield self.store.get_power_level(
-                event.room_id,
-                event.user_id,
-            )
-
-            ban_level, _, _  = yield self.store.get_ops_levels(event.room_id)
-
             if ban_level:
                 ban_level = int(ban_level)
             else:
@@ -204,7 +232,30 @@ class Auth(object):
         else:
             raise AuthError(500, "Unknown membership %s" % membership)
 
-        defer.returnValue(True)
+        return True
+
+    def _get_power_level_from_event_state(self, event, user_id):
+        key = (RoomPowerLevelsEvent.TYPE, "", )
+        power_level_event = event.old_state_events.get(key)
+        level = None
+        if power_level_event:
+            level = power_level_event.content.get("users", {}).get(user_id)
+            if not level:
+                level = power_level_event.content.get("users_default", 0)
+
+        return level
+
+    def _get_ops_level_from_event_state(self, event):
+        key = (RoomPowerLevelsEvent.TYPE, "", )
+        power_level_event = event.old_state_events.get(key)
+
+        if power_level_event:
+            return (
+                power_level_event.content.get("ban", 50),
+                power_level_event.content.get("kick", 50),
+                power_level_event.content.get("redact", 50),
+            )
+        return None, None, None,
 
     @defer.inlineCallbacks
     def get_user_by_req(self, request):
@@ -229,7 +280,7 @@ class Auth(object):
                 default=[""]
             )[0]
             if user and access_token and ip_addr:
-                self.store.insert_client_ip(
+                yield self.store.insert_client_ip(
                     user=user,
                     access_token=access_token,
                     device_id=user_info["device_id"],
@@ -273,68 +324,81 @@ class Auth(object):
         return self.store.is_server_admin(user)
 
     @defer.inlineCallbacks
-    @log_function
-    def _can_send_event(self, event):
-        send_level = yield self.store.get_send_event_level(event.room_id)
-
-        if send_level:
-            send_level = int(send_level)
-        else:
-            send_level = 0
-
-        user_level = yield self.store.get_power_level(
-            event.room_id,
-            event.user_id,
-        )
-
-        if user_level:
-            user_level = int(user_level)
-        else:
-            user_level = 0
+    def add_auth_events(self, event):
+        if event.type == RoomCreateEvent.TYPE:
+            event.auth_events = []
+            return
 
-        if user_level < send_level:
-            raise AuthError(
-                403, "You don't have permission to post to the room"
-            )
+        auth_events = []
 
-        defer.returnValue(True)
+        key = (RoomPowerLevelsEvent.TYPE, "", )
+        power_level_event = event.old_state_events.get(key)
 
-    @defer.inlineCallbacks
-    def _can_add_state(self, event):
-        add_level = yield self.store.get_add_state_level(event.room_id)
+        if power_level_event:
+            auth_events.append(power_level_event.event_id)
 
-        if not add_level:
-            defer.returnValue(True)
+        key = (RoomJoinRulesEvent.TYPE, "", )
+        join_rule_event = event.old_state_events.get(key)
 
-        add_level = int(add_level)
+        key = (RoomMemberEvent.TYPE, event.user_id, )
+        member_event = event.old_state_events.get(key)
 
-        user_level = yield self.store.get_power_level(
-            event.room_id,
-            event.user_id,
+        if join_rule_event:
+            join_rule = join_rule_event.content.get("join_rule")
+            is_public = join_rule == JoinRules.PUBLIC if join_rule else False
+        else:
+            is_public = False
+
+        if event.type == RoomMemberEvent.TYPE:
+            e_type = event.content["membership"]
+            if e_type in [Membership.JOIN, Membership.INVITE]:
+                if join_rule_event:
+                    auth_events.append(join_rule_event.event_id)
+
+                if member_event and not is_public:
+                    auth_events.append(member_event.event_id)
+        elif member_event:
+            if member_event.content["membership"] == Membership.JOIN:
+                auth_events.append(member_event.event_id)
+
+        hashes = yield self.store.get_event_reference_hashes(
+            auth_events
         )
+        hashes = [
+            {
+                k: encode_base64(v) for k, v in h.items()
+                if k == "sha256"
+            }
+            for h in hashes
+        ]
+        event.auth_events = zip(auth_events, hashes)
 
-        user_level = int(user_level)
-
-        if user_level < add_level:
-            raise AuthError(
-                403, "You don't have permission to add state to the room"
+    @log_function
+    def _can_send_event(self, event):
+        key = (RoomPowerLevelsEvent.TYPE, "", )
+        send_level_event = event.old_state_events.get(key)
+        send_level = None
+        if send_level_event:
+            send_level = send_level_event.content.get("events", {}).get(
+                event.type
             )
+            if not send_level:
+                if hasattr(event, "state_key"):
+                    send_level = send_level_event.content.get(
+                        "state_default", 50
+                    )
+                else:
+                    send_level = send_level_event.content.get(
+                        "events_default", 0
+                    )
 
-        defer.returnValue(True)
-
-    @defer.inlineCallbacks
-    def _can_replace_state(self, event):
-        current_state = yield self.store.get_current_state(
-            event.room_id,
-            event.type,
-            event.state_key,
-        )
-
-        if current_state:
-            current_state = current_state[0]
+        if send_level:
+            send_level = int(send_level)
+        else:
+            send_level = 0
 
-        user_level = yield self.store.get_power_level(
-            event.room_id,
+        user_level = self._get_power_level_from_event_state(
+            event,
             event.user_id,
         )
 
@@ -343,35 +407,24 @@ class Auth(object):
         else:
             user_level = 0
 
-        logger.debug(
-            "Checking power level for %s, %s", event.user_id, user_level
-        )
-        if current_state and hasattr(current_state, "required_power_level"):
-            req = current_state.required_power_level
+        if user_level < send_level:
+            raise AuthError(
+                403,
+                "You don't have permission to post that to the room. " +
+                "user_level (%d) < send_level (%d)" % (user_level, send_level)
+            )
 
-            logger.debug("Checked power level for %s, %s", event.user_id, req)
-            if user_level < req:
-                raise AuthError(
-                    403,
-                    "You don't have permission to change that state"
-                )
+        return True
 
-    @defer.inlineCallbacks
     def _check_redaction(self, event):
-        user_level = yield self.store.get_power_level(
-            event.room_id,
+        user_level = self._get_power_level_from_event_state(
+            event,
             event.user_id,
         )
 
-        if user_level:
-            user_level = int(user_level)
-        else:
-            user_level = 0
-
-        _, _, redact_level  = yield self.store.get_ops_levels(event.room_id)
-
-        if not redact_level:
-            redact_level = 50
+        _, _, redact_level = self._get_ops_level_from_event_state(
+            event
+        )
 
         if user_level < redact_level:
             raise AuthError(
@@ -379,16 +432,10 @@ class Auth(object):
                 "You don't have permission to redact events"
             )
 
-    @defer.inlineCallbacks
     def _check_power_levels(self, event):
-        for k, v in event.content.items():
-            if k == "default":
-                continue
-
-            # FIXME (erikj): We don't want hsob_Ts in content.
-            if k == "hsob_ts":
-                continue
-
+        user_list = event.content.get("users", {})
+        # Validate users
+        for k, v in user_list.items():
             try:
                 self.hs.parse_userid(k)
             except:
@@ -399,80 +446,68 @@ class Auth(object):
             except:
                 raise SynapseError(400, "Not a valid power level: %s" % (v,))
 
-        current_state = yield self.store.get_current_state(
-            event.room_id,
-            event.type,
-            event.state_key,
-        )
+        key = (event.type, event.state_key, )
+        current_state = event.old_state_events.get(key)
 
         if not current_state:
             return
-        else:
-            current_state = current_state[0]
 
-        user_level = yield self.store.get_power_level(
-            event.room_id,
+        user_level = self._get_power_level_from_event_state(
+            event,
             event.user_id,
         )
 
-        if user_level:
-            user_level = int(user_level)
-        else:
-            user_level = 0
+        # Check other levels:
+        levels_to_check = [
+            ("users_default", []),
+            ("events_default", []),
+            ("ban", []),
+            ("redact", []),
+            ("kick", []),
+        ]
+
+        old_list = current_state.content.get("users")
+        for user in set(old_list.keys() + user_list.keys()):
+            levels_to_check.append(
+                (user, ["users"])
+            )
 
-        old_list = current_state.content
+        old_list = current_state.content.get("events")
+        new_list = event.content.get("events")
+        for ev_id in set(old_list.keys() + new_list.keys()):
+            levels_to_check.append(
+                (ev_id, ["events"])
+            )
 
-        # FIXME (erikj)
-        old_people = {k: v for k, v in old_list.items() if k.startswith("@")}
-        new_people = {
-            k: v for k, v in event.content.items()
-            if k.startswith("@")
-        }
+        old_state = current_state.content
+        new_state = event.content
 
-        removed = set(old_people.keys()) - set(new_people.keys())
-        added = set(new_people.keys()) - set(old_people.keys())
-        same = set(old_people.keys()) & set(new_people.keys())
+        for level_to_check, dir in levels_to_check:
+            old_loc = old_state
+            for d in dir:
+                old_loc = old_loc.get(d, {})
 
-        for r in removed:
-            if int(old_list[r]) > user_level:
-                raise AuthError(
-                    403,
-                    "You don't have permission to remove user: %s" % (r, )
-                )
+            new_loc = new_state
+            for d in dir:
+                new_loc = new_loc.get(d, {})
 
-        for n in added:
-            if int(event.content[n]) > user_level:
-                raise AuthError(
-                    403,
-                    "You don't have permission to add ops level greater "
-                    "than your own"
-                )
+            if level_to_check in old_loc:
+                old_level = int(old_loc[level_to_check])
+            else:
+                old_level = None
 
-        for s in same:
-            if int(event.content[s]) != int(old_list[s]):
-                if int(event.content[s]) > user_level:
-                    raise AuthError(
-                        403,
-                        "You don't have permission to add ops level greater "
-                        "than your own"
-                    )
+            if level_to_check in new_loc:
+                new_level = int(new_loc[level_to_check])
+            else:
+                new_level = None
 
-        if "default" in old_list:
-            old_default = int(old_list["default"])
+            if new_level is not None and old_level is not None:
+                if new_level == old_level:
+                    continue
 
-            if old_default > user_level:
+            if old_level > user_level or new_level > user_level:
                 raise AuthError(
                     403,
-                    "You don't have permission to add ops level greater than "
-                    "your own"
+                    "You don't have permission to add ops level greater "
+                    "than your own"
                 )
-
-            if "default" in event.content:
-                new_default = int(event.content["default"])
-
-                if new_default > user_level:
-                    raise AuthError(
-                        403,
-                        "You don't have permission to add ops level greater "
-                        "than your own"
-                    )
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 38ccb4f9d1..33d15072af 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -158,3 +158,37 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
     for key, value in kwargs.iteritems():
         err[key] = value
     return err
+
+
+class FederationError(RuntimeError):
+    """  This class is used to inform remote home servers about erroneous
+    PDUs they sent us.
+
+    FATAL: The remote server could not interpret the source event.
+        (e.g., it was missing a required field)
+    ERROR: The remote server interpreted the event, but it failed some other
+        check (e.g. auth)
+    WARN: The remote server accepted the event, but believes some part of it
+        is wrong (e.g., it referred to an invalid event)
+    """
+
+    def __init__(self, level, code, reason, affected, source=None):
+        if level not in ["FATAL", "ERROR", "WARN"]:
+            raise ValueError("Level is not valid: %s" % (level,))
+        self.level = level
+        self.code = code
+        self.reason = reason
+        self.affected = affected
+        self.source = source
+
+        msg = "%s %s: %s" % (level, code, reason,)
+        super(FederationError, self).__init__(msg)
+
+    def get_dict(self):
+        return {
+            "level": self.level,
+            "code": self.code,
+            "reason": self.reason,
+            "affected": self.affected,
+            "source": self.source if self.source else self.affected,
+        }
diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py
index f66fea2904..1d8bed2906 100644
--- a/synapse/api/events/__init__.py
+++ b/synapse/api/events/__init__.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.api.errors import SynapseError, Codes
 from synapse.util.jsonobject import JsonEncodedObject
 
 
@@ -56,22 +55,26 @@ class SynapseEvent(JsonEncodedObject):
         "user_id",  # sender/initiator
         "content",  # HTTP body, JSON
         "state_key",
-        "required_power_level",
         "age_ts",
         "prev_content",
-        "prev_state",
+        "replaces_state",
         "redacted_because",
+        "origin_server_ts",
     ]
 
     internal_keys = [
         "is_state",
-        "prev_events",
         "depth",
         "destinations",
         "origin",
         "outlier",
-        "power_level",
         "redacted",
+        "prev_events",
+        "hashes",
+        "signatures",
+        "prev_state",
+        "auth_events",
+        "state_hash",
     ]
 
     required_keys = [
@@ -82,8 +85,8 @@ class SynapseEvent(JsonEncodedObject):
 
     def __init__(self, raises=True, **kwargs):
         super(SynapseEvent, self).__init__(**kwargs)
-        if "content" in kwargs:
-            self.check_json(self.content, raises=raises)
+        # if "content" in kwargs:
+        #     self.check_json(self.content, raises=raises)
 
     def get_content_template(self):
         """ Retrieve the JSON template for this event as a dict.
@@ -114,66 +117,6 @@ class SynapseEvent(JsonEncodedObject):
         """
         raise NotImplementedError("get_content_template not implemented.")
 
-    def check_json(self, content, raises=True):
-        """Checks the given JSON content abides by the rules of the template.
-
-        Args:
-            content : A JSON object to check.
-            raises: True to raise a SynapseError if the check fails.
-        Returns:
-            True if the content passes the template. Returns False if the check
-            fails and raises=False.
-        Raises:
-            SynapseError if the check fails and raises=True.
-        """
-        # recursively call to inspect each layer
-        err_msg = self._check_json(content, self.get_content_template())
-        if err_msg:
-            if raises:
-                raise SynapseError(400, err_msg, Codes.BAD_JSON)
-            else:
-                return False
-        else:
-            return True
-
-    def _check_json(self, content, template):
-        """Check content and template matches.
-
-        If the template is a dict, each key in the dict will be validated with
-        the content, else it will just compare the types of content and
-        template. This basic type check is required because this function will
-        be recursively called and could be called with just strs or ints.
-
-        Args:
-            content: The content to validate.
-            template: The validation template.
-        Returns:
-            str: An error message if the validation fails, else None.
-        """
-        if type(content) != type(template):
-            return "Mismatched types: %s" % template
-
-        if type(template) == dict:
-            for key in template:
-                if key not in content:
-                    return "Missing %s key" % key
-
-                if type(content[key]) != type(template[key]):
-                    return "Key %s is of the wrong type (got %s, want %s)" % (
-                        key, type(content[key]), type(template[key]))
-
-                if type(content[key]) == dict:
-                    # we must go deeper
-                    msg = self._check_json(content[key], template[key])
-                    if msg:
-                        return msg
-                elif type(content[key]) == list:
-                    # make sure each item type in content matches the template
-                    for entry in content[key]:
-                        msg = self._check_json(entry, template[key][0])
-                        if msg:
-                            return msg
-
 
 class SynapseStateEvent(SynapseEvent):
 
diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py
index 74d0ef77f4..a1ec708a81 100644
--- a/synapse/api/events/factory.py
+++ b/synapse/api/events/factory.py
@@ -16,11 +16,13 @@
 from synapse.api.events.room import (
     RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
     InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent,
-    RoomPowerLevelsEvent, RoomJoinRulesEvent, RoomOpsPowerLevelsEvent,
-    RoomCreateEvent, RoomAddStateLevelEvent, RoomSendEventLevelEvent,
+    RoomPowerLevelsEvent, RoomJoinRulesEvent,
+    RoomCreateEvent,
     RoomRedactionEvent,
 )
 
+from synapse.types import EventID
+
 from synapse.util.stringutils import random_string
 
 
@@ -37,9 +39,6 @@ class EventFactory(object):
         RoomPowerLevelsEvent,
         RoomJoinRulesEvent,
         RoomCreateEvent,
-        RoomAddStateLevelEvent,
-        RoomSendEventLevelEvent,
-        RoomOpsPowerLevelsEvent,
         RoomRedactionEvent,
     ]
 
@@ -51,12 +50,26 @@ class EventFactory(object):
         self.clock = hs.get_clock()
         self.hs = hs
 
+        self.event_id_count = 0
+
+    def create_event_id(self):
+        i = str(self.event_id_count)
+        self.event_id_count += 1
+
+        local_part = str(int(self.clock.time())) + i + random_string(5)
+
+        e_id = EventID.create_local(local_part, self.hs)
+
+        return e_id.to_string()
+
     def create_event(self, etype=None, **kwargs):
         kwargs["type"] = etype
         if "event_id" not in kwargs:
-            kwargs["event_id"] = "%s@%s" % (
-                random_string(10), self.hs.hostname
-            )
+            kwargs["event_id"] = self.create_event_id()
+            kwargs["origin"] = self.hs.hostname
+        else:
+            ev_id = self.hs.parse_eventid(kwargs["event_id"])
+            kwargs["origin"] = ev_id.domain
 
         if "origin_server_ts" not in kwargs:
             kwargs["origin_server_ts"] = int(self.clock.time_msec())
diff --git a/synapse/api/events/room.py b/synapse/api/events/room.py
index cd936074fc..8c4ac45d02 100644
--- a/synapse/api/events/room.py
+++ b/synapse/api/events/room.py
@@ -154,27 +154,6 @@ class RoomPowerLevelsEvent(SynapseStateEvent):
         return {}
 
 
-class RoomAddStateLevelEvent(SynapseStateEvent):
-    TYPE = "m.room.add_state_level"
-
-    def get_content_template(self):
-        return {}
-
-
-class RoomSendEventLevelEvent(SynapseStateEvent):
-    TYPE = "m.room.send_event_level"
-
-    def get_content_template(self):
-        return {}
-
-
-class RoomOpsPowerLevelsEvent(SynapseStateEvent):
-    TYPE = "m.room.ops_levels"
-
-    def get_content_template(self):
-        return {}
-
-
 class RoomAliasesEvent(SynapseStateEvent):
     TYPE = "m.room.aliases"
 
diff --git a/synapse/api/events/utils.py b/synapse/api/events/utils.py
index c3a32be8c1..802648f8f7 100644
--- a/synapse/api/events/utils.py
+++ b/synapse/api/events/utils.py
@@ -15,21 +15,34 @@
 
 from .room import (
     RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent,
-    RoomAddStateLevelEvent, RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent,
     RoomAliasesEvent, RoomCreateEvent,
 )
 
+
 def prune_event(event):
-    """ Prunes the given event of all keys we don't know about or think could
-    potentially be dodgy.
+    """ Returns a pruned version of the given event, which removes all keys we
+    don't know about or think could potentially be dodgy.
 
     This is used when we "redact" an event. We want to remove all fields that
     the user has specified, but we do want to keep necessary information like
     type, state_key etc.
     """
+    event_type = event.type
 
-    # Remove all extraneous fields.
-    event.unrecognized_keys = {}
+    allowed_keys = [
+        "event_id",
+        "user_id",
+        "room_id",
+        "hashes",
+        "signatures",
+        "content",
+        "type",
+        "state_key",
+        "depth",
+        "prev_events",
+        "prev_state",
+        "auth_events",
+    ]
 
     new_content = {}
 
@@ -38,27 +51,33 @@ def prune_event(event):
             if field in event.content:
                 new_content[field] = event.content[field]
 
-    if event.type == RoomMemberEvent.TYPE:
+    if event_type == RoomMemberEvent.TYPE:
         add_fields("membership")
-    elif event.type == RoomCreateEvent.TYPE:
+    elif event_type == RoomCreateEvent.TYPE:
         add_fields("creator")
-    elif event.type == RoomJoinRulesEvent.TYPE:
+    elif event_type == RoomJoinRulesEvent.TYPE:
         add_fields("join_rule")
-    elif event.type == RoomPowerLevelsEvent.TYPE:
-        # TODO: Actually check these are valid user_ids etc.
-        add_fields("default")
-        for k, v in event.content.items():
-            if k.startswith("@") and isinstance(v, (int, long)):
-                new_content[k] = v
-    elif event.type == RoomAddStateLevelEvent.TYPE:
-        add_fields("level")
-    elif event.type == RoomSendEventLevelEvent.TYPE:
-        add_fields("level")
-    elif event.type == RoomOpsPowerLevelsEvent.TYPE:
-        add_fields("kick_level", "ban_level", "redact_level")
-    elif event.type == RoomAliasesEvent.TYPE:
+    elif event_type == RoomPowerLevelsEvent.TYPE:
+        add_fields(
+            "users",
+            "users_default",
+            "events",
+            "events_default",
+            "events_default",
+            "state_default",
+            "ban",
+            "kick",
+            "redact",
+        )
+    elif event_type == RoomAliasesEvent.TYPE:
         add_fields("aliases")
 
-    event.content = new_content
+    allowed_fields = {
+        k: v
+        for k, v in event.get_full_dict().items()
+        if k in allowed_keys
+    }
+
+    allowed_fields["content"] = new_content
 
-    return event
+    return type(event)(**allowed_fields)
diff --git a/synapse/api/events/validator.py b/synapse/api/events/validator.py
new file mode 100644
index 0000000000..2d4f2a3aa7
--- /dev/null
+++ b/synapse/api/events/validator.py
@@ -0,0 +1,87 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.errors import SynapseError, Codes
+
+
+class EventValidator(object):
+    def __init__(self, hs):
+        pass
+
+    def validate(self, event):
+        """Checks the given JSON content abides by the rules of the template.
+
+        Args:
+            content : A JSON object to check.
+            raises: True to raise a SynapseError if the check fails.
+        Returns:
+            True if the content passes the template. Returns False if the check
+            fails and raises=False.
+        Raises:
+            SynapseError if the check fails and raises=True.
+        """
+        # recursively call to inspect each layer
+        err_msg = self._check_json_template(
+            event.content,
+            event.get_content_template()
+        )
+        if err_msg:
+            raise SynapseError(400, err_msg, Codes.BAD_JSON)
+        else:
+            return True
+
+    def _check_json_template(self, content, template):
+        """Check content and template matches.
+
+        If the template is a dict, each key in the dict will be validated with
+        the content, else it will just compare the types of content and
+        template. This basic type check is required because this function will
+        be recursively called and could be called with just strs or ints.
+
+        Args:
+            content: The content to validate.
+            template: The validation template.
+        Returns:
+            str: An error message if the validation fails, else None.
+        """
+        if type(content) != type(template):
+            return "Mismatched types: %s" % template
+
+        if type(template) == dict:
+            for key in template:
+                if key not in content:
+                    return "Missing %s key" % key
+
+                if type(content[key]) != type(template[key]):
+                    return "Key %s is of the wrong type (got %s, want %s)" % (
+                        key, type(content[key]), type(template[key]))
+
+                if type(content[key]) == dict:
+                    # we must go deeper
+                    msg = self._check_json_template(
+                        content[key],
+                        template[key]
+                    )
+                    if msg:
+                        return msg
+                elif type(content[key]) == list:
+                    # make sure each item type in content matches the template
+                    for entry in content[key]:
+                        msg = self._check_json_template(
+                            entry,
+                            template[key][0]
+                        )
+                        if msg:
+                            return msg
\ No newline at end of file
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index b3dae5da64..43164c8d67 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -236,7 +236,10 @@ def setup():
         f.namespace['hs'] = hs
         reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
 
-    hs.start_listening(config.bind_port, config.unsecure_port)
+    bind_port = config.bind_port
+    if config.no_tls:
+        bind_port = None
+    hs.start_listening(bind_port, config.unsecure_port)
 
     if config.daemonize:
         print config.pid_file
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 3afda12d5a..814a4c349b 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -30,6 +30,7 @@ class ServerConfig(Config):
         self.pid_file = self.abspath(args.pid_file)
         self.webclient = True
         self.manhole = args.manhole
+        self.no_tls = args.no_tls
 
         if not args.content_addr:
             host = args.server_name
@@ -67,6 +68,8 @@ class ServerConfig(Config):
         server_group.add_argument("--content-addr", default=None,
                                   help="The host and scheme to use for the "
                                   "content repository")
+        server_group.add_argument("--no-tls", action='store_true',
+                                  help="Don't bind to the https port.")
 
     def read_signing_key(self, signing_key_path):
         signing_keys = self.read_file(signing_key_path, "signing_key")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
new file mode 100644
index 0000000000..baa93b0ee4
--- /dev/null
+++ b/synapse/crypto/event_signing.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.api.events.utils import prune_event
+from syutil.jsonutil import encode_canonical_json
+from syutil.base64util import encode_base64, decode_base64
+from syutil.crypto.jsonsign import sign_json
+
+import hashlib
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
+    """Check whether the hash for this PDU matches the contents"""
+    computed_hash = _compute_content_hash(event, hash_algorithm)
+    if computed_hash.name not in event.hashes:
+        raise Exception("Algorithm %s not in hashes %s" % (
+            computed_hash.name, list(event.hashes)
+        ))
+    message_hash_base64 = event.hashes[computed_hash.name]
+    try:
+        message_hash_bytes = decode_base64(message_hash_base64)
+    except:
+        raise Exception("Invalid base64: %s" % (message_hash_base64,))
+    return message_hash_bytes == computed_hash.digest()
+
+
+def _compute_content_hash(event, hash_algorithm):
+    event_json = event.get_full_dict()
+    # TODO: We need to sign the JSON that is going out via fedaration.
+    event_json.pop("age_ts", None)
+    event_json.pop("unsigned", None)
+    event_json.pop("signatures", None)
+    event_json.pop("hashes", None)
+    event_json_bytes = encode_canonical_json(event_json)
+    return hash_algorithm(event_json_bytes)
+
+
+def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+    tmp_event = prune_event(event)
+    event_json = tmp_event.get_dict()
+    event_json.pop("signatures", None)
+    event_json.pop("age_ts", None)
+    event_json.pop("unsigned", None)
+    event_json_bytes = encode_canonical_json(event_json)
+    hashed = hash_algorithm(event_json_bytes)
+    return (hashed.name, hashed.digest())
+
+
+def compute_event_signature(event, signature_name, signing_key):
+    tmp_event = prune_event(event)
+    redact_json = tmp_event.get_full_dict()
+    redact_json.pop("signatures", None)
+    redact_json.pop("age_ts", None)
+    redact_json.pop("unsigned", None)
+    logger.debug("Signing event: %s", redact_json)
+    redact_json = sign_json(redact_json, signature_name, signing_key)
+    return redact_json["signatures"]
+
+
+def add_hashes_and_signatures(event, signature_name, signing_key,
+                              hash_algorithm=hashlib.sha256):
+    if hasattr(event, "old_state_events"):
+        state_json_bytes = encode_canonical_json(
+            [e.event_id for e in event.old_state_events.values()]
+        )
+        hashed = hash_algorithm(state_json_bytes)
+        event.state_hash = {
+            hashed.name: encode_base64(hashed.digest())
+        }
+
+    hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm)
+
+    if not hasattr(event, "hashes"):
+        event.hashes = {}
+    event.hashes[hashed.name] = encode_base64(hashed.digest())
+
+    event.signatures = compute_event_signature(
+        event,
+        signature_name=signature_name,
+        signing_key=signing_key,
+    )
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
index e8180d94fd..52c84efb5b 100644
--- a/synapse/federation/pdu_codec.py
+++ b/synapse/federation/pdu_codec.py
@@ -18,50 +18,25 @@ from .units import Pdu
 import copy
 
 
-def decode_event_id(event_id, server_name):
-    parts = event_id.split("@")
-    if len(parts) < 2:
-        return (event_id, server_name)
-    else:
-        return (parts[0], "".join(parts[1:]))
-
-
-def encode_event_id(pdu_id, origin):
-    return "%s@%s" % (pdu_id, origin)
-
-
 class PduCodec(object):
 
     def __init__(self, hs):
+        self.signing_key = hs.config.signing_key[0]
         self.server_name = hs.hostname
         self.event_factory = hs.get_event_factory()
         self.clock = hs.get_clock()
+        self.hs = hs
 
     def event_from_pdu(self, pdu):
         kwargs = {}
 
-        kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
-        kwargs["room_id"] = pdu.context
-        kwargs["etype"] = pdu.pdu_type
-        kwargs["prev_events"] = [
-            encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
-        ]
-
-        if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
-            kwargs["prev_state"] = encode_event_id(
-                pdu.prev_state_id, pdu.prev_state_origin
-            )
+        kwargs["etype"] = pdu.type
 
         kwargs.update({
             k: v
             for k, v in pdu.get_full_dict().items()
             if k not in [
-                "pdu_id",
-                "context",
-                "pdu_type",
-                "prev_pdus",
-                "prev_state_id",
-                "prev_state_origin",
+                "type",
             ]
         })
 
@@ -70,33 +45,10 @@ class PduCodec(object):
     def pdu_from_event(self, event):
         d = event.get_full_dict()
 
-        d["pdu_id"], d["origin"] = decode_event_id(
-            event.event_id, self.server_name
-        )
-        d["context"] = event.room_id
-        d["pdu_type"] = event.type
-
-        if hasattr(event, "prev_events"):
-            d["prev_pdus"] = [
-                decode_event_id(e, self.server_name)
-                for e in event.prev_events
-            ]
-
-        if hasattr(event, "prev_state"):
-            d["prev_state_id"], d["prev_state_origin"] = (
-                decode_event_id(event.prev_state, self.server_name)
-            )
-
-        if hasattr(event, "state_key"):
-            d["is_state"] = True
-
         kwargs = copy.deepcopy(event.unrecognized_keys)
         kwargs.update({
             k: v for k, v in d.items()
-            if k not in ["event_id", "room_id", "type", "prev_events"]
         })
 
-        if "origin_server_ts" not in kwargs:
-            kwargs["origin_server_ts"] = int(self.clock.time_msec())
-
-        return Pdu(**kwargs)
+        pdu = Pdu(**kwargs)
+        return pdu
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 7043fcc504..73dc844d59 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -21,8 +21,6 @@ These actions are mostly only used by the :py:mod:`.replication` module.
 
 from twisted.internet import defer
 
-from .units import Pdu
-
 from synapse.util.logutils import log_function
 
 import json
@@ -32,76 +30,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class PduActions(object):
-    """ Defines persistence actions that relate to handling PDUs.
-    """
-
-    def __init__(self, datastore):
-        self.store = datastore
-
-    @log_function
-    def mark_as_processed(self, pdu):
-        """ Persist the fact that we have fully processed the given `Pdu`
-
-        Returns:
-            Deferred
-        """
-        return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
-
-    @defer.inlineCallbacks
-    @log_function
-    def after_transaction(self, transaction_id, destination, origin):
-        """ Returns all `Pdu`s that we sent to the given remote home server
-        after a given transaction id.
-
-        Returns:
-            Deferred: Results in a list of `Pdu`s
-        """
-        results = yield self.store.get_pdus_after_transaction(
-            transaction_id,
-            destination
-        )
-
-        defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
-    @defer.inlineCallbacks
-    @log_function
-    def get_all_pdus_from_context(self, context):
-        results = yield self.store.get_all_pdus_from_context(context)
-        defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
-    @defer.inlineCallbacks
-    @log_function
-    def backfill(self, context, pdu_list, limit):
-        """ For a given list of PDU id and origins return the proceeding
-        `limit` `Pdu`s in the given `context`.
-
-        Returns:
-            Deferred: Results in a list of `Pdu`s.
-        """
-        results = yield self.store.get_backfill(
-            context, pdu_list, limit
-        )
-
-        defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
-    @log_function
-    def is_new(self, pdu):
-        """ When we receive a `Pdu` from a remote home server, we want to
-        figure out whether it is `new`, i.e. it is not some historic PDU that
-        we haven't seen simply because we haven't backfilled back that far.
-
-        Returns:
-            Deferred: Results in a `bool`
-        """
-        return self.store.is_pdu_new(
-            pdu_id=pdu.pdu_id,
-            origin=pdu.origin,
-            context=pdu.context,
-            depth=pdu.depth
-        )
-
-
 class TransactionActions(object):
     """ Defines persistence actions that relate to handling Transactions.
     """
@@ -158,7 +86,6 @@ class TransactionActions(object):
             transaction.transaction_id,
             transaction.destination,
             transaction.origin_server_ts,
-            [(p["pdu_id"], p["origin"]) for p in transaction.pdus]
         )
 
     @log_function
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 092411eaf9..5c625ddabf 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 from .units import Transaction, Pdu, Edu
 
-from .persistence import PduActions, TransactionActions
+from .persistence import TransactionActions
 
 from synapse.util.logutils import log_function
 
@@ -57,7 +57,7 @@ class ReplicationLayer(object):
         self.transport_layer.register_request_handler(self)
 
         self.store = hs.get_datastore()
-        self.pdu_actions = PduActions(self.store)
+        # self.pdu_actions = PduActions(self.store)
         self.transaction_actions = TransactionActions(self.store)
 
         self._transaction_queue = _TransactionQueue(
@@ -81,7 +81,7 @@ class ReplicationLayer(object):
 
     def register_edu_handler(self, edu_type, handler):
         if edu_type in self.edu_handlers:
-            raise KeyError("Already have an EDU handler for %s" % (edu_type))
+            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
 
         self.edu_handlers[edu_type] = handler
 
@@ -102,24 +102,17 @@ class ReplicationLayer(object):
           object to encode as JSON.
         """
         if query_type in self.query_handlers:
-            raise KeyError("Already have a Query handler for %s" % (query_type))
+            raise KeyError(
+                "Already have a Query handler for %s" % (query_type,)
+            )
 
         self.query_handlers[query_type] = handler
 
-    @defer.inlineCallbacks
     @log_function
     def send_pdu(self, pdu):
         """Informs the replication layer about a new PDU generated within the
         home server that should be transmitted to others.
 
-        This will fill out various attributes on the PDU object, e.g. the
-        `prev_pdus` key.
-
-        *Note:* The home server should always call `send_pdu` even if it knows
-        that it does not need to be replicated to other home servers. This is
-        in case e.g. someone else joins via a remote home server and then
-        backfills.
-
         TODO: Figure out when we should actually resolve the deferred.
 
         Args:
@@ -132,18 +125,15 @@ class ReplicationLayer(object):
         order = self._order
         self._order += 1
 
-        logger.debug("[%s] Persisting PDU", pdu.pdu_id)
-
-        # Save *before* trying to send
-        yield self.store.persist_event(pdu=pdu)
-
-        logger.debug("[%s] Persisted PDU", pdu.pdu_id)
-        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
 
         # TODO, add errback, etc.
         self._transaction_queue.enqueue_pdu(pdu, order)
 
-        logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+        logger.debug(
+            "[%s] transaction_layer.enqueue_pdu... done",
+            pdu.event_id
+        )
 
     @log_function
     def send_edu(self, destination, edu_type, content):
@@ -159,6 +149,11 @@ class ReplicationLayer(object):
         return defer.succeed(None)
 
     @log_function
+    def send_failure(self, failure, destination):
+        self._transaction_queue.enqueue_failure(failure, destination)
+        return defer.succeed(None)
+
+    @log_function
     def make_query(self, destination, query_type, args,
                    retry_on_dns_fail=True):
         """Sends a federation Query to a remote homeserver of the given type
@@ -181,7 +176,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def backfill(self, dest, context, limit):
+    def backfill(self, dest, context, limit, extremities):
         """Requests some more historic PDUs for the given context from the
         given destination server.
 
@@ -189,12 +184,12 @@ class ReplicationLayer(object):
             dest (str): The remote home server to ask.
             context (str): The context to backfill.
             limit (int): The maximum number of PDUs to return.
+            extremities (list): List of PDU id and origins of the first pdus
+                we have seen from the context
 
         Returns:
             Deferred: Results in the received PDUs.
         """
-        extremities = yield self.store.get_oldest_pdus_in_context(context)
-
         logger.debug("backfill extrem=%s", extremities)
 
         # If there are no extremeties then we've (probably) reached the start.
@@ -210,13 +205,13 @@ class ReplicationLayer(object):
 
         pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
         for pdu in pdus:
-            yield self._handle_new_pdu(pdu, backfilled=True)
+            yield self._handle_new_pdu(dest, pdu, backfilled=True)
 
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
     @log_function
-    def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+    def get_pdu(self, destination, event_id, outlier=False):
         """Requests the PDU with given origin and ID from the remote home
         server.
 
@@ -225,7 +220,7 @@ class ReplicationLayer(object):
         Args:
             destination (str): Which home server to query
             pdu_origin (str): The home server that originally sent the pdu.
-            pdu_id (str)
+            event_id (str)
             outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                 it's from an arbitary point in the context as opposed to part
                 of the current block of PDUs. Defaults to `False`
@@ -234,8 +229,9 @@ class ReplicationLayer(object):
             Deferred: Results in the requested PDU.
         """
 
-        transaction_data = yield self.transport_layer.get_pdu(
-            destination, pdu_origin, pdu_id)
+        transaction_data = yield self.transport_layer.get_event(
+            destination, event_id
+        )
 
         transaction = Transaction(**transaction_data)
 
@@ -244,13 +240,13 @@ class ReplicationLayer(object):
         pdu = None
         if pdu_list:
             pdu = pdu_list[0]
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdu)
 
     @defer.inlineCallbacks
     @log_function
-    def get_state_for_context(self, destination, context):
+    def get_state_for_context(self, destination, context, event_id=None):
         """Requests all of the `current` state PDUs for a given context from
         a remote home server.
 
@@ -263,29 +259,25 @@ class ReplicationLayer(object):
         """
 
         transaction_data = yield self.transport_layer.get_context_state(
-            destination, context)
+            destination,
+            context,
+            event_id=event_id,
+        )
 
         transaction = Transaction(**transaction_data)
 
         pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
         for pdu in pdus:
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_pdus_request(self, context):
-        pdus = yield self.pdu_actions.get_all_pdus_from_context(
-            context
+    def on_backfill_request(self, origin, context, versions, limit):
+        pdus = yield self.handler.on_backfill_request(
+            origin, context, versions, limit
         )
-        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
-
-    @defer.inlineCallbacks
-    @log_function
-    def on_backfill_request(self, context, versions, limit):
-
-        pdus = yield self.pdu_actions.backfill(context, versions, limit)
 
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
@@ -295,6 +287,10 @@ class ReplicationLayer(object):
         transaction = Transaction(**transaction_data)
 
         for p in transaction.pdus:
+            if "unsigned" in p:
+                unsigned = p["unsigned"]
+                if "age" in unsigned:
+                    p["age"] = unsigned["age"]
             if "age" in p:
                 p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
                 del p["age"]
@@ -315,11 +311,15 @@ class ReplicationLayer(object):
 
         dl = []
         for pdu in pdu_list:
-            dl.append(self._handle_new_pdu(pdu))
+            dl.append(self._handle_new_pdu(transaction.origin, pdu))
 
         if hasattr(transaction, "edus"):
             for edu in [Edu(**x) for x in transaction.edus]:
-                self.received_edu(transaction.origin, edu.edu_type, edu.content)
+                self.received_edu(
+                    transaction.origin,
+                    edu.edu_type,
+                    edu.content
+                )
 
         results = yield defer.DeferredList(dl)
 
@@ -347,20 +347,22 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_state_request(self, context):
-        results = yield self.store.get_current_state_for_context(
-            context
-        )
-
-        logger.debug("Context returning %d results", len(results))
+    def on_context_state_request(self, origin, context, event_id):
+        if event_id:
+            pdus = yield self.handler.get_state_for_pdu(
+                origin,
+                context,
+                event_id,
+            )
+        else:
+            raise NotImplementedError("Specify an event")
 
-        pdus = [Pdu.from_pdu_tuple(p) for p in results]
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
     @defer.inlineCallbacks
     @log_function
-    def on_pdu_request(self, pdu_origin, pdu_id):
-        pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+    def on_pdu_request(self, origin, event_id):
+        pdu = yield self._get_persisted_pdu(origin, event_id)
 
         if pdu:
             defer.returnValue(
@@ -372,103 +374,191 @@ class ReplicationLayer(object):
     @defer.inlineCallbacks
     @log_function
     def on_pull_request(self, origin, versions):
-        transaction_id = max([int(v) for v in versions])
+        raise NotImplementedError("Pull transacions not implemented")
+
+    @defer.inlineCallbacks
+    def on_query_request(self, query_type, args):
+        if query_type in self.query_handlers:
+            response = yield self.query_handlers[query_type](args)
+            defer.returnValue((200, response))
+        else:
+            defer.returnValue(
+                (404, "No handler for Query type '%s'" % (query_type, ))
+            )
+
+    @defer.inlineCallbacks
+    def on_make_join_request(self, context, user_id):
+        pdu = yield self.handler.on_make_join_request(context, user_id)
+        defer.returnValue({
+            "event": pdu.get_dict(),
+        })
 
-        response = yield self.pdu_actions.after_transaction(
-            transaction_id,
-            origin,
-            self.server_name
+    @defer.inlineCallbacks
+    def on_invite_request(self, origin, content):
+        pdu = Pdu(**content)
+        ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+        defer.returnValue(
+            (
+                200,
+                {
+                    "event": ret_pdu.get_dict(),
+                }
+            )
         )
 
-        if not response:
-            response = []
+    @defer.inlineCallbacks
+    def on_send_join_request(self, origin, content):
+        pdu = Pdu(**content)
+        res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+
+        defer.returnValue((200, {
+            "state": [p.get_dict() for p in res_pdus["state"]],
+            "auth_chain": [p.get_dict() for p in res_pdus["auth_chain"]],
+        }))
 
+    @defer.inlineCallbacks
+    def on_event_auth(self, origin, context, event_id):
+        auth_pdus = yield self.handler.on_event_auth(event_id)
         defer.returnValue(
-            (200, self._transaction_from_pdus(response).get_dict())
+            (
+                200,
+                {
+                    "auth_chain": [a.get_dict() for a in auth_pdus],
+                }
+            )
         )
 
     @defer.inlineCallbacks
-    def on_query_request(self, query_type, args):
-        if query_type in self.query_handlers:
-            response = yield self.query_handlers[query_type](args)
-            defer.returnValue((200, response))
-        else:
-            defer.returnValue((404, "No handler for Query type '%s'"
-                % (query_type)
-            ))
+    def make_join(self, destination, context, user_id):
+        ret = yield self.transport_layer.make_join(
+            destination=destination,
+            context=context,
+            user_id=user_id,
+        )
+
+        pdu_dict = ret["event"]
+
+        logger.debug("Got response to make_join: %s", pdu_dict)
+
+        defer.returnValue(Pdu(**pdu_dict))
 
     @defer.inlineCallbacks
+    def send_join(self, destination, pdu):
+        _, content = yield self.transport_layer.send_join(
+            destination,
+            pdu.room_id,
+            pdu.event_id,
+            pdu.get_dict(),
+        )
+
+        logger.debug("Got content: %s", content)
+        state = [Pdu(outlier=True, **p) for p in content.get("state", [])]
+        for pdu in state:
+            yield self._handle_new_pdu(destination, pdu)
+
+        auth_chain = [
+            Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
+        ]
+        for pdu in auth_chain:
+            yield self._handle_new_pdu(destination, pdu)
+
+        defer.returnValue(state)
+
+    @defer.inlineCallbacks
+    def send_invite(self, destination, context, event_id, pdu):
+        code, content = yield self.transport_layer.send_invite(
+            destination=destination,
+            context=context,
+            event_id=event_id,
+            content=pdu.get_dict(),
+        )
+
+        pdu_dict = content["event"]
+
+        logger.debug("Got response to send_invite: %s", pdu_dict)
+
+        defer.returnValue(Pdu(**pdu_dict))
+
     @log_function
-    def _get_persisted_pdu(self, pdu_id, pdu_origin):
+    def _get_persisted_pdu(self, origin, event_id):
         """ Get a PDU from the database with given origin and id.
 
         Returns:
             Deferred: Results in a `Pdu`.
         """
-        pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
-
-        defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+        return self.handler.get_persisted_pdu(origin, event_id)
 
     def _transaction_from_pdus(self, pdu_list):
         """Returns a new Transaction containing the given PDUs suitable for
         transmission.
         """
         pdus = [p.get_dict() for p in pdu_list]
+        time_now = self._clock.time_msec()
         for p in pdus:
-            if "age_ts" in pdus:
-                p["age"] = int(self.clock.time_msec()) - p["age_ts"]
-
+            if "age_ts" in p:
+                age = time_now - p["age_ts"]
+                p.setdefault("unsigned", {})["age"] = int(age)
+                del p["age_ts"]
         return Transaction(
             origin=self.server_name,
             pdus=pdus,
-            origin_server_ts=int(self._clock.time_msec()),
+            origin_server_ts=int(time_now),
             destination=None,
         )
 
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_pdu(self, pdu, backfilled=False):
+    def _handle_new_pdu(self, origin, pdu, backfilled=False):
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+        existing = yield self._get_persisted_pdu(origin, pdu.event_id)
 
         if existing and (not existing.outlier or pdu.outlier):
-            logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+            logger.debug("Already seen pdu %s", pdu.event_id)
             defer.returnValue({})
             return
 
+        state = None
+
         # Get missing pdus if necessary.
-        is_new = yield self.pdu_actions.is_new(pdu)
-        if is_new and not pdu.outlier:
+        if not pdu.outlier:
             # We only backfill backwards to the min depth.
-            min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+            min_depth = yield self.handler.get_min_depth_for_context(
+                pdu.room_id
+            )
 
             if min_depth and pdu.depth > min_depth:
-                for pdu_id, origin in pdu.prev_pdus:
-                    exists = yield self._get_persisted_pdu(pdu_id, origin)
+                for event_id, hashes in pdu.prev_events:
+                    exists = yield self._get_persisted_pdu(origin, event_id)
 
                     if not exists:
-                        logger.debug("Requesting pdu %s %s", pdu_id, origin)
+                        logger.debug("Requesting pdu %s", event_id)
 
                         try:
                             yield self.get_pdu(
                                 pdu.origin,
-                                pdu_id=pdu_id,
-                                pdu_origin=origin
+                                event_id=event_id,
                             )
-                            logger.debug("Processed pdu %s %s", pdu_id, origin)
+                            logger.debug("Processed pdu %s", event_id)
                         except:
                             # TODO(erikj): Do some more intelligent retries.
                             logger.exception("Failed to get PDU")
-
-        # Persist the Pdu, but don't mark it as processed yet.
-        yield self.store.persist_event(pdu=pdu)
+            else:
+                # We need to get the state at this event, since we have reached
+                # a backward extremity edge.
+                state = yield self.get_state_for_context(
+                    origin, pdu.room_id, pdu.event_id,
+                )
 
         if not backfilled:
-            ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+            ret = yield self.handler.on_receive_pdu(
+                pdu,
+                backfilled=backfilled,
+                state=state,
+            )
         else:
             ret = None
 
-        yield self.pdu_actions.mark_as_processed(pdu)
+        # yield self.pdu_actions.mark_as_processed(pdu)
 
         defer.returnValue(ret)
 
@@ -476,14 +566,6 @@ class ReplicationLayer(object):
         return "<ReplicationLayer(%s)>" % self.server_name
 
 
-class ReplicationHandler(object):
-    """This defines the methods that the :py:class:`.ReplicationLayer` will
-    use to communicate with the rest of the home server.
-    """
-    def on_receive_pdu(self, pdu):
-        raise NotImplementedError("on_receive_pdu")
-
-
 class _TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
     a time for a given destination.
@@ -509,6 +591,9 @@ class _TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = {}
 
+        # destination -> list of tuple(failure, deferred)
+        self.pending_failures_by_dest = {}
+
         # HACK to get unique tx id
         self._next_txn_id = int(self._clock.time_msec())
 
@@ -562,6 +647,18 @@ class _TransactionQueue(object):
         return deferred
 
     @defer.inlineCallbacks
+    def enqueue_failure(self, failure, destination):
+        deferred = defer.Deferred()
+
+        self.pending_failures_by_dest.setdefault(
+            destination, []
+        ).append(
+            (failure, deferred)
+        )
+
+        yield deferred
+
+    @defer.inlineCallbacks
     @log_function
     def _attempt_new_transaction(self, destination):
         if destination in self.pending_transactions:
@@ -570,8 +667,9 @@ class _TransactionQueue(object):
         #  list of (pending_pdu, deferred, order)
         pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
         pending_edus = self.pending_edus_by_dest.pop(destination, [])
+        pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-        if not pending_pdus and not pending_edus:
+        if not pending_pdus and not pending_edus and not pending_failures:
             return
 
         logger.debug("TX [%s] Attempting new transaction", destination)
@@ -581,7 +679,11 @@ class _TransactionQueue(object):
 
         pdus = [x[0] for x in pending_pdus]
         edus = [x[0] for x in pending_edus]
-        deferreds = [x[1] for x in pending_pdus + pending_edus]
+        failures = [x[0].get_dict() for x in pending_failures]
+        deferreds = [
+            x[1]
+            for x in pending_pdus + pending_edus + pending_failures
+        ]
 
         try:
             self.pending_transactions[destination] = 1
@@ -589,12 +691,13 @@ class _TransactionQueue(object):
             logger.debug("TX [%s] Persisting transaction...", destination)
 
             transaction = Transaction.create_new(
-                origin_server_ts=self._clock.time_msec(),
+                origin_server_ts=int(self._clock.time_msec()),
                 transaction_id=str(self._next_txn_id),
                 origin=self.server_name,
                 destination=destination,
                 pdus=pdus,
                 edus=edus,
+                pdu_failures=failures,
             )
 
             self._next_txn_id += 1
@@ -614,7 +717,9 @@ class _TransactionQueue(object):
                 if "pdus" in data:
                     for p in data["pdus"]:
                         if "age_ts" in p:
-                            p["age"] = now - int(p["age_ts"])
+                            unsigned = p.setdefault("unsigned", {})
+                            unsigned["age"] = now - int(p["age_ts"])
+                            del p["age_ts"]
                 return data
 
             code, response = yield self.transport_layer.send_transaction(
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index e7517cac4d..95c40c6c1b 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,7 @@ class TransportLayer(object):
         self.received_handler = None
 
     @log_function
-    def get_context_state(self, destination, context):
+    def get_context_state(self, destination, context, event_id=None):
         """ Requests all state for a given context (i.e. room) from the
         given server.
 
@@ -89,54 +89,62 @@ class TransportLayer(object):
 
         subpath = "/state/%s/" % context
 
-        return self._do_request_for_transaction(destination, subpath)
+        args = {}
+        if event_id:
+            args["event_id"] = event_id
+
+        return self._do_request_for_transaction(
+            destination, subpath, args=args
+        )
 
     @log_function
-    def get_pdu(self, destination, pdu_origin, pdu_id):
+    def get_event(self, destination, event_id):
         """ Requests the pdu with give id and origin from the given server.
 
         Args:
             destination (str): The host name of the remote home server we want
                 to get the state from.
-            pdu_origin (str): The home server which created the PDU.
-            pdu_id (str): The id of the PDU being requested.
+            event_id (str): The id of the event being requested.
 
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
-        logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
-                     destination, pdu_origin, pdu_id)
+        logger.debug("get_pdu dest=%s, event_id=%s",
+                     destination, event_id)
 
-        subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
+        subpath = "/event/%s/" % (event_id, )
 
         return self._do_request_for_transaction(destination, subpath)
 
     @log_function
-    def backfill(self, dest, context, pdu_tuples, limit):
+    def backfill(self, dest, context, event_tuples, limit):
         """ Requests `limit` previous PDUs in a given context before list of
         PDUs.
 
         Args:
             dest (str)
             context (str)
-            pdu_tuples (list)
+            event_tuples (list)
             limt (int)
 
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
         logger.debug(
-            "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
-            dest, context, repr(pdu_tuples), str(limit)
+            "backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
+            dest, context, repr(event_tuples), str(limit)
         )
 
-        if not pdu_tuples:
+        if not event_tuples:
+            # TODO: raise?
             return
 
-        subpath = "/backfill/%s/" % context
+        subpath = "/backfill/%s/" % (context,)
 
-        args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
-        args["limit"] = limit
+        args = {
+            "v": event_tuples,
+            "limit": limit,
+        }
 
         return self._do_request_for_transaction(
             dest,
@@ -198,6 +206,72 @@ class TransportLayer(object):
         defer.returnValue(response)
 
     @defer.inlineCallbacks
+    @log_function
+    def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
+        path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
+
+        response = yield self.client.get_json(
+            destination=destination,
+            path=path,
+            retry_on_dns_fail=retry_on_dns_fail,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_join(self, destination, context, event_id, content):
+        path = PREFIX + "/send_join/%s/%s" % (
+            context,
+            event_id,
+        )
+
+        code, content = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        if not 200 <= code < 300:
+            raise RuntimeError("Got %d from send_join", code)
+
+        defer.returnValue(json.loads(content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_invite(self, destination, context, event_id, content):
+        path = PREFIX + "/invite/%s/%s" % (
+            context,
+            event_id,
+        )
+
+        code, content = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        if not 200 <= code < 300:
+            raise RuntimeError("Got %d from send_invite", code)
+
+        defer.returnValue(json.loads(content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def get_event_auth(self, destination, context, event_id):
+        path = PREFIX + "/event_auth/%s/%s" % (
+            context,
+            event_id,
+        )
+
+        response = yield self.client.get_json(
+            destination=destination,
+            path=path,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
     def _authenticate_request(self, request):
         json_request = {
             "method": request.method,
@@ -210,7 +284,7 @@ class TransportLayer(object):
         origin = None
 
         if request.method == "PUT":
-            #TODO: Handle other method types? other content types?
+            # TODO: Handle other method types? other content types?
             try:
                 content_bytes = request.content.read()
                 content = json.loads(content_bytes)
@@ -222,11 +296,13 @@ class TransportLayer(object):
             try:
                 params = auth.split(" ")[1].split(",")
                 param_dict = dict(kv.split("=") for kv in params)
+
                 def strip_quotes(value):
                     if value.startswith("\""):
                         return value[1:-1]
                     else:
                         return value
+
                 origin = strip_quotes(param_dict["origin"])
                 key = strip_quotes(param_dict["key"])
                 sig = strip_quotes(param_dict["sig"])
@@ -247,7 +323,7 @@ class TransportLayer(object):
             if auth.startswith("X-Matrix"):
                 (origin, key, sig) = parse_auth_header(auth)
                 json_request["origin"] = origin
-                json_request["signatures"].setdefault(origin,{})[key] = sig
+                json_request["signatures"].setdefault(origin, {})[key] = sig
 
         if not json_request["signatures"]:
             raise SynapseError(
@@ -313,10 +389,10 @@ class TransportLayer(object):
         # data_id pair.
         self.server.register_path(
             "GET",
-            re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
+            re.compile("^" + PREFIX + "/event/([^/]*)/$"),
             self._with_authentication(
-                lambda origin, content, query, pdu_origin, pdu_id:
-                handler.on_pdu_request(pdu_origin, pdu_id)
+                lambda origin, content, query, event_id:
+                handler.on_pdu_request(origin, event_id)
             )
         )
 
@@ -326,7 +402,11 @@ class TransportLayer(object):
             re.compile("^" + PREFIX + "/state/([^/]*)/$"),
             self._with_authentication(
                 lambda origin, content, query, context:
-                handler.on_context_state_request(context)
+                handler.on_context_state_request(
+                    origin,
+                    context,
+                    query.get("event_id", [None])[0],
+                )
             )
         )
 
@@ -336,28 +416,63 @@ class TransportLayer(object):
             self._with_authentication(
                 lambda origin, content, query, context:
                 self._on_backfill_request(
-                    context, query["v"], query["limit"]
+                    origin, context, query["v"], query["limit"]
                 )
             )
         )
 
+        # This is when we receive a server-server Query
         self.server.register_path(
             "GET",
-            re.compile("^" + PREFIX + "/context/([^/]*)/$"),
+            re.compile("^" + PREFIX + "/query/([^/]*)$"),
             self._with_authentication(
-                lambda origin, content, query, context:
-                handler.on_context_pdus_request(context)
+                lambda origin, content, query, query_type:
+                handler.on_query_request(
+                    query_type, {k: v[0] for k, v in query.items()}
+                )
             )
         )
 
-        # This is when we receive a server-server Query
         self.server.register_path(
             "GET",
-            re.compile("^" + PREFIX + "/query/([^/]*)$"),
+            re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
             self._with_authentication(
-                lambda origin, content, query, query_type:
-                handler.on_query_request(
-                    query_type, {k: v[0] for k, v in query.items()}
+                lambda origin, content, query, context, user_id:
+                self._on_make_join_request(
+                    origin, content, query, context, user_id
+                )
+            )
+        )
+
+        self.server.register_path(
+            "GET",
+            re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                handler.on_event_auth(
+                    origin, context, event_id,
+                )
+            )
+        )
+
+        self.server.register_path(
+            "PUT",
+            re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                self._on_send_join_request(
+                    origin, content, query,
+                )
+            )
+        )
+
+        self.server.register_path(
+            "PUT",
+            re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                self._on_invite_request(
+                    origin, content, query,
                 )
             )
         )
@@ -402,7 +517,8 @@ class TransportLayer(object):
             return
 
         try:
-            code, response = yield self.received_handler.on_incoming_transaction(
+            handler = self.received_handler
+            code, response = yield handler.on_incoming_transaction(
                 transaction_data
             )
         except:
@@ -440,7 +556,7 @@ class TransportLayer(object):
         defer.returnValue(data)
 
     @log_function
-    def _on_backfill_request(self, context, v_list, limits):
+    def _on_backfill_request(self, origin, context, v_list, limits):
         if not limits:
             return defer.succeed(
                 (400, {"error": "Did not include limit param"})
@@ -448,124 +564,34 @@ class TransportLayer(object):
 
         limit = int(limits[-1])
 
-        versions = [v.split(",", 1) for v in v_list]
+        versions = v_list
 
         return self.request_handler.on_backfill_request(
-            context, versions, limit)
-
-
-class TransportReceivedHandler(object):
-    """ Callbacks used when we receive a transaction
-    """
-    def on_incoming_transaction(self, transaction):
-        """ Called on PUT /send/<transaction_id>, or on response to a request
-        that we sent (e.g. a backfill request)
-
-        Args:
-            transaction (synapse.transaction.Transaction): The transaction that
-                was sent to us.
-
-        Returns:
-            twisted.internet.defer.Deferred: A deferred that gets fired when
-            the transaction has finished being processed.
-
-            The result should be a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-
-class TransportRequestHandler(object):
-    """ Handlers used when someone want's data from us
-    """
-    def on_pull_request(self, versions):
-        """ Called on GET /pull/?v=...
-
-        This is hit when a remote home server wants to get all data
-        after a given transaction. Mainly used when a home server comes back
-        online and wants to get everything it has missed.
-
-        Args:
-            versions (list): A list of transaction_ids that should be used to
-                determine what PDUs the remote side have not yet seen.
-
-        Returns:
-            Deferred: Resultsin a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_pdu_request(self, pdu_origin, pdu_id):
-        """ Called on GET /pdu/<pdu_origin>/<pdu_id>/
-
-        Someone wants a particular PDU. This PDU may or may not have originated
-        from us.
-
-        Args:
-            pdu_origin (str)
-            pdu_id (str)
-
-        Returns:
-            Deferred: Resultsin a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_context_state_request(self, context):
-        """ Called on GET /state/<context>/
-
-        Gets hit when someone wants all the *current* state for a given
-        contexts.
-
-        Args:
-            context (str): The name of the context that we're interested in.
-
-        Returns:
-            twisted.internet.defer.Deferred: A deferred that gets fired when
-            the transaction has finished being processed.
-
-            The result should be a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_backfill_request(self, context, versions, limit):
-        """ Called on GET /backfill/<context>/?v=...&limit=...
+            origin, context, versions, limit
+        )
 
-        Gets hit when we want to backfill backwards on a given context from
-        the given point.
+    @defer.inlineCallbacks
+    @log_function
+    def _on_make_join_request(self, origin, content, query, context, user_id):
+        content = yield self.request_handler.on_make_join_request(
+            context, user_id,
+        )
+        defer.returnValue((200, content))
 
-        Args:
-            context (str): The context to backfill
-            versions (list): A list of 2-tuples representing where to backfill
-                from, in the form `(pdu_id, origin)`
-            limit (int): How many pdus to return.
+    @defer.inlineCallbacks
+    @log_function
+    def _on_send_join_request(self, origin, content, query):
+        content = yield self.request_handler.on_send_join_request(
+            origin, content,
+        )
 
-        Returns:
-            Deferred: Results in a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
+        defer.returnValue((200, content))
 
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
+    @defer.inlineCallbacks
+    @log_function
+    def _on_invite_request(self, origin, content, query):
+        content = yield self.request_handler.on_invite_request(
+            origin, content,
+        )
 
-    def on_query_request(self):
-        """ Called on a GET /query/<query_type> request. """
+        defer.returnValue((200, content))
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index b2fb964180..f4e7b62bd9 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -20,8 +20,6 @@ server protocol.
 from synapse.util.jsonobject import JsonEncodedObject
 
 import logging
-import json
-import copy
 
 
 logger = logging.getLogger(__name__)
@@ -33,13 +31,13 @@ class Pdu(JsonEncodedObject):
 
     A Pdu can be classified as "state". For a given context, we can efficiently
     retrieve all state pdu's that haven't been clobbered. Clobbering is done
-    via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
+    via a unique constraint on the tuple (context, type, state_key). A pdu
     is a state pdu if `is_state` is True.
 
     Example pdu::
 
         {
-            "pdu_id": "78c",
+            "event_id": "$78c:example.com",
             "origin_server_ts": 1404835423000,
             "origin": "bar",
             "prev_ids": [
@@ -52,24 +50,21 @@ class Pdu(JsonEncodedObject):
     """
 
     valid_keys = [
-        "pdu_id",
-        "context",
+        "event_id",
+        "room_id",
         "origin",
         "origin_server_ts",
-        "pdu_type",
+        "type",
         "destinations",
-        "transaction_id",
-        "prev_pdus",
+        "prev_events",
         "depth",
         "content",
-        "outlier",
-        "is_state",  # Below this are keys valid only for State Pdus.
-        "state_key",
-        "power_level",
-        "prev_state_id",
-        "prev_state_origin",
-        "required_power_level",
+        "hashes",
         "user_id",
+        "auth_events",
+        "signatures",  # Below this are keys valid only for State Pdus.
+        "state_key",
+        "prev_state",
     ]
 
     internal_keys = [
@@ -79,61 +74,28 @@ class Pdu(JsonEncodedObject):
     ]
 
     required_keys = [
-        "pdu_id",
-        "context",
+        "event_id",
+        "room_id",
         "origin",
         "origin_server_ts",
-        "pdu_type",
+        "type",
         "content",
     ]
 
     # TODO: We need to make this properly load content rather than
     # just leaving it as a dict. (OR DO WE?!)
 
-    def __init__(self, destinations=[], is_state=False, prev_pdus=[],
-                 outlier=False, **kwargs):
-        if is_state:
-            for required_key in ["state_key"]:
-                if required_key not in kwargs:
-                    raise RuntimeError("Key %s is required" % required_key)
-
+    def __init__(self, destinations=[], prev_events=[],
+                 outlier=False, hashes={}, signatures={}, **kwargs):
         super(Pdu, self).__init__(
             destinations=destinations,
-            is_state=is_state,
-            prev_pdus=prev_pdus,
+            prev_events=prev_events,
             outlier=outlier,
+            hashes=hashes,
+            signatures=signatures,
             **kwargs
         )
 
-    @classmethod
-    def from_pdu_tuple(cls, pdu_tuple):
-        """ Converts a PduTuple to a Pdu
-
-        Args:
-            pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
-                convert
-
-        Returns:
-            Pdu
-        """
-        if pdu_tuple:
-            d = copy.copy(pdu_tuple.pdu_entry._asdict())
-            d["origin_server_ts"] = d.pop("ts")
-
-            d["content"] = json.loads(d["content_json"])
-            del d["content_json"]
-
-            args = {f: d[f] for f in cls.valid_keys if f in d}
-            if "unrecognized_keys" in d and d["unrecognized_keys"]:
-                args.update(json.loads(d["unrecognized_keys"]))
-
-            return Pdu(
-                prev_pdus=pdu_tuple.prev_pdu_list,
-                **args
-            )
-        else:
-            return None
-
     def __str__(self):
         return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
 
@@ -193,6 +155,7 @@ class Transaction(JsonEncodedObject):
         "edus",
         "transaction_id",
         "destination",
+        "pdu_failures",
     ]
 
     internal_keys = [
@@ -229,7 +192,9 @@ class Transaction(JsonEncodedObject):
         transaction_id and origin_server_ts keys.
         """
         if "origin_server_ts" not in kwargs:
-            raise KeyError("Require 'origin_server_ts' to construct a Transaction")
+            raise KeyError(
+                "Require 'origin_server_ts' to construct a Transaction"
+            )
         if "transaction_id" not in kwargs:
             raise KeyError(
                 "Require 'transaction_id' to construct a Transaction"
@@ -241,6 +206,3 @@ class Transaction(JsonEncodedObject):
         kwargs["pdus"] = [p.get_dict() for p in pdus]
 
         return Transaction(**kwargs)
-
-
-
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index de4d23bbb3..07a8464107 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,7 +14,18 @@
 # limitations under the License.
 
 from twisted.internet import defer
+
 from synapse.api.errors import LimitExceededError
+from synapse.util.async import run_on_reactor
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.api.events.room import RoomMemberEvent
+from synapse.api.constants import Membership
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
 
 class BaseHandler(object):
 
@@ -30,6 +41,9 @@ class BaseHandler(object):
         self.clock = hs.get_clock()
         self.hs = hs
 
+        self.signing_key = hs.config.signing_key[0]
+        self.server_name = hs.hostname
+
     def ratelimit(self, user_id):
         time_now = self.clock.time()
         allowed, time_allowed = self.ratelimiter.send_message(
@@ -44,16 +58,58 @@ class BaseHandler(object):
 
     @defer.inlineCallbacks
     def _on_new_room_event(self, event, snapshot, extra_destinations=[],
-                           extra_users=[]):
+                           extra_users=[], suppress_auth=False,
+                           do_invite_host=None):
+        yield run_on_reactor()
+
         snapshot.fill_out_prev_events(event)
 
+        yield self.state_handler.annotate_state_groups(event)
+
+        yield self.auth.add_auth_events(event)
+
+        logger.debug("Signing event...")
+
+        add_hashes_and_signatures(
+            event, self.server_name, self.signing_key
+        )
+
+        logger.debug("Signed event.")
+
+        if not suppress_auth:
+            logger.debug("Authing...")
+            self.auth.check(event, raises=True)
+            logger.debug("Authed")
+        else:
+            logger.debug("Suppressed auth.")
+
+        if do_invite_host:
+            federation_handler = self.hs.get_handlers().federation_handler
+            invite_event = yield federation_handler.send_invite(
+                do_invite_host,
+                event
+            )
+
+            # FIXME: We need to check if the remote changed anything else
+            event.signatures = invite_event.signatures
+
         yield self.store.persist_event(event)
 
         destinations = set(extra_destinations)
         # Send a PDU to all hosts who have joined the room.
-        destinations.update((yield self.store.get_joined_hosts_for_room(
-            event.room_id
-        )))
+
+        for k, s in event.state_events.items():
+            try:
+                if k[0] == RoomMemberEvent.TYPE:
+                    if s.content["membership"] == Membership.JOIN:
+                        destinations.add(
+                            self.hs.parse_userid(s.state_key).domain
+                        )
+            except:
+                logger.warn(
+                    "Failed to get destination from event %s", s.event_id
+                )
+
         event.destinations = list(destinations)
 
         self.notifier.on_new_room_event(event, extra_users=extra_users)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a56830d520..164363cdc5 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -147,10 +147,8 @@ class DirectoryHandler(BaseHandler):
             content={"aliases": aliases},
         )
 
-        snapshot = yield self.store.snapshot_room(
-            room_id=room_id,
-            user_id=user_id,
-        )
+        snapshot = yield self.store.snapshot_room(event)
 
-        yield self.state_handler.handle_new_event(event, snapshot)
-        yield self._on_new_room_event(event, snapshot, extra_users=[user_id])
+        yield self._on_new_room_event(
+            event, snapshot, extra_users=[user_id], suppress_auth=True
+        )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f52591d2a3..c2cd91bb39 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,13 +17,15 @@
 
 from ._base import BaseHandler
 
-from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
+from synapse.api.errors import AuthError, FederationError
+from synapse.api.events.room import RoomMemberEvent
 from synapse.api.constants import Membership
 from synapse.util.logutils import log_function
 from synapse.federation.pdu_codec import PduCodec
-from synapse.api.errors import SynapseError
+from synapse.util.async import run_on_reactor
+from synapse.crypto.event_signing import compute_event_signature
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer
 
 import logging
 
@@ -62,6 +64,9 @@ class FederationHandler(BaseHandler):
 
         self.pdu_codec = PduCodec(hs)
 
+        # When joining a room we need to queue any events for that room up
+        self.room_queues = {}
+
     @log_function
     @defer.inlineCallbacks
     def handle_new_event(self, event, snapshot):
@@ -78,6 +83,8 @@ class FederationHandler(BaseHandler):
             processing.
         """
 
+        yield run_on_reactor()
+
         pdu = self.pdu_codec.pdu_from_event(event)
 
         if not hasattr(pdu, "destinations") or not pdu.destinations:
@@ -87,97 +94,88 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def on_receive_pdu(self, pdu, backfilled):
+    def on_receive_pdu(self, pdu, backfilled, state=None):
         """ Called by the ReplicationLayer when we have a new pdu. We need to
-        do auth checks and put it throught the StateHandler.
+        do auth checks and put it through the StateHandler.
         """
         event = self.pdu_codec.event_from_pdu(pdu)
 
         logger.debug("Got event: %s", event.event_id)
 
-        with (yield self.lock_manager.lock(pdu.context)):
-            if event.is_state and not backfilled:
-                is_new_state = yield self.state_handler.handle_new_state(
-                    pdu
-                )
-            else:
-                is_new_state = False
-        # TODO: Implement something in federation that allows us to
-        # respond to PDU.
+        if event.room_id in self.room_queues:
+            self.room_queues[event.room_id].append(pdu)
+            return
 
-        target_is_mine = False
-        if hasattr(event, "target_host"):
-            target_is_mine = event.target_host == self.hs.hostname
-
-        if event.type == InviteJoinEvent.TYPE:
-            if not target_is_mine:
-                logger.debug("Ignoring invite/join event %s", event)
-                return
-
-            # If we receive an invite/join event then we need to join the
-            # sender to the given room.
-            # TODO: We should probably auth this or some such
-            content = event.content
-            content.update({"membership": Membership.JOIN})
-            new_event = self.event_factory.create_event(
-                etype=RoomMemberEvent.TYPE,
-                state_key=event.user_id,
-                room_id=event.room_id,
-                user_id=event.user_id,
-                membership=Membership.JOIN,
-                content=content
-            )
+        logger.debug("Processing event: %s", event.event_id)
+
+        if state:
+            state = [self.pdu_codec.event_from_pdu(p) for p in state]
+
+        is_new_state = yield self.state_handler.annotate_state_groups(
+            event,
+            old_state=state
+        )
 
-            yield self.hs.get_handlers().room_member_handler.change_membership(
-                new_event,
-                do_auth=False,
+        logger.debug("Event: %s", event)
+
+        try:
+            self.auth.check(event, raises=True)
+        except AuthError as e:
+            raise FederationError(
+                "ERROR",
+                e.code,
+                e.msg,
+                affected=event.event_id,
             )
 
-        else:
-            with (yield self.room_lock.lock(event.room_id)):
-                yield self.store.persist_event(
-                    event,
-                    backfilled,
-                    is_new_state=is_new_state
-                )
+        is_new_state = is_new_state and not backfilled
 
-            room = yield self.store.get_room(event.room_id)
+        # TODO: Implement something in federation that allows us to
+        # respond to PDU.
 
-            if not room:
-                # Huh, let's try and get the current state
-                try:
-                    yield self.replication_layer.get_state_for_context(
-                        event.origin, event.room_id
-                    )
+        yield self.store.persist_event(
+            event,
+            backfilled,
+            is_new_state=is_new_state
+        )
 
-                    hosts = yield self.store.get_joined_hosts_for_room(
-                        event.room_id
-                    )
-                    if self.hs.hostname in hosts:
-                        try:
-                            yield self.store.store_room(
-                                room_id=event.room_id,
-                                room_creator_user_id="",
-                                is_public=False,
-                            )
-                        except:
-                            pass
-                except:
-                    logger.exception(
-                        "Failed to get current state for room %s",
-                        event.room_id
-                    )
+        room = yield self.store.get_room(event.room_id)
 
-            if not backfilled:
-                extra_users = []
-                if event.type == RoomMemberEvent.TYPE:
-                    target_user_id = event.state_key
-                    target_user = self.hs.parse_userid(target_user_id)
-                    extra_users.append(target_user)
+        if not room:
+            # Huh, let's try and get the current state
+            try:
+                yield self.replication_layer.get_state_for_context(
+                    event.origin, event.room_id, event.event_id,
+                )
 
-                yield self.notifier.on_new_room_event(
-                    event, extra_users=extra_users
+                hosts = yield self.store.get_joined_hosts_for_room(
+                    event.room_id
                 )
+                if self.hs.hostname in hosts:
+                    try:
+                        yield self.store.store_room(
+                            room_id=event.room_id,
+                            room_creator_user_id="",
+                            is_public=False,
+                        )
+                    except:
+                        pass
+            except:
+                logger.exception(
+                    "Failed to get current state for room %s",
+                    event.room_id
+                )
+
+        if not backfilled:
+            extra_users = []
+            if event.type == RoomMemberEvent.TYPE:
+                target_user_id = event.state_key
+                target_user = self.hs.parse_userid(target_user_id)
+                extra_users.append(target_user)
+
+            yield self.notifier.on_new_room_event(
+                event, extra_users=extra_users
+            )
 
         if event.type == RoomMemberEvent.TYPE:
             if event.membership == Membership.JOIN:
@@ -189,79 +187,344 @@ class FederationHandler(BaseHandler):
     @log_function
     @defer.inlineCallbacks
     def backfill(self, dest, room_id, limit):
-        pdus = yield self.replication_layer.backfill(dest, room_id, limit)
+        extremities = yield self.store.get_oldest_events_in_room(room_id)
+
+        pdus = yield self.replication_layer.backfill(
+            dest,
+            room_id,
+            limit,
+            extremities=extremities,
+        )
 
         events = []
 
         for pdu in pdus:
             event = self.pdu_codec.event_from_pdu(pdu)
+
+            # FIXME (erikj): Not sure this actually works :/
+            yield self.state_handler.annotate_state_groups(event)
+
             events.append(event)
+
             yield self.store.persist_event(event, backfilled=True)
 
         defer.returnValue(events)
 
+    @defer.inlineCallbacks
+    def send_invite(self, target_host, event):
+        pdu = yield self.replication_layer.send_invite(
+            destination=target_host,
+            context=event.room_id,
+            event_id=event.event_id,
+            pdu=self.pdu_codec.pdu_from_event(event)
+        )
+
+        defer.returnValue(self.pdu_codec.event_from_pdu(pdu))
+
+    @defer.inlineCallbacks
+    def on_event_auth(self, event_id):
+        auth = yield self.store.get_auth_chain(event_id)
+        defer.returnValue([self.pdu_codec.pdu_from_event(e) for e in auth])
+
     @log_function
     @defer.inlineCallbacks
     def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
-
         hosts = yield self.store.get_joined_hosts_for_room(room_id)
         if self.hs.hostname in hosts:
             # We are already in the room.
             logger.debug("We're already in the room apparently")
             defer.returnValue(False)
 
-        # First get current state to see if we are already joined.
+        pdu = yield self.replication_layer.make_join(
+            target_host,
+            room_id,
+            joinee
+        )
+
+        logger.debug("Got response to make_join: %s", pdu)
+
+        event = self.pdu_codec.event_from_pdu(pdu)
+
+        # We should assert some things.
+        assert(event.type == RoomMemberEvent.TYPE)
+        assert(event.user_id == joinee)
+        assert(event.state_key == joinee)
+        assert(event.room_id == room_id)
+
+        event.outlier = False
+
+        self.room_queues[room_id] = []
+
         try:
-            yield self.replication_layer.get_state_for_context(
-                target_host, room_id
+            event.event_id = self.event_factory.create_event_id()
+            event.content = content
+
+            state = yield self.replication_layer.send_join(
+                target_host,
+                self.pdu_codec.pdu_from_event(event)
             )
 
-            hosts = yield self.store.get_joined_hosts_for_room(room_id)
-            if self.hs.hostname in hosts:
-                # Oh, we were actually in the room already.
-                logger.debug("We're already in the room apparently")
-                defer.returnValue(False)
-        except Exception:
-            logger.exception("Failed to get current state")
-
-        new_event = self.event_factory.create_event(
-            etype=InviteJoinEvent.TYPE,
-            target_host=target_host,
-            room_id=room_id,
-            user_id=joinee,
-            content=content
-        )
+            state = [self.pdu_codec.event_from_pdu(p) for p in state]
 
-        new_event.destinations = [target_host]
+            logger.debug("do_invite_join state: %s", state)
 
-        snapshot.fill_out_prev_events(new_event)
-        yield self.handle_new_event(new_event, snapshot)
+            is_new_state = yield self.state_handler.annotate_state_groups(
+                event,
+                old_state=state
+            )
 
-        # TODO (erikj): Time out here.
-        d = defer.Deferred()
-        self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
-        reactor.callLater(10, d.cancel)
+            logger.debug("do_invite_join event: %s", event)
 
-        try:
-            yield d
-        except defer.CancelledError:
-            raise SynapseError(500, "Unable to join remote room")
+            try:
+                yield self.store.store_room(
+                    room_id=room_id,
+                    room_creator_user_id="",
+                    is_public=False
+                )
+            except:
+                # FIXME
+                pass
 
-        try:
-            yield self.store.store_room(
-                room_id=room_id,
-                room_creator_user_id="",
-                is_public=False
+            for e in state:
+                # FIXME: Auth these.
+                e.outlier = True
+
+                yield self.state_handler.annotate_state_groups(
+                    e,
+                )
+
+                yield self.store.persist_event(
+                    e,
+                    backfilled=False,
+                    is_new_state=False
+                )
+
+            yield self.store.persist_event(
+                event,
+                backfilled=False,
+                is_new_state=is_new_state
             )
-        except:
-            pass
+        finally:
+            room_queue = self.room_queues[room_id]
+            del self.room_queues[room_id]
 
+            for p in room_queue:
+                try:
+                    yield self.on_receive_pdu(p, backfilled=False)
+                except:
+                    pass
 
         defer.returnValue(True)
 
+    @defer.inlineCallbacks
+    @log_function
+    def on_make_join_request(self, context, user_id):
+        event = self.event_factory.create_event(
+            etype=RoomMemberEvent.TYPE,
+            content={"membership": Membership.JOIN},
+            room_id=context,
+            user_id=user_id,
+            state_key=user_id,
+        )
+
+        snapshot = yield self.store.snapshot_room(event)
+        snapshot.fill_out_prev_events(event)
+
+        yield self.state_handler.annotate_state_groups(event)
+        yield self.auth.add_auth_events(event)
+        self.auth.check(event, raises=True)
+
+        pdu = self.pdu_codec.pdu_from_event(event)
+
+        defer.returnValue(pdu)
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_send_join_request(self, origin, pdu):
+        event = self.pdu_codec.event_from_pdu(pdu)
+
+        event.outlier = False
+
+        is_new_state = yield self.state_handler.annotate_state_groups(event)
+        self.auth.check(event, raises=True)
+
+        # FIXME (erikj):  All this is duplicated above :(
+
+        yield self.store.persist_event(
+            event,
+            backfilled=False,
+            is_new_state=is_new_state
+        )
+
+        extra_users = []
+        if event.type == RoomMemberEvent.TYPE:
+            target_user_id = event.state_key
+            target_user = self.hs.parse_userid(target_user_id)
+            extra_users.append(target_user)
+
+        yield self.notifier.on_new_room_event(
+            event, extra_users=extra_users
+        )
+
+        if event.type == RoomMemberEvent.TYPE:
+            if event.membership == Membership.JOIN:
+                user = self.hs.parse_userid(event.state_key)
+                self.distributor.fire(
+                    "user_joined_room", user=user, room_id=event.room_id
+                )
+
+        new_pdu = self.pdu_codec.pdu_from_event(event)
+
+        destinations = set()
+
+        for k, s in event.state_events.items():
+            try:
+                if k[0] == RoomMemberEvent.TYPE:
+                    if s.content["membership"] == Membership.JOIN:
+                        destinations.add(
+                            self.hs.parse_userid(s.state_key).domain
+                        )
+            except:
+                logger.warn(
+                    "Failed to get destination from event %s", s.event_id
+                )
+
+        new_pdu.destinations = list(destinations)
+
+        yield self.replication_layer.send_pdu(new_pdu)
+
+        auth_chain = yield self.store.get_auth_chain(event.event_id)
+        pdu_auth_chain = [
+            self.pdu_codec.pdu_from_event(e)
+            for e in auth_chain
+        ]
+
+        defer.returnValue({
+            "state": [
+                self.pdu_codec.pdu_from_event(e)
+                for e in event.state_events.values()
+            ],
+            "auth_chain": pdu_auth_chain,
+        })
+
+    @defer.inlineCallbacks
+    def on_invite_request(self, origin, pdu):
+        event = self.pdu_codec.event_from_pdu(pdu)
+
+        event.outlier = True
+
+        event.signatures.update(
+            compute_event_signature(
+                event,
+                self.hs.hostname,
+                self.hs.config.signing_key[0]
+            )
+        )
+
+        yield self.state_handler.annotate_state_groups(event)
+
+        yield self.store.persist_event(
+            event,
+            backfilled=False,
+        )
+
+        target_user = self.hs.parse_userid(event.state_key)
+        yield self.notifier.on_new_room_event(
+            event, extra_users=[target_user],
+        )
+
+        defer.returnValue(self.pdu_codec.pdu_from_event(event))
+
+    @defer.inlineCallbacks
+    def get_state_for_pdu(self, origin, room_id, event_id):
+        yield run_on_reactor()
+
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        if not in_room:
+            raise AuthError(403, "Host not in room.")
+
+        state_groups = yield self.store.get_state_groups(
+            [event_id]
+        )
+
+        if state_groups:
+            _, state = state_groups.items().pop()
+            results = {
+                (e.type, e.state_key): e for e in state
+            }
+
+            event = yield self.store.get_event(event_id)
+            if hasattr(event, "state_key"):
+                # Get previous state
+                if hasattr(event, "replaces_state") and event.replaces_state:
+                    prev_event = yield self.store.get_event(
+                        event.replaces_state
+                    )
+                    results[(event.type, event.state_key)] = prev_event
+                else:
+                    del results[(event.type, event.state_key)]
+
+            defer.returnValue(
+                [
+                    self.pdu_codec.pdu_from_event(s)
+                    for s in results.values()
+                ]
+            )
+        else:
+            defer.returnValue([])
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_backfill_request(self, origin, context, pdu_list, limit):
+        in_room = yield self.auth.check_host_in_room(context, origin)
+        if not in_room:
+            raise AuthError(403, "Host not in room.")
+
+        events = yield self.store.get_backfill_events(
+            context,
+            pdu_list,
+            limit
+        )
+
+        defer.returnValue([
+            self.pdu_codec.pdu_from_event(e)
+            for e in events
+        ])
+
+    @defer.inlineCallbacks
+    @log_function
+    def get_persisted_pdu(self, origin, event_id):
+        """ Get a PDU from the database with given origin and id.
+
+        Returns:
+            Deferred: Results in a `Pdu`.
+        """
+        event = yield self.store.get_event(
+            event_id,
+            allow_none=True,
+        )
+
+        if event:
+            in_room = yield self.auth.check_host_in_room(
+                event.room_id,
+                origin
+            )
+            if not in_room:
+                raise AuthError(403, "Host not in room.")
+
+            defer.returnValue(self.pdu_codec.pdu_from_event(event))
+        else:
+            defer.returnValue(None)
+
+    @log_function
+    def get_min_depth_for_context(self, context):
+        return self.store.get_min_depth(context)
 
     @log_function
     def _on_user_joined(self, user, room_id):
-        waiters = self.waiting_for_join_list.get((user.to_string(), room_id), [])
+        waiters = self.waiting_for_join_list.get(
+            (user.to_string(), room_id),
+            []
+        )
         while waiters:
             waiters.pop().callback(None)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 72894869ea..8394013df3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -81,12 +81,11 @@ class MessageHandler(BaseHandler):
         user = self.hs.parse_userid(event.user_id)
         assert user.is_mine, "User must be our own: %s" % (user,)
 
-        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
+        snapshot = yield self.store.snapshot_room(event)
 
-        if not suppress_auth:
-            yield self.auth.check(event, snapshot, raises=True)
-
-        yield self._on_new_room_event(event, snapshot)
+        yield self._on_new_room_event(
+            event, snapshot, suppress_auth=suppress_auth
+        )
 
         self.hs.get_handlers().presence_handler.bump_presence_active_time(
             user
@@ -142,16 +141,7 @@ class MessageHandler(BaseHandler):
             SynapseError if something went wrong.
         """
 
-        snapshot = yield self.store.snapshot_room(
-            event.room_id,
-            event.user_id,
-            state_type=event.type,
-            state_key=event.state_key,
-        )
-
-        yield self.auth.check(event, snapshot, raises=True)
-
-        yield self.state_handler.handle_new_event(event, snapshot)
+        snapshot = yield self.store.snapshot_room(event)
 
         yield self._on_new_room_event(event, snapshot)
 
@@ -201,7 +191,7 @@ class MessageHandler(BaseHandler):
                 raise RoomError(
                     403, "Member does not meet private room rules.")
 
-        data = yield self.store.get_current_state(
+        data = yield self.state_handler.get_current_state(
             room_id, event_type, state_key
         )
         defer.returnValue(data)
@@ -219,9 +209,7 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def send_feedback(self, event):
-        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
-
-        yield self.auth.check(event, snapshot, raises=True)
+        snapshot = yield self.store.snapshot_room(event)
 
         # store message in db
         yield self._on_new_room_event(event, snapshot)
@@ -239,7 +227,7 @@ class MessageHandler(BaseHandler):
         yield self.auth.check_joined_room(room_id, user_id)
 
         # TODO: This is duplicating logic from snapshot_all_rooms
-        current_state = yield self.store.get_current_state(room_id)
+        current_state = yield self.state_handler.get_current_state(room_id)
         defer.returnValue([self.hs.serialize_event(c) for c in current_state])
 
     @defer.inlineCallbacks
@@ -316,7 +304,7 @@ class MessageHandler(BaseHandler):
                     "end": end_token.to_string(),
                 }
 
-                current_state = yield self.store.get_current_state(
+                current_state = yield self.state_handler.get_current_state(
                     event.room_id
                 )
                 d["state"] = [self.hs.serialize_event(c) for c in current_state]
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dab9b03f04..834b37f5f3 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError, CodeMessageException
 from synapse.api.constants import Membership
-from synapse.api.events.room import RoomMemberEvent
 
 from ._base import BaseHandler
 
@@ -153,10 +152,13 @@ class ProfileHandler(BaseHandler):
         if not user.is_mine:
             defer.returnValue(None)
 
-        (displayname, avatar_url) = yield defer.gatherResults([
-            self.store.get_profile_displayname(user.localpart),
-            self.store.get_profile_avatar_url(user.localpart),
-        ])
+        (displayname, avatar_url) = yield defer.gatherResults(
+            [
+                self.store.get_profile_displayname(user.localpart),
+                self.store.get_profile_avatar_url(user.localpart),
+            ],
+            consumeErrors=True
+        )
 
         state["displayname"] = displayname
         state["avatar_url"] = avatar_url
@@ -196,10 +198,7 @@ class ProfileHandler(BaseHandler):
         )
 
         for j in joins:
-            snapshot = yield self.store.snapshot_room(
-                j.room_id, j.state_key, RoomMemberEvent.TYPE,
-                j.state_key
-            )
+            snapshot = yield self.store.snapshot_room(j)
 
             content = {
                 "membership": j.content["membership"],
@@ -218,5 +217,6 @@ class ProfileHandler(BaseHandler):
                 user_id=j.state_key,
             )
 
-            yield self.state_handler.handle_new_event(new_event, snapshot)
-            yield self._on_new_room_event(new_event, snapshot)
+            yield self._on_new_room_event(
+                new_event, snapshot, suppress_auth=True
+            )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 81ce1a5907..3642fcfc6d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -21,8 +21,7 @@ from synapse.api.constants import Membership, JoinRules
 from synapse.api.errors import StoreError, SynapseError
 from synapse.api.events.room import (
     RoomMemberEvent, RoomCreateEvent, RoomPowerLevelsEvent,
-    RoomJoinRulesEvent, RoomAddStateLevelEvent, RoomTopicEvent,
-    RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomNameEvent,
+    RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
 )
 from synapse.util import stringutils
 from ._base import BaseHandler
@@ -122,15 +121,13 @@ class RoomCreationHandler(BaseHandler):
 
         @defer.inlineCallbacks
         def handle_event(event):
-            snapshot = yield self.store.snapshot_room(
-                room_id=room_id,
-                user_id=user_id,
-            )
+            snapshot = yield self.store.snapshot_room(event)
 
             logger.debug("Event: %s", event)
 
-            yield self.state_handler.handle_new_event(event, snapshot)
-            yield self._on_new_room_event(event, snapshot, extra_users=[user])
+            yield self._on_new_room_event(
+                event, snapshot, extra_users=[user], suppress_auth=True
+            )
 
         for event in creation_events:
             yield handle_event(event)
@@ -141,7 +138,6 @@ class RoomCreationHandler(BaseHandler):
                 etype=RoomNameEvent.TYPE,
                 room_id=room_id,
                 user_id=user_id,
-                required_power_level=50,
                 content={"name": name},
             )
 
@@ -153,7 +149,6 @@ class RoomCreationHandler(BaseHandler):
                 etype=RoomTopicEvent.TYPE,
                 room_id=room_id,
                 user_id=user_id,
-                required_power_level=50,
                 content={"topic": topic},
             )
 
@@ -198,7 +193,6 @@ class RoomCreationHandler(BaseHandler):
         event_keys = {
             "room_id": room_id,
             "user_id": creator.to_string(),
-            "required_power_level": 100,
         }
 
         def create(etype, **content):
@@ -215,7 +209,21 @@ class RoomCreationHandler(BaseHandler):
 
         power_levels_event = self.event_factory.create_event(
             etype=RoomPowerLevelsEvent.TYPE,
-            content={creator.to_string(): 100, "default": 0},
+            content={
+                "users": {
+                    creator.to_string(): 100,
+                },
+                "users_default": 0,
+                "events": {
+                    RoomNameEvent.TYPE: 100,
+                    RoomPowerLevelsEvent.TYPE: 100,
+                },
+                "events_default": 0,
+                "state_default": 50,
+                "ban": 50,
+                "kick": 50,
+                "redact": 50
+            },
             **event_keys
         )
 
@@ -225,30 +233,10 @@ class RoomCreationHandler(BaseHandler):
             join_rule=join_rule,
         )
 
-        add_state_event = create(
-            etype=RoomAddStateLevelEvent.TYPE,
-            level=100,
-        )
-
-        send_event = create(
-            etype=RoomSendEventLevelEvent.TYPE,
-            level=0,
-        )
-
-        ops = create(
-            etype=RoomOpsPowerLevelsEvent.TYPE,
-            ban_level=50,
-            kick_level=50,
-            redact_level=50,
-        )
-
         return [
             creation_event,
             power_levels_event,
             join_rules_event,
-            add_state_event,
-            send_event,
-            ops,
         ]
 
 
@@ -363,10 +351,8 @@ class RoomMemberHandler(BaseHandler):
         """
         target_user_id = event.state_key
 
-        snapshot = yield self.store.snapshot_room(
-            event.room_id, event.user_id,
-            RoomMemberEvent.TYPE, target_user_id
-        )
+        snapshot = yield self.store.snapshot_room(event)
+
         ## TODO(markjh): get prev state from snapshot.
         prev_state = yield self.store.get_room_member(
             target_user_id, event.room_id
@@ -375,13 +361,6 @@ class RoomMemberHandler(BaseHandler):
         if prev_state:
             event.content["prev"] = prev_state.membership
 
-#        if prev_state and prev_state.membership == event.membership:
-#            # treat this event as a NOOP.
-#            if do_auth:  # This is mainly to fix a unit test.
-#                yield self.auth.check(event, raises=True)
-#            defer.returnValue({})
-#            return
-
         room_id = event.room_id
 
         # If we're trying to join a room then we have to do this differently
@@ -391,29 +370,17 @@ class RoomMemberHandler(BaseHandler):
             yield self._do_join(event, snapshot, do_auth=do_auth)
         else:
             # This is not a JOIN, so we can handle it normally.
-            if do_auth:
-                yield self.auth.check(event, snapshot, raises=True)
-
-            # If we're banning someone, set a req power level
-            if event.membership == Membership.BAN:
-                if not hasattr(event, "required_power_level") or event.required_power_level is None:
-                    # Add some default required_power_level
-                    user_level = yield self.store.get_power_level(
-                        event.room_id,
-                        event.user_id,
-                    )
-                    event.required_power_level = user_level
 
             if prev_state and prev_state.membership == event.membership:
                 # double same action, treat this event as a NOOP.
                 defer.returnValue({})
                 return
 
-            yield self.state_handler.handle_new_event(event, snapshot)
             yield self._do_local_membership_update(
                 event,
                 membership=event.content["membership"],
                 snapshot=snapshot,
+                do_auth=do_auth,
             )
 
         defer.returnValue({"room_id": room_id})
@@ -443,10 +410,7 @@ class RoomMemberHandler(BaseHandler):
             content=content,
         )
 
-        snapshot = yield self.store.snapshot_room(
-            room_id, joinee.to_string(), RoomMemberEvent.TYPE,
-            joinee.to_string()
-        )
+        snapshot = yield self.store.snapshot_room(new_event)
 
         yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
 
@@ -502,14 +466,11 @@ class RoomMemberHandler(BaseHandler):
         if not have_joined:
             logger.debug("Doing normal join")
 
-            if do_auth:
-                yield self.auth.check(event, snapshot, raises=True)
-
-            yield self.state_handler.handle_new_event(event, snapshot)
             yield self._do_local_membership_update(
                 event,
                 membership=event.content["membership"],
                 snapshot=snapshot,
+                do_auth=do_auth,
             )
 
         user = self.hs.parse_userid(event.user_id)
@@ -553,26 +514,27 @@ class RoomMemberHandler(BaseHandler):
 
         defer.returnValue([r.room_id for r in rooms])
 
-    def _do_local_membership_update(self, event, membership, snapshot):
-        destinations = []
-
+    @defer.inlineCallbacks
+    def _do_local_membership_update(self, event, membership, snapshot,
+                                    do_auth):
         # If we're inviting someone, then we should also send it to that
         # HS.
         target_user_id = event.state_key
         target_user = self.hs.parse_userid(target_user_id)
-        if membership == Membership.INVITE:
-            host = target_user.domain
-            destinations.append(host)
-
-        # Always include target domain
-        host = target_user.domain
-        destinations.append(host)
-
-        return self._on_new_room_event(
-            event, snapshot, extra_destinations=destinations,
-            extra_users=[target_user]
+        if membership == Membership.INVITE and not target_user.is_mine:
+            do_invite_host = target_user.domain
+        else:
+            do_invite_host = None
+
+        yield self._on_new_room_event(
+            event,
+            snapshot,
+            extra_users=[target_user],
+            suppress_auth=(not do_auth),
+            do_invite_host=do_invite_host,
         )
 
+
 class RoomListHandler(BaseHandler):
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/base.py b/synapse/rest/base.py
index 2e8e3fa7d4..79fc4dfb84 100644
--- a/synapse/rest/base.py
+++ b/synapse/rest/base.py
@@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
 from synapse.rest.transactions import HttpTransactionStore
 import re
 
+import logging
+
+
+logger = logging.getLogger(__name__)
+
 
 def client_path_pattern(path_regex):
     """Creates a regex compiled client path with the correct client path
@@ -62,6 +67,8 @@ class RestServlet(object):
         self.auth = hs.get_auth()
         self.txns = HttpTransactionStore()
 
+        self.validator = hs.get_event_validator()
+
     def register(self, http_server):
         """ Register this servlet with the given HTTP server. """
         if hasattr(self, "PATTERN"):
diff --git a/synapse/rest/events.py b/synapse/rest/events.py
index 097195d7cc..92ff5e5ca7 100644
--- a/synapse/rest/events.py
+++ b/synapse/rest/events.py
@@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
 from synapse.streams.config import PaginationConfig
 from synapse.rest.base import RestServlet, client_path_pattern
 
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
 
 class EventStreamRestServlet(RestServlet):
     PATTERN = client_path_pattern("/events$")
@@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
     @defer.inlineCallbacks
     def on_GET(self, request):
         auth_user = yield self.auth.get_user_by_req(request)
-
-        handler = self.handlers.event_stream_handler
-        pagin_config = PaginationConfig.from_request(request)
-        timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
-        if "timeout" in request.args:
-            try:
-                timeout = int(request.args["timeout"][0])
-            except ValueError:
-                raise SynapseError(400, "timeout must be in milliseconds.")
-
-        chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
-                                         timeout=timeout)
+        try:
+            handler = self.handlers.event_stream_handler
+            pagin_config = PaginationConfig.from_request(request)
+            timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
+            if "timeout" in request.args:
+                try:
+                    timeout = int(request.args["timeout"][0])
+                except ValueError:
+                    raise SynapseError(400, "timeout must be in milliseconds.")
+
+            chunk = yield handler.get_stream(
+                auth_user.to_string(), pagin_config, timeout=timeout
+            )
+        except:
+            logger.exception("Event stream failed")
+            raise
 
         defer.returnValue((200, chunk))
 
diff --git a/synapse/rest/room.py b/synapse/rest/room.py
index 7724967061..05da0be090 100644
--- a/synapse/rest/room.py
+++ b/synapse/rest/room.py
@@ -138,7 +138,7 @@ class RoomStateEventRestServlet(RestServlet):
             raise SynapseError(
                 404, "Event not found.", errcode=Codes.NOT_FOUND
             )
-        defer.returnValue((200, data[0].get_dict()["content"]))
+        defer.returnValue((200, data.get_dict()["content"]))
 
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, event_type, state_key):
@@ -154,6 +154,9 @@ class RoomStateEventRestServlet(RestServlet):
             user_id=user.to_string(),
             state_key=urllib.unquote(state_key)
             )
+
+        self.validator.validate(event)
+
         if event_type == RoomMemberEvent.TYPE:
             # membership events are special
             handler = self.handlers.room_member_handler
@@ -188,6 +191,8 @@ class RoomSendEventRestServlet(RestServlet):
             content=content
         )
 
+        self.validator.validate(event)
+
         msg_handler = self.handlers.message_handler
         yield msg_handler.send_message(event)
 
@@ -253,6 +258,9 @@ class JoinRoomAliasServlet(RestServlet):
                 user_id=user.to_string(),
                 state_key=user.to_string()
             )
+
+            self.validator.validate(event)
+
             handler = self.handlers.room_member_handler
             yield handler.change_membership(event)
             defer.returnValue((200, {}))
@@ -424,6 +432,9 @@ class RoomMembershipRestServlet(RestServlet):
             user_id=user.to_string(),
             state_key=state_key
         )
+
+        self.validator.validate(event)
+
         handler = self.handlers.room_member_handler
         yield handler.change_membership(event)
         defer.returnValue((200, {}))
@@ -461,6 +472,8 @@ class RoomRedactEventRestServlet(RestServlet):
             redacts=urllib.unquote(event_id),
         )
 
+        self.validator.validate(event)
+
         msg_handler = self.handlers.message_handler
         yield msg_handler.send_message(event)
 
diff --git a/synapse/server.py b/synapse/server.py
index a4d2d4aba5..da0a44433a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -22,13 +22,14 @@
 from synapse.federation import initialize_http_replication
 from synapse.api.events import serialize_event
 from synapse.api.events.factory import EventFactory
+from synapse.api.events.validator import EventValidator
 from synapse.notifier import Notifier
 from synapse.api.auth import Auth
 from synapse.handlers import Handlers
 from synapse.rest import RestServletFactory
 from synapse.state import StateHandler
 from synapse.storage import DataStore
-from synapse.types import UserID, RoomAlias, RoomID
+from synapse.types import UserID, RoomAlias, RoomID, EventID
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
 from synapse.util.lockutils import LockManager
@@ -80,6 +81,7 @@ class BaseHomeServer(object):
         'event_sources',
         'ratelimiter',
         'keyring',
+        'event_validator',
     ]
 
     def __init__(self, hostname, **kwargs):
@@ -143,6 +145,11 @@ class BaseHomeServer(object):
         object."""
         return RoomID.from_string(s, hs=self)
 
+    def parse_eventid(self, s):
+        """Parse the string given by 's' as a Event ID and return a EventID
+        object."""
+        return EventID.from_string(s, hs=self)
+
     def serialize_event(self, e):
         return serialize_event(self, e)
 
@@ -218,6 +225,9 @@ class HomeServer(BaseHomeServer):
     def build_keyring(self):
         return Keyring(self)
 
+    def build_event_validator(self):
+        return EventValidator(self)
+
     def register_servlets(self):
         """ Register all servlets associated with this HomeServer.
         """
diff --git a/synapse/state.py b/synapse/state.py
index 9db84c9b5c..11c54fd38c 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -16,11 +16,13 @@
 
 from twisted.internet import defer
 
-from synapse.federation.pdu_codec import encode_event_id, decode_event_id
 from synapse.util.logutils import log_function
+from synapse.util.async import run_on_reactor
+from synapse.api.events.room import RoomPowerLevelsEvent
 
 from collections import namedtuple
 
+import copy
 import logging
 import hashlib
 
@@ -35,230 +37,169 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
 class StateHandler(object):
-    """ Repsonsible for doing state conflict resolution.
+    """ Responsible for doing state conflict resolution.
     """
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        self._replication = hs.get_replication_layer()
-        self.server_name = hs.hostname
 
     @defer.inlineCallbacks
     @log_function
-    def handle_new_event(self, event, snapshot):
-        """ Given an event this works out if a) we have sufficient power level
-        to update the state and b) works out what the prev_state should be.
+    def annotate_state_groups(self, event, old_state=None):
+        yield run_on_reactor()
 
-        Returns:
-            Deferred: Resolved with a boolean indicating if we succesfully
-            updated the state.
+        if old_state:
+            event.state_group = None
+            event.old_state_events = {
+                (s.type, s.state_key): s for s in old_state
+            }
+            event.state_events = event.old_state_events
 
-        Raised:
-            AuthError
-        """
-        # This needs to be done in a transaction.
+            if hasattr(event, "state_key"):
+                event.state_events[(event.type, event.state_key)] = event
 
-        if not hasattr(event, "state_key"):
+            defer.returnValue(False)
             return
 
-        key = KeyStateTuple(
-            event.room_id,
-            event.type,
-            _get_state_key_from_event(event)
-        )
-
-        # Now I need to fill out the prev state and work out if it has auth
-        # (w.r.t. to power levels)
-
-        snapshot.fill_out_prev_events(event)
-
-        event.prev_events = [
-            e for e in event.prev_events if e != event.event_id
-        ]
+        if hasattr(event, "outlier") and event.outlier:
+            event.state_group = None
+            event.old_state_events = None
+            event.state_events = {}
+            defer.returnValue(False)
+            return
 
-        current_state = snapshot.prev_state_pdu
+        ids = [e for e, _ in event.prev_events]
 
-        if current_state:
-            event.prev_state = encode_event_id(
-                current_state.pdu_id, current_state.origin
-            )
+        ret = yield self.resolve_state_groups(ids)
+        state_group, new_state = ret
 
-        # TODO check current_state to see if the min power level is less
-        # than the power level of the user
-        # power_level = self._get_power_level_for_event(event)
+        event.old_state_events = copy.deepcopy(new_state)
 
-        pdu_id, origin = decode_event_id(event.event_id, self.server_name)
+        if hasattr(event, "state_key"):
+            key = (event.type, event.state_key)
+            if key in new_state:
+                event.replaces_state = new_state[key].event_id
+            new_state[key] = event
+        elif state_group:
+            event.state_group = state_group
+            event.state_events = new_state
+            defer.returnValue(False)
 
-        yield self.store.update_current_state(
-            pdu_id=pdu_id,
-            origin=origin,
-            context=key.context,
-            pdu_type=key.type,
-            state_key=key.state_key
-        )
+        event.state_group = None
+        event.state_events = new_state
 
-        defer.returnValue(True)
+        defer.returnValue(hasattr(event, "state_key"))
 
     @defer.inlineCallbacks
-    @log_function
-    def handle_new_state(self, new_pdu):
-        """ Apply conflict resolution to `new_pdu`.
-
-        This should be called on every new state pdu, regardless of whether or
-        not there is a conflict.
-
-        This function is safe against the race of it getting called with two
-        `PDU`s trying to update the same state.
-        """
-
-        # This needs to be done in a transaction.
-
-        is_new = yield self._handle_new_state(new_pdu)
+    def get_current_state(self, room_id, event_type=None, state_key=""):
+        events = yield self.store.get_latest_events_in_room(room_id)
 
-        logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin)
+        event_ids = [
+            e_id
+            for e_id, _, _ in events
+        ]
 
-        if is_new:
-            yield self.store.update_current_state(
-                pdu_id=new_pdu.pdu_id,
-                origin=new_pdu.origin,
-                context=new_pdu.context,
-                pdu_type=new_pdu.pdu_type,
-                state_key=new_pdu.state_key
-            )
+        res = yield self.resolve_state_groups(event_ids)
 
-        defer.returnValue(is_new)
+        if event_type:
+            defer.returnValue(res[1].get((event_type, state_key)))
+            return
 
-    def _get_power_level_for_event(self, event):
-        # return self._persistence.get_power_level_for_user(event.room_id,
-            # event.sender)
-        return event.power_level
+        defer.returnValue(res[1].values())
 
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_state(self, new_pdu):
-        tree, missing_branch = yield self.store.get_unresolved_state_tree(
-            new_pdu
-        )
-        new_branch, current_branch = tree
-
-        logger.debug(
-            "_handle_new_state new=%s, current=%s",
-            new_branch, current_branch
+    def resolve_state_groups(self, event_ids):
+        state_groups = yield self.store.get_state_groups(
+            event_ids
         )
 
-        if missing_branch is not None:
-            # We're missing some PDUs. Fetch them.
-            # TODO (erikj): Limit this.
-            missing_prev = tree[missing_branch][-1]
-
-            pdu_id = missing_prev.prev_state_id
-            origin = missing_prev.prev_state_origin
-
-            is_missing = yield self.store.get_pdu(pdu_id, origin) is None
-            if not is_missing:
-                raise Exception("Conflict resolution failed")
-
-            yield self._replication.get_pdu(
-                destination=missing_prev.origin,
-                pdu_origin=origin,
-                pdu_id=pdu_id,
-                outlier=True
-            )
-
-            updated_current = yield self._handle_new_state(new_pdu)
-            defer.returnValue(updated_current)
-
-        if not current_branch:
-            # There is no current state
-            defer.returnValue(True)
-            return
-
-        n = new_branch[-1]
-        c = current_branch[-1]
-
-        common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
-
-        if common_ancestor:
-            # We found a common ancestor!
-
-            if len(current_branch) == 1:
-                # This is a direct clobber so we can just...
-                defer.returnValue(True)
+        group_names = set(state_groups.keys())
+        if len(group_names) == 1:
+            name, state_list = state_groups.items().pop()
+            state = {
+                (e.type, e.state_key): e
+                for e in state_list
+            }
+            defer.returnValue((name, state))
+
+        state = {}
+        for group, g_state in state_groups.items():
+            for s in g_state:
+                state.setdefault(
+                    (s.type, s.state_key),
+                    {}
+                )[s.event_id] = s
+
+        unconflicted_state = {
+            k: v.values()[0] for k, v in state.items()
+            if len(v.values()) == 1
+        }
+
+        conflicted_state = {
+            k: v.values()
+            for k, v in state.items()
+            if len(v.values()) > 1
+        }
+
+        try:
+            new_state = {}
+            new_state.update(unconflicted_state)
+            for key, events in conflicted_state.items():
+                new_state[key] = self._resolve_state_events(events)
+        except:
+            logger.exception("Failed to resolve state")
+            raise
+
+        defer.returnValue((None, new_state))
+
+    def _get_power_level_from_event_state(self, event, user_id):
+        if hasattr(event, "old_state_events") and event.old_state_events:
+            key = (RoomPowerLevelsEvent.TYPE, "", )
+            power_level_event = event.old_state_events.get(key)
+            level = None
+            if power_level_event:
+                level = power_level_event.content.get("users", {}).get(
+                    user_id
+                )
+                if not level:
+                    level = power_level_event.content.get("users_default", 0)
 
+            return level
         else:
-            # We didn't find a common ancestor. This is probably fine.
-            pass
+            return 0
 
-        result = yield self._do_conflict_res(
-            new_branch, current_branch, common_ancestor
-        )
-        defer.returnValue(result)
+    @log_function
+    def _resolve_state_events(self, events):
+        curr_events = events
 
-    @defer.inlineCallbacks
-    def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
-        conflict_res = [
-            self._do_power_level_conflict_res,
-            self._do_chain_length_conflict_res,
-            self._do_hash_conflict_res,
+        new_powers = [
+            self._get_power_level_from_event_state(e, e.user_id)
+            for e in curr_events
         ]
 
-        for algo in conflict_res:
-            new_res, curr_res = yield defer.maybeDeferred(
-                algo,
-                new_branch, current_branch, common_ancestor
-            )
-
-            if new_res < curr_res:
-                defer.returnValue(False)
-            elif new_res > curr_res:
-                defer.returnValue(True)
-
-        raise Exception("Conflict resolution failed.")
-
-    @defer.inlineCallbacks
-    def _do_power_level_conflict_res(self, new_branch, current_branch,
-                                     common_ancestor):
-        new_powers_deferreds = []
-        for e in new_branch[:-1] if common_ancestor else new_branch:
-            if hasattr(e, "user_id"):
-                new_powers_deferreds.append(
-                    self.store.get_power_level(e.context, e.user_id)
-                )
-
-        current_powers_deferreds = []
-        for e in current_branch[:-1] if common_ancestor else current_branch:
-            if hasattr(e, "user_id"):
-                current_powers_deferreds.append(
-                    self.store.get_power_level(e.context, e.user_id)
-                )
-
-        new_powers = yield defer.gatherResults(
-            new_powers_deferreds,
-            consumeErrors=True
-        )
-
-        current_powers = yield defer.gatherResults(
-            current_powers_deferreds,
-            consumeErrors=True
-        )
+        new_powers = [
+            int(p) if p else 0 for p in new_powers
+        ]
 
-        max_power_new = max(new_powers)
-        max_power_current = max(current_powers)
+        max_power = max(new_powers)
 
-        defer.returnValue(
-            (max_power_new, max_power_current)
-        )
-
-    def _do_chain_length_conflict_res(self, new_branch, current_branch,
-                                      common_ancestor):
-        return (len(new_branch), len(current_branch))
+        curr_events = [
+            z[0] for z in zip(curr_events, new_powers)
+            if z[1] == max_power
+        ]
 
-    def _do_hash_conflict_res(self, new_branch, current_branch,
-                              common_ancestor):
-        new_str = "".join([p.pdu_id + p.origin for p in new_branch])
-        c_str = "".join([p.pdu_id + p.origin for p in current_branch])
+        if not curr_events:
+            raise RuntimeError("Max didn't get a max?")
+        elif len(curr_events) == 1:
+            return curr_events[0]
 
+        # TODO: For now, just choose the one with the largest event_id.
         return (
-            hashlib.sha1(new_str).hexdigest(),
-            hashlib.sha1(c_str).hexdigest()
+            sorted(
+                curr_events,
+                key=lambda e: hashlib.sha1(
+                    e.event_id + e.user_id + e.room_id + e.type
+                ).hexdigest()
+            )[0]
         )
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4e9291fdff..4034437f6b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -16,14 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.events.room import (
-    RoomMemberEvent, RoomTopicEvent, FeedbackEvent,
-#   RoomConfigEvent,
-    RoomNameEvent,
-    RoomJoinRulesEvent,
-    RoomPowerLevelsEvent,
-    RoomAddStateLevelEvent,
-    RoomSendEventLevelEvent,
-    RoomOpsPowerLevelsEvent,
+    RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent,
     RoomRedactionEvent,
 )
 
@@ -37,9 +30,17 @@ from .registration import RegistrationStore
 from .room import RoomStore
 from .roommember import RoomMemberStore
 from .stream import StreamStore
-from .pdu import StatePduStore, PduStore, PdusTable
 from .transactions import TransactionStore
 from .keys import KeyStore
+from .event_federation import EventFederationStore
+
+from .state import StateStore
+from .signatures import SignatureStore
+
+from syutil.base64util import decode_base64
+
+from synapse.crypto.event_signing import compute_event_reference_hash
+
 
 import json
 import logging
@@ -51,7 +52,6 @@ logger = logging.getLogger(__name__)
 
 SCHEMAS = [
     "transactions",
-    "pdu",
     "users",
     "profiles",
     "presence",
@@ -59,6 +59,9 @@ SCHEMAS = [
     "room_aliases",
     "keys",
     "redactions",
+    "state",
+    "event_edges",
+    "event_signatures",
 ]
 
 
@@ -73,10 +76,12 @@ class _RollbackButIsFineException(Exception):
     """
     pass
 
+
 class DataStore(RoomMemberStore, RoomStore,
                 RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
-                PresenceStore, PduStore, StatePduStore, TransactionStore,
-                DirectoryStore, KeyStore):
+                PresenceStore, TransactionStore,
+                DirectoryStore, KeyStore, StateStore, SignatureStore,
+                EventFederationStore, ):
 
     def __init__(self, hs):
         super(DataStore, self).__init__(hs)
@@ -88,8 +93,7 @@ class DataStore(RoomMemberStore, RoomStore,
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event=None, backfilled=False, pdu=None,
-                      is_new_state=True):
+    def persist_event(self, event, backfilled=False, is_new_state=True):
         stream_ordering = None
         if backfilled:
             if not self.min_token_deferred.called:
@@ -99,8 +103,8 @@ class DataStore(RoomMemberStore, RoomStore,
 
         try:
             yield self.runInteraction(
-                self._persist_pdu_event_txn,
-                pdu=pdu,
+                "persist_event",
+                self._persist_event_txn,
                 event=event,
                 backfilled=backfilled,
                 stream_ordering=stream_ordering,
@@ -119,7 +123,8 @@ class DataStore(RoomMemberStore, RoomStore,
                 "type",
                 "room_id",
                 "content",
-                "unrecognized_keys"
+                "unrecognized_keys",
+                "depth",
             ],
             allow_none=allow_none,
         )
@@ -130,42 +135,6 @@ class DataStore(RoomMemberStore, RoomStore,
         event = self._parse_event_from_row(events_dict)
         defer.returnValue(event)
 
-    def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
-                               backfilled=False, stream_ordering=None,
-                               is_new_state=True):
-        if pdu is not None:
-            self._persist_event_pdu_txn(txn, pdu)
-        if event is not None:
-            return self._persist_event_txn(
-                txn, event, backfilled, stream_ordering,
-                is_new_state=is_new_state,
-            )
-
-    def _persist_event_pdu_txn(self, txn, pdu):
-        cols = dict(pdu.__dict__)
-        unrec_keys = dict(pdu.unrecognized_keys)
-        del cols["content"]
-        del cols["prev_pdus"]
-        cols["content_json"] = json.dumps(pdu.content)
-
-        unrec_keys.update({
-            k: v for k, v in cols.items()
-            if k not in PdusTable.fields
-        })
-
-        cols["unrecognized_keys"] = json.dumps(unrec_keys)
-
-        cols["ts"] = cols.pop("origin_server_ts")
-
-        logger.debug("Persisting: %s", repr(cols))
-
-        if pdu.is_state:
-            self._persist_state_txn(txn, pdu.prev_pdus, cols)
-        else:
-            self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
-
-        self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
-
     @log_function
     def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
                            is_new_state=True):
@@ -177,19 +146,13 @@ class DataStore(RoomMemberStore, RoomStore,
             self._store_room_name_txn(txn, event)
         elif event.type == RoomTopicEvent.TYPE:
             self._store_room_topic_txn(txn, event)
-        elif event.type == RoomJoinRulesEvent.TYPE:
-            self._store_join_rule(txn, event)
-        elif event.type == RoomPowerLevelsEvent.TYPE:
-            self._store_power_levels(txn, event)
-        elif event.type == RoomAddStateLevelEvent.TYPE:
-            self._store_add_state_level(txn, event)
-        elif event.type == RoomSendEventLevelEvent.TYPE:
-            self._store_send_event_level(txn, event)
-        elif event.type == RoomOpsPowerLevelsEvent.TYPE:
-            self._store_ops_level(txn, event)
         elif event.type == RoomRedactionEvent.TYPE:
             self._store_redaction(txn, event)
 
+        outlier = False
+        if hasattr(event, "outlier"):
+            outlier = event.outlier
+
         vals = {
             "topological_ordering": event.depth,
             "event_id": event.event_id,
@@ -197,25 +160,33 @@ class DataStore(RoomMemberStore, RoomStore,
             "room_id": event.room_id,
             "content": json.dumps(event.content),
             "processed": True,
+            "outlier": outlier,
+            "depth": event.depth,
         }
 
         if stream_ordering is not None:
             vals["stream_ordering"] = stream_ordering
 
-        if hasattr(event, "outlier"):
-            vals["outlier"] = event.outlier
-        else:
-            vals["outlier"] = False
-
         unrec = {
             k: v
             for k, v in event.get_full_dict().items()
-            if k not in vals.keys() and k not in ["redacted", "redacted_because"]
+            if k not in vals.keys() and k not in [
+                "redacted",
+                "redacted_because",
+                "signatures",
+                "hashes",
+                "prev_events",
+            ]
         }
         vals["unrecognized_keys"] = json.dumps(unrec)
 
         try:
-            self._simple_insert_txn(txn, "events", vals)
+            self._simple_insert_txn(
+                txn,
+                "events",
+                vals,
+                or_replace=(not outlier),
+            )
         except:
             logger.warn(
                 "Failed to persist, probably duplicate: %s",
@@ -224,6 +195,16 @@ class DataStore(RoomMemberStore, RoomStore,
             )
             raise _RollbackButIsFineException("_persist_event")
 
+        self._handle_prev_events(
+            txn,
+            outlier=outlier,
+            event_id=event.event_id,
+            prev_events=event.prev_events,
+            room_id=event.room_id,
+        )
+
+        self._store_state_groups_txn(txn, event)
+
         is_state = hasattr(event, "state_key") and event.state_key is not None
         if is_new_state and is_state:
             vals = {
@@ -233,8 +214,8 @@ class DataStore(RoomMemberStore, RoomStore,
                 "state_key": event.state_key,
             }
 
-            if hasattr(event, "prev_state"):
-                vals["prev_state"] = event.prev_state
+            if hasattr(event, "replaces_state"):
+                vals["prev_state"] = event.replaces_state
 
             self._simple_insert_txn(txn, "state_events", vals)
 
@@ -249,6 +230,81 @@ class DataStore(RoomMemberStore, RoomStore,
                 }
             )
 
+            for e_id, h in event.prev_state:
+                self._simple_insert_txn(
+                    txn,
+                    table="event_edges",
+                    values={
+                        "event_id": event.event_id,
+                        "prev_event_id": e_id,
+                        "room_id": event.room_id,
+                        "is_state": 1,
+                    },
+                    or_ignore=True,
+                )
+
+            if not backfilled:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_forward_extremities",
+                    values={
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "type": event.type,
+                        "state_key": event.state_key,
+                    }
+                )
+
+                for prev_state_id, _ in event.prev_state:
+                    self._simple_delete_txn(
+                        txn,
+                        table="state_forward_extremities",
+                        keyvalues={
+                            "event_id": prev_state_id,
+                        }
+                    )
+
+        for hash_alg, hash_base64 in event.hashes.items():
+            hash_bytes = decode_base64(hash_base64)
+            self._store_event_content_hash_txn(
+                txn, event.event_id, hash_alg, hash_bytes,
+            )
+
+        if hasattr(event, "signatures"):
+            signatures = event.signatures.get(event.origin, {})
+
+            for key_id, signature_base64 in signatures.items():
+                signature_bytes = decode_base64(signature_base64)
+                self._store_event_origin_signature_txn(
+                    txn, event.event_id, event.origin, key_id, signature_bytes,
+                )
+
+        for prev_event_id, prev_hashes in event.prev_events:
+            for alg, hash_base64 in prev_hashes.items():
+                hash_bytes = decode_base64(hash_base64)
+                self._store_prev_event_hash_txn(
+                    txn, event.event_id, prev_event_id, alg, hash_bytes
+                )
+
+        for auth_id, _ in event.auth_events:
+            self._simple_insert_txn(
+                txn,
+                table="event_auth",
+                values={
+                    "event_id": event.event_id,
+                    "room_id": event.room_id,
+                    "auth_id": auth_id,
+                },
+                or_ignore=True,
+            )
+
+        (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
+        self._store_event_reference_hash_txn(
+            txn, event.event_id, ref_alg, ref_hash_bytes
+        )
+
+        self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
+
     def _store_redaction(self, txn, event):
         txn.execute(
             "INSERT OR IGNORE INTO redactions "
@@ -319,7 +375,7 @@ class DataStore(RoomMemberStore, RoomStore,
             ],
         )
 
-    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+    def snapshot_room(self, event):
         """Snapshot the room for an update by a user
         Args:
             room_id (synapse.types.RoomId): The room to snapshot.
@@ -330,29 +386,33 @@ class DataStore(RoomMemberStore, RoomStore,
             synapse.storage.Snapshot: A snapshot of the state of the room.
         """
         def _snapshot(txn):
-            membership_state = self._get_room_member(txn, user_id, room_id)
-            prev_pdus = self._get_latest_pdus_in_context(
-                txn, room_id
+            prev_events = self._get_latest_events_in_room(
+                txn,
+                event.room_id
             )
-            if state_type is not None and state_key is not None:
-                prev_state_pdu = self._get_current_state_pdu(
-                    txn, room_id, state_type, state_key
+
+            prev_state = None
+            state_key = None
+            if hasattr(event, "state_key"):
+                state_key = event.state_key
+                prev_state = self._get_latest_state_in_room(
+                    txn,
+                    event.room_id,
+                    type=event.type,
+                    state_key=state_key,
                 )
-            else:
-                prev_state_pdu = None
 
             return Snapshot(
                 store=self,
-                room_id=room_id,
-                user_id=user_id,
-                prev_pdus=prev_pdus,
-                membership_state=membership_state,
-                state_type=state_type,
+                room_id=event.room_id,
+                user_id=event.user_id,
+                prev_events=prev_events,
+                prev_state=prev_state,
+                state_type=event.type,
                 state_key=state_key,
-                prev_state_pdu=prev_state_pdu,
             )
 
-        return self.runInteraction(_snapshot)
+        return self.runInteraction("snapshot_room", _snapshot)
 
 
 class Snapshot(object):
@@ -361,7 +421,7 @@ class Snapshot(object):
         store (DataStore): The datastore.
         room_id (RoomId): The room of the snapshot.
         user_id (UserId): The user this snapshot is for.
-        prev_pdus (list): The list of PDU ids this snapshot is after.
+        prev_events (list): The list of event ids this snapshot is after.
         membership_state (RoomMemberEvent): The current state of the user in
             the room.
         state_type (str, optional): State type captured by the snapshot
@@ -370,32 +430,30 @@ class Snapshot(object):
             the previous value of the state type and key in the room.
     """
 
-    def __init__(self, store, room_id, user_id, prev_pdus,
-                 membership_state, state_type=None, state_key=None,
-                 prev_state_pdu=None):
+    def __init__(self, store, room_id, user_id, prev_events,
+                 prev_state, state_type=None, state_key=None):
         self.store = store
         self.room_id = room_id
         self.user_id = user_id
-        self.prev_pdus = prev_pdus
-        self.membership_state = membership_state
+        self.prev_events = prev_events
+        self.prev_state = prev_state
         self.state_type = state_type
         self.state_key = state_key
-        self.prev_state_pdu = prev_state_pdu
 
     def fill_out_prev_events(self, event):
-        if hasattr(event, "prev_events"):
-            return
-
-        es = [
-            "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
-        ]
-
-        event.prev_events = [e for e in es if e != event.event_id]
+        if not hasattr(event, "prev_events"):
+            event.prev_events = [
+                (event_id, hashes)
+                for event_id, hashes, _ in self.prev_events
+            ]
+
+            if self.prev_events:
+                event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
+            else:
+                event.depth = 0
 
-        if self.prev_pdus:
-            event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
-        else:
-            event.depth = 0
+        if not hasattr(event, "prev_state") and self.prev_state is not None:
+            event.prev_state = self.prev_state
 
 
 def schema_path(schema):
@@ -436,11 +494,13 @@ def prepare_database(db_conn):
         user_version = row[0]
 
         if user_version > SCHEMA_VERSION:
-            raise ValueError("Cannot use this database as it is too " +
+            raise ValueError(
+                "Cannot use this database as it is too " +
                 "new for the server to understand"
             )
         elif user_version < SCHEMA_VERSION:
-            logging.info("Upgrading database from version %d",
+            logging.info(
+                "Upgrading database from version %d",
                 user_version
             )
 
@@ -452,13 +512,13 @@ def prepare_database(db_conn):
             db_conn.commit()
 
     else:
-        sql_script = "BEGIN TRANSACTION;"
+        sql_script = "BEGIN TRANSACTION;\n"
         for sql_loc in SCHEMAS:
             sql_script += read_schema(sql_loc)
+            sql_script += "\n"
         sql_script += "COMMIT TRANSACTION;"
         c.executescript(sql_script)
         db_conn.commit()
         c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
 
     c.close()
-
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 65a86e9056..a1ee0318f6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,59 +14,69 @@
 # limitations under the License.
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.errors import StoreError
 from synapse.api.events.utils import prune_event
 from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
 
 import collections
 import copy
 import json
+import sys
+import time
 
 
 logger = logging.getLogger(__name__)
 
 sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
 
 
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging to the .execute() method."""
-    __slots__ = ["txn"]
+    __slots__ = ["txn", "name"]
 
-    def __init__(self, txn):
+    def __init__(self, txn, name):
         object.__setattr__(self, "txn", txn)
+        object.__setattr__(self, "name", name)
 
-    def __getattribute__(self, name):
-        if name == "execute":
-            return object.__getattribute__(self, "execute")
-
-        return getattr(object.__getattribute__(self, "txn"), name)
+    def __getattr__(self, name):
+        return getattr(self.txn, name)
 
     def __setattr__(self, name, value):
-        setattr(object.__getattribute__(self, "txn"), name, value)
+        setattr(self.txn, name, value)
 
     def execute(self, sql, *args, **kwargs):
         # TODO(paul): Maybe use 'info' and 'debug' for values?
-        sql_logger.debug("[SQL] %s", sql)
+        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
         try:
             if args and args[0]:
                 values = args[0]
-                sql_logger.debug("[SQL values] " +
-                    ", ".join(("<%s>",) * len(values)), *values)
+                sql_logger.debug(
+                    "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
+                    self.name,
+                    *values
+                )
         except:
             # Don't let logging failures stop SQL from working
             pass
 
-        # TODO(paul): Here would be an excellent place to put some timing
-        #   measurements, and log (warning?) slow queries.
-        return object.__getattribute__(self, "txn").execute(
-            sql, *args, **kwargs
-        )
+        start = time.clock() * 1000
+        try:
+            return self.txn.execute(
+                sql, *args, **kwargs
+            )
+        except:
+                logger.exception("[SQL FAIL] {%s}", self.name)
+                raise
+        finally:
+            end = time.clock() * 1000
+            sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
 
 
 class SQLBaseStore(object):
+    _TXN_ID = 0
 
     def __init__(self, hs):
         self.hs = hs
@@ -74,10 +84,30 @@ class SQLBaseStore(object):
         self.event_factory = hs.get_event_factory()
         self._clock = hs.get_clock()
 
-    def runInteraction(self, func, *args, **kwargs):
+    def runInteraction(self, desc, func, *args, **kwargs):
         """Wraps the .runInteraction() method on the underlying db_pool."""
         def inner_func(txn, *args, **kwargs):
-            return func(LoggingTransaction(txn), *args, **kwargs)
+            start = time.clock() * 1000
+            txn_id = SQLBaseStore._TXN_ID
+
+            # We don't really need these to be unique, so lets stop it from
+            # growing really large.
+            self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+            name = "%s-%x" % (desc, txn_id, )
+
+            transaction_logger.debug("[TXN START] {%s}", name)
+            try:
+                return func(LoggingTransaction(txn, name), *args, **kwargs)
+            except:
+                logger.exception("[TXN FAIL] {%s}", name)
+                raise
+            finally:
+                end = time.clock() * 1000
+                transaction_logger.debug(
+                    "[TXN END] {%s} %f",
+                    name, end - start
+                )
 
         return self._db_pool.runInteraction(inner_func, *args, **kwargs)
 
@@ -113,7 +143,7 @@ class SQLBaseStore(object):
             else:
                 return cursor.fetchall()
 
-        return self.runInteraction(interaction)
+        return self.runInteraction("_execute", interaction)
 
     def _execute_and_decode(self, query, *args):
         return self._execute(self.cursor_to_dict, query, *args)
@@ -130,6 +160,7 @@ class SQLBaseStore(object):
             or_replace : bool; if True performs an INSERT OR REPLACE
         """
         return self.runInteraction(
+            "_simple_insert",
             self._simple_insert_txn, table, values, or_replace=or_replace,
             or_ignore=or_ignore,
         )
@@ -170,7 +201,6 @@ class SQLBaseStore(object):
             table, keyvalues, retcols=retcols, allow_none=allow_none
         )
 
-    @defer.inlineCallbacks
     def _simple_select_one_onecol(self, table, keyvalues, retcol,
                                   allow_none=False):
         """Executes a SELECT query on the named table, which is expected to
@@ -181,19 +211,40 @@ class SQLBaseStore(object):
             keyvalues : dict of column names and values to select the row with
             retcol : string giving the name of the column to return
         """
-        ret = yield self._simple_select_one(
+        return self.runInteraction(
+            "_simple_select_one_onecol",
+            self._simple_select_one_onecol_txn,
+            table, keyvalues, retcol, allow_none=allow_none,
+        )
+
+    def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+                                      allow_none=False):
+        ret = self._simple_select_onecol_txn(
+            txn,
             table=table,
             keyvalues=keyvalues,
-            retcols=[retcol],
-            allow_none=allow_none
+            retcol=retcol,
         )
 
         if ret:
-            defer.returnValue(ret[retcol])
+            return ret[0]
         else:
-            defer.returnValue(None)
+            if allow_none:
+                return None
+            else:
+                raise StoreError(404, "No row found")
+
+    def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+        sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+            "retcol": retcol,
+            "table": table,
+            "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+        }
+
+        txn.execute(sql, keyvalues.values())
+
+        return [r[0] for r in txn.fetchall()]
 
-    @defer.inlineCallbacks
     def _simple_select_onecol(self, table, keyvalues, retcol):
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
@@ -206,25 +257,33 @@ class SQLBaseStore(object):
         Returns:
             Deferred: Results in a list
         """
-        sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
-            "retcol": retcol,
-            "table": table,
-            "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
-        }
-
-        def func(txn):
-            txn.execute(sql, keyvalues.values())
-            return txn.fetchall()
+        return self.runInteraction(
+            "_simple_select_onecol",
+            self._simple_select_onecol_txn,
+            table, keyvalues, retcol
+        )
 
-        res = yield self.runInteraction(func)
+    def _simple_select_list(self, table, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
 
-        defer.returnValue([r[0] for r in res])
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        return self.runInteraction(
+            "_simple_select_list",
+            self._simple_select_list_txn,
+            table, keyvalues, retcols
+        )
 
-    def _simple_select_list(self, table, keyvalues, retcols):
+    def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
+            txn : Transaction object
             table : string giving the table name
             keyvalues : dict of column names and values to select the rows with
             retcols : list of strings giving the names of the columns to return
@@ -232,14 +291,11 @@ class SQLBaseStore(object):
         sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
-            " AND ".join("%s = ?" % (k) for k in keyvalues)
+            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        def func(txn):
-            txn.execute(sql, keyvalues.values())
-            return self.cursor_to_dict(txn)
-
-        return self.runInteraction(func)
+        txn.execute(sql, keyvalues.values())
+        return self.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            retcols=None):
@@ -307,7 +363,7 @@ class SQLBaseStore(object):
                     raise StoreError(500, "More than one row matched")
 
             return ret
-        return self.runInteraction(func)
+        return self.runInteraction("_simple_selectupdate_one", func)
 
     def _simple_delete_one(self, table, keyvalues):
         """Executes a DELETE query on the named table, expecting to delete a
@@ -319,7 +375,7 @@ class SQLBaseStore(object):
         """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
-            " AND ".join("%s = ?" % (k) for k in keyvalues)
+            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
         def func(txn):
@@ -328,7 +384,25 @@ class SQLBaseStore(object):
                 raise StoreError(404, "No row found")
             if txn.rowcount > 1:
                 raise StoreError(500, "more than one row matched")
-        return self.runInteraction(func)
+        return self.runInteraction("_simple_delete_one", func)
+
+    def _simple_delete(self, table, keyvalues):
+        """Executes a DELETE query on the named table.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+
+        return self.runInteraction("_simple_delete", self._simple_delete_txn)
+
+    def _simple_delete_txn(self, txn, table, keyvalues):
+        sql = "DELETE FROM %s WHERE %s" % (
+            table,
+            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+        )
+
+        return txn.execute(sql, keyvalues.values())
 
     def _simple_max_id(self, table):
         """Executes a SELECT query on the named table, expecting to return the
@@ -346,7 +420,7 @@ class SQLBaseStore(object):
                 return 0
             return max_id
 
-        return self.runInteraction(func)
+        return self.runInteraction("_simple_max_id", func)
 
     def _parse_event_from_row(self, row_dict):
         d = copy.deepcopy({k: v for k, v in row_dict.items()})
@@ -355,6 +429,10 @@ class SQLBaseStore(object):
         d.pop("topological_ordering", None)
         d.pop("processed", None)
         d["origin_server_ts"] = d.pop("ts", 0)
+        replaces_state = d.pop("prev_state", None)
+
+        if replaces_state:
+            d["replaces_state"] = replaces_state
 
         d.update(json.loads(row_dict["unrecognized_keys"]))
         d["content"] = json.loads(d["content"])
@@ -369,23 +447,65 @@ class SQLBaseStore(object):
             **d
         )
 
+    def _get_events_txn(self, txn, event_ids):
+        # FIXME (erikj): This should be batched?
+
+        sql = "SELECT * FROM events WHERE event_id = ?"
+
+        event_rows = []
+        for e_id in event_ids:
+            c = txn.execute(sql, (e_id,))
+            event_rows.extend(self.cursor_to_dict(c))
+
+        return self._parse_events_txn(txn, event_rows)
+
     def _parse_events(self, rows):
-        return self.runInteraction(self._parse_events_txn, rows)
+        return self.runInteraction(
+            "_parse_events", self._parse_events_txn, rows
+        )
 
     def _parse_events_txn(self, txn, rows):
         events = [self._parse_event_from_row(r) for r in rows]
 
-        sql = "SELECT * FROM events WHERE event_id = ?"
+        select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+
+        for i, ev in enumerate(events):
+            signatures = self._get_event_origin_signatures_txn(
+                txn, ev.event_id,
+            )
 
-        for ev in events:
-            if hasattr(ev, "prev_state"):
-                # Load previous state_content.
-                # TODO: Should we be pulling this out above?
-                cursor = txn.execute(sql, (ev.prev_state,))
-                prevs = self.cursor_to_dict(cursor)
-                if prevs:
-                    prev = self._parse_event_from_row(prevs[0])
-                    ev.prev_content = prev.content
+            ev.signatures = {
+                k: encode_base64(v) for k, v in signatures.items()
+            }
+
+            prevs = self._get_prev_events_and_state(txn, ev.event_id)
+
+            ev.prev_events = [
+                (e_id, h)
+                for e_id, h, is_state in prevs
+                if is_state == 0
+            ]
+
+            ev.auth_events = self._get_auth_events(txn, ev.event_id)
+
+            if hasattr(ev, "state_key"):
+                ev.prev_state = [
+                    (e_id, h)
+                    for e_id, h, is_state in prevs
+                    if is_state == 1
+                ]
+
+                if hasattr(ev, "replaces_state"):
+                    # Load previous state_content.
+                    # FIXME (erikj): Handle multiple prev_states.
+                    cursor = txn.execute(
+                        select_event_sql,
+                        (ev.replaces_state,)
+                    )
+                    prevs = self.cursor_to_dict(cursor)
+                    if prevs:
+                        prev = self._parse_event_from_row(prevs[0])
+                        ev.prev_content = prev.content
 
             if not hasattr(ev, "redacted"):
                 logger.debug("Doesn't have redacted key: %s", ev)
@@ -393,15 +513,16 @@ class SQLBaseStore(object):
 
             if ev.redacted:
                 # Get the redaction event.
-                sql = "SELECT * FROM events WHERE event_id = ?"
-                txn.execute(sql, (ev.redacted,))
+                select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+                txn.execute(select_event_sql, (ev.redacted,))
 
                 del_evs = self._parse_events_txn(
                     txn, self.cursor_to_dict(txn)
                 )
 
                 if del_evs:
-                    prune_event(ev)
+                    ev = prune_event(ev)
+                    events[i] = ev
                     ev.redacted_because = del_evs[0]
 
         return events
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 52373a28a6..d6a7113b9c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore):
 
     def delete_room_alias(self, room_alias):
         return self.runInteraction(
+            "delete_room_alias",
             self._delete_room_alias_txn,
             room_alias,
         )
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
new file mode 100644
index 0000000000..a027db3868
--- /dev/null
+++ b/synapse/storage/event_federation.py
@@ -0,0 +1,377 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from syutil.base64util import encode_base64
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class EventFederationStore(SQLBaseStore):
+
+    def get_auth_chain(self, event_id):
+        return self.runInteraction(
+            "get_auth_chain",
+            self._get_auth_chain_txn,
+            event_id
+        )
+
+    def _get_auth_chain_txn(self, txn, event_id):
+        results = self._get_auth_chain_ids_txn(txn, event_id)
+
+        sql = "SELECT * FROM events WHERE event_id = ?"
+        rows = []
+        for ev_id in results:
+            c = txn.execute(sql, (ev_id,))
+            rows.extend(self.cursor_to_dict(c))
+
+        return self._parse_events_txn(txn, rows)
+
+    def get_auth_chain_ids(self, event_id):
+        return self.runInteraction(
+            "get_auth_chain_ids",
+            self._get_auth_chain_ids_txn,
+            event_id
+        )
+
+    def _get_auth_chain_ids_txn(self, txn, event_id):
+        results = set()
+
+        base_sql = (
+            "SELECT auth_id FROM event_auth WHERE %s"
+        )
+
+        front = set([event_id])
+        while front:
+            sql = base_sql % (
+                " OR ".join(["event_id=?"] * len(front)),
+            )
+
+            txn.execute(sql, list(front))
+            front = [r[0] for r in txn.fetchall()]
+            results.update(front)
+
+        return list(results)
+
+    def get_oldest_events_in_room(self, room_id):
+        return self.runInteraction(
+            "get_oldest_events_in_room",
+            self._get_oldest_events_in_room_txn,
+            room_id,
+        )
+
+    def _get_oldest_events_in_room_txn(self, txn, room_id):
+        return self._simple_select_onecol_txn(
+            txn,
+            table="event_backward_extremities",
+            keyvalues={
+                "room_id": room_id,
+            },
+            retcol="event_id",
+        )
+
+    def get_latest_events_in_room(self, room_id):
+        return self.runInteraction(
+            "get_latest_events_in_room",
+            self._get_latest_events_in_room,
+            room_id,
+        )
+
+    def _get_latest_events_in_room(self, txn, room_id):
+        sql = (
+            "SELECT e.event_id, e.depth FROM events as e "
+            "INNER JOIN event_forward_extremities as f "
+            "ON e.event_id = f.event_id "
+            "WHERE f.room_id = ?"
+        )
+
+        txn.execute(sql, (room_id, ))
+
+        results = []
+        for event_id, depth in txn.fetchall():
+            hashes = self._get_event_reference_hashes_txn(txn, event_id)
+            prev_hashes = {
+                k: encode_base64(v) for k, v in hashes.items()
+                if k == "sha256"
+            }
+            results.append((event_id, prev_hashes, depth))
+
+        return results
+
+    def _get_latest_state_in_room(self, txn, room_id, type, state_key):
+        event_ids = self._simple_select_onecol_txn(
+            txn,
+            table="state_forward_extremities",
+            keyvalues={
+                "room_id": room_id,
+                "type": type,
+                "state_key": state_key,
+            },
+            retcol="event_id",
+        )
+
+        results = []
+        for event_id in event_ids:
+            hashes = self._get_event_reference_hashes_txn(txn, event_id)
+            prev_hashes = {
+                k: encode_base64(v) for k, v in hashes.items()
+                if k == "sha256"
+            }
+            results.append((event_id, prev_hashes))
+
+        return results
+
+    def _get_prev_events(self, txn, event_id):
+        results = self._get_prev_events_and_state(
+            txn,
+            event_id,
+            is_state=0,
+        )
+
+        return [(e_id, h, ) for e_id, h, _ in results]
+
+    def _get_prev_state(self, txn, event_id):
+        results = self._get_prev_events_and_state(
+            txn,
+            event_id,
+            is_state=1,
+        )
+
+        return [(e_id, h, ) for e_id, h, _ in results]
+
+    def _get_prev_events_and_state(self, txn, event_id, is_state=None):
+        keyvalues = {
+            "event_id": event_id,
+        }
+
+        if is_state is not None:
+            keyvalues["is_state"] = is_state
+
+        res = self._simple_select_list_txn(
+            txn,
+            table="event_edges",
+            keyvalues=keyvalues,
+            retcols=["prev_event_id", "is_state"],
+        )
+
+        results = []
+        for d in res:
+            hashes = self._get_event_reference_hashes_txn(
+                txn,
+                d["prev_event_id"]
+            )
+            prev_hashes = {
+                k: encode_base64(v) for k, v in hashes.items()
+                if k == "sha256"
+            }
+            results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
+
+        return results
+
+    def _get_auth_events(self, txn, event_id):
+        auth_ids = self._simple_select_onecol_txn(
+            txn,
+            table="event_auth",
+            keyvalues={
+                "event_id": event_id,
+            },
+            retcol="auth_id",
+        )
+
+        results = []
+        for auth_id in auth_ids:
+            hashes = self._get_event_reference_hashes_txn(txn, auth_id)
+            prev_hashes = {
+                k: encode_base64(v) for k, v in hashes.items()
+                if k == "sha256"
+            }
+            results.append((auth_id, prev_hashes))
+
+        return results
+
+    def get_min_depth(self, room_id):
+        return self.runInteraction(
+            "get_min_depth",
+            self._get_min_depth_interaction,
+            room_id,
+        )
+
+    def _get_min_depth_interaction(self, txn, room_id):
+        min_depth = self._simple_select_one_onecol_txn(
+            txn,
+            table="room_depth",
+            keyvalues={"room_id": room_id},
+            retcol="min_depth",
+            allow_none=True,
+        )
+
+        return int(min_depth) if min_depth is not None else None
+
+    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+        min_depth = self._get_min_depth_interaction(txn, room_id)
+
+        do_insert = depth < min_depth if min_depth else True
+
+        if do_insert:
+            self._simple_insert_txn(
+                txn,
+                table="room_depth",
+                values={
+                    "room_id": room_id,
+                    "min_depth": depth,
+                },
+                or_replace=True,
+            )
+
+    def _handle_prev_events(self, txn, outlier, event_id, prev_events,
+                            room_id):
+        for e_id, _ in prev_events:
+            # TODO (erikj): This could be done as a bulk insert
+            self._simple_insert_txn(
+                txn,
+                table="event_edges",
+                values={
+                    "event_id": event_id,
+                    "prev_event_id": e_id,
+                    "room_id": room_id,
+                    "is_state": 0,
+                },
+                or_ignore=True,
+            )
+
+        # Update the extremities table if this is not an outlier.
+        if not outlier:
+            for e_id, _ in prev_events:
+                # TODO (erikj): This could be done as a bulk insert
+                self._simple_delete_txn(
+                    txn,
+                    table="event_forward_extremities",
+                    keyvalues={
+                        "event_id": e_id,
+                        "room_id": room_id,
+                    }
+                )
+
+            # We only insert as a forward extremity the new pdu if there are
+            # no other pdus that reference it as a prev pdu
+            query = (
+                "INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
+                "SELECT ?, ? WHERE NOT EXISTS ("
+                "SELECT 1 FROM %(event_edges)s WHERE "
+                "prev_event_id = ? "
+                ")"
+            ) % {
+                "table": "event_forward_extremities",
+                "event_edges": "event_edges",
+            }
+
+            logger.debug("query: %s", query)
+
+            txn.execute(query, (event_id, room_id, event_id))
+
+            # Insert all the prev_pdus as a backwards thing, they'll get
+            # deleted in a second if they're incorrect anyway.
+            for e_id, _ in prev_events:
+                # TODO (erikj): This could be done as a bulk insert
+                self._simple_insert_txn(
+                    txn,
+                    table="event_backward_extremities",
+                    values={
+                        "event_id": e_id,
+                        "room_id": room_id,
+                    },
+                    or_ignore=True,
+                )
+
+            # Also delete from the backwards extremities table all ones that
+            # reference pdus that we have already seen
+            query = (
+                "DELETE FROM event_backward_extremities WHERE EXISTS ("
+                "SELECT 1 FROM events "
+                "WHERE "
+                "event_backward_extremities.event_id = events.event_id "
+                "AND not events.outlier "
+                ")"
+            )
+            txn.execute(query)
+
+    def get_backfill_events(self, room_id, event_list, limit):
+        """Get a list of Events for a given topic that occured before (and
+        including) the pdus in pdu_list. Return a list of max size `limit`.
+
+        Args:
+            txn
+            room_id (str)
+            event_list (list)
+            limit (int)
+
+        Return:
+            list: A list of PduTuples
+        """
+        return self.runInteraction(
+            "get_backfill_events",
+            self._get_backfill_events, room_id, event_list, limit
+        )
+
+    def _get_backfill_events(self, txn, room_id, event_list, limit):
+        logger.debug(
+            "_get_backfill_events: %s, %s, %s",
+            room_id, repr(event_list), limit
+        )
+
+        # We seed the pdu_results with the things from the pdu_list.
+        event_results = event_list
+
+        front = event_list
+
+        query = (
+            "SELECT prev_event_id FROM event_edges "
+            "WHERE room_id = ? AND event_id = ? "
+            "LIMIT ?"
+        )
+
+        # We iterate through all event_ids in `front` to select their previous
+        # events. These are dumped in `new_front`.
+        # We continue until we reach the limit *or* new_front is empty (i.e.,
+        # we've run out of things to select
+        while front and len(event_results) < limit:
+
+            new_front = []
+            for event_id in front:
+                logger.debug(
+                    "_backfill_interaction: id=%s",
+                    event_id
+                )
+
+                txn.execute(
+                    query,
+                    (room_id, event_id, limit - len(event_results))
+                )
+
+                for row in txn.fetchall():
+                    logger.debug(
+                        "_backfill_interaction: got id=%s",
+                        *row
+                    )
+                    new_front.append(row[0])
+
+            front = new_front
+            event_results += new_front
+
+        # We also want to update the `prev_pdus` attributes before returning.
+        return self._get_events_txn(txn, event_results)
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
deleted file mode 100644
index d70467dcd6..0000000000
--- a/synapse/storage/pdu.py
+++ /dev/null
@@ -1,915 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-from ._base import SQLBaseStore, Table, JoinHelper
-
-from synapse.federation.units import Pdu
-from synapse.util.logutils import log_function
-
-from collections import namedtuple
-
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-class PduStore(SQLBaseStore):
-    """A collection of queries for handling PDUs.
-    """
-
-    def get_pdu(self, pdu_id, origin):
-        """Given a pdu_id and origin, get a PDU.
-
-        Args:
-            txn
-            pdu_id (str)
-            origin (str)
-
-        Returns:
-            PduTuple: If the pdu does not exist in the database, returns None
-        """
-
-        return self.runInteraction(
-            self._get_pdu_tuple, pdu_id, origin
-        )
-
-    def _get_pdu_tuple(self, txn, pdu_id, origin):
-        res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
-        return res[0] if res else None
-
-    def _get_pdu_tuples(self, txn, pdu_id_tuples):
-        results = []
-        for pdu_id, origin in pdu_id_tuples:
-            txn.execute(
-                PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
-                (pdu_id, origin)
-            )
-
-            edges = [
-                (r.prev_pdu_id, r.prev_origin)
-                for r in PduEdgesTable.decode_results(txn.fetchall())
-            ]
-
-            query = (
-                "SELECT %(fields)s FROM %(pdus)s as p "
-                "LEFT JOIN %(state)s as s "
-                "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
-                "WHERE p.pdu_id = ? AND p.origin = ? "
-            ) % {
-                "fields": _pdu_state_joiner.get_fields(
-                    PdusTable="p", StatePdusTable="s"),
-                "pdus": PdusTable.table_name,
-                "state": StatePdusTable.table_name,
-            }
-
-            txn.execute(query, (pdu_id, origin))
-
-            row = txn.fetchone()
-            if row:
-                results.append(PduTuple(PduEntry(*row), edges))
-
-        return results
-
-    def get_current_state_for_context(self, context):
-        """Get a list of PDUs that represent the current state for a given
-        context
-
-        Args:
-            context (str)
-
-        Returns:
-            list: A list of PduTuples
-        """
-
-        return self.runInteraction(
-            self._get_current_state_for_context,
-            context
-        )
-
-    def _get_current_state_for_context(self, txn, context):
-        query = (
-            "SELECT pdu_id, origin FROM %s WHERE context = ?"
-            % CurrentStateTable.table_name
-        )
-
-        logger.debug("get_current_state %s, Args=%s", query, context)
-        txn.execute(query, (context,))
-
-        res = txn.fetchall()
-
-        logger.debug("get_current_state %d results", len(res))
-
-        return self._get_pdu_tuples(txn, res)
-
-    def _persist_pdu_txn(self, txn, prev_pdus, cols):
-        """Inserts a (non-state) PDU into the database.
-
-        Args:
-            txn,
-            prev_pdus (list)
-            **cols: The columns to insert into the PdusTable.
-        """
-        entry = PdusTable.EntryType(
-            **{k: cols.get(k, None) for k in PdusTable.fields}
-        )
-
-        txn.execute(PdusTable.insert_statement(), entry)
-
-        self._handle_prev_pdus(
-            txn, entry.outlier, entry.pdu_id, entry.origin,
-            prev_pdus, entry.context
-        )
-
-    def mark_pdu_as_processed(self, pdu_id, pdu_origin):
-        """Mark a received PDU as processed.
-
-        Args:
-            txn
-            pdu_id (str)
-            pdu_origin (str)
-        """
-
-        return self.runInteraction(
-            self._mark_as_processed, pdu_id, pdu_origin
-        )
-
-    def _mark_as_processed(self, txn, pdu_id, pdu_origin):
-        txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
-
-    def get_all_pdus_from_context(self, context):
-        """Get a list of all PDUs for a given context."""
-        return self.runInteraction(
-            self._get_all_pdus_from_context, context,
-        )
-
-    def _get_all_pdus_from_context(self, txn, context):
-        query = (
-            "SELECT pdu_id, origin FROM %s "
-            "WHERE context = ?"
-        ) % PdusTable.table_name
-
-        txn.execute(query, (context,))
-
-        return self._get_pdu_tuples(txn, txn.fetchall())
-
-    def get_backfill(self, context, pdu_list, limit):
-        """Get a list of Pdus for a given topic that occured before (and
-        including) the pdus in pdu_list. Return a list of max size `limit`.
-
-        Args:
-            txn
-            context (str)
-            pdu_list (list)
-            limit (int)
-
-        Return:
-            list: A list of PduTuples
-        """
-        return self.runInteraction(
-            self._get_backfill, context, pdu_list, limit
-        )
-
-    def _get_backfill(self, txn, context, pdu_list, limit):
-        logger.debug(
-            "backfill: %s, %s, %s",
-            context, repr(pdu_list), limit
-        )
-
-        # We seed the pdu_results with the things from the pdu_list.
-        pdu_results = pdu_list
-
-        front = pdu_list
-
-        query = (
-            "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
-            "WHERE context = ? AND pdu_id = ? AND origin = ? "
-            "LIMIT ?"
-        ) % {
-            "edges_table": PduEdgesTable.table_name,
-        }
-
-        # We iterate through all pdu_ids in `front` to select their previous
-        # pdus. These are dumped in `new_front`. We continue until we reach the
-        # limit *or* new_front is empty (i.e., we've run out of things to
-        # select
-        while front and len(pdu_results) < limit:
-
-            new_front = []
-            for pdu_id, origin in front:
-                logger.debug(
-                    "_backfill_interaction: i=%s, o=%s",
-                    pdu_id, origin
-                )
-
-                txn.execute(
-                    query,
-                    (context, pdu_id, origin, limit - len(pdu_results))
-                )
-
-                for row in txn.fetchall():
-                    logger.debug(
-                        "_backfill_interaction: got i=%s, o=%s",
-                        *row
-                    )
-                    new_front.append(row)
-
-            front = new_front
-            pdu_results += new_front
-
-        # We also want to update the `prev_pdus` attributes before returning.
-        return self._get_pdu_tuples(txn, pdu_results)
-
-    def get_min_depth_for_context(self, context):
-        """Get the current minimum depth for a context
-
-        Args:
-            txn
-            context (str)
-        """
-        return self.runInteraction(
-            self._get_min_depth_for_context, context
-        )
-
-    def _get_min_depth_for_context(self, txn, context):
-        return self._get_min_depth_interaction(txn, context)
-
-    def _get_min_depth_interaction(self, txn, context):
-        txn.execute(
-            "SELECT min_depth FROM %s WHERE context = ?"
-            % ContextDepthTable.table_name,
-            (context,)
-        )
-
-        row = txn.fetchone()
-
-        return row[0] if row else None
-
-    def _update_min_depth_for_context_txn(self, txn, context, depth):
-        """Update the minimum `depth` of the given context, which is the line
-        on which we stop backfilling backwards.
-
-        Args:
-            context (str)
-            depth (int)
-        """
-        min_depth = self._get_min_depth_interaction(txn, context)
-
-        do_insert = depth < min_depth if min_depth else True
-
-        if do_insert:
-            txn.execute(
-                "INSERT OR REPLACE INTO %s (context, min_depth) "
-                "VALUES (?,?)" % ContextDepthTable.table_name,
-                (context, depth)
-            )
-
-    def _get_latest_pdus_in_context(self, txn, context):
-        """Get's a list of the most current pdus for a given context. This is
-        used when we are sending a Pdu and need to fill out the `prev_pdus`
-        key
-
-        Args:
-            txn
-            context
-        """
-        query = (
-            "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
-            "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
-            "AND f.origin = p.origin "
-            "WHERE f.context = ?"
-        ) % {
-            "pdus": PdusTable.table_name,
-            "forward": PduForwardExtremitiesTable.table_name,
-        }
-
-        logger.debug("get_prev query: %s", query)
-
-        txn.execute(
-            query,
-            (context, )
-        )
-
-        results = txn.fetchall()
-
-        return [(row[0], row[1], row[2]) for row in results]
-
-    @defer.inlineCallbacks
-    def get_oldest_pdus_in_context(self, context):
-        """Get a list of Pdus that we haven't backfilled beyond yet (and havent
-        seen). This list is used when we want to backfill backwards and is the
-        list we send to the remote server.
-
-        Args:
-            txn
-            context (str)
-
-        Returns:
-            list: A list of PduIdTuple.
-        """
-        results = yield self._execute(
-            None,
-            "SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
-            % {"back": PduBackwardExtremitiesTable.table_name, },
-            context
-        )
-
-        defer.returnValue([PduIdTuple(i, o) for i, o in results])
-
-    def is_pdu_new(self, pdu_id, origin, context, depth):
-        """For a given Pdu, try and figure out if it's 'new', i.e., if it's
-        not something we got randomly from the past, for example when we
-        request the current state of the room that will probably return a bunch
-        of pdus from before we joined.
-
-        Args:
-            txn
-            pdu_id (str)
-            origin (str)
-            context (str)
-            depth (int)
-
-        Returns:
-            bool
-        """
-
-        return self.runInteraction(
-            self._is_pdu_new,
-            pdu_id=pdu_id,
-            origin=origin,
-            context=context,
-            depth=depth
-        )
-
-    def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
-        # If depth > min depth in back table, then we classify it as new.
-        # OR if there is nothing in the back table, then it kinda needs to
-        # be a new thing.
-        query = (
-            "SELECT min(p.depth) FROM %(edges)s as e "
-            "INNER JOIN %(back)s as b "
-            "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
-            "INNER JOIN %(pdus)s as p "
-            "ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
-            "WHERE p.context = ?"
-        ) % {
-            "pdus": PdusTable.table_name,
-            "edges": PduEdgesTable.table_name,
-            "back": PduBackwardExtremitiesTable.table_name,
-        }
-
-        txn.execute(query, (context,))
-
-        min_depth, = txn.fetchone()
-
-        if not min_depth or depth > int(min_depth):
-            logger.debug(
-                "is_new true: id=%s, o=%s, d=%s min_depth=%s",
-                pdu_id, origin, depth, min_depth
-            )
-            return True
-
-        # If this pdu is in the forwards table, then it also is a new one
-        query = (
-            "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
-        ) % {
-            "forward": PduForwardExtremitiesTable.table_name,
-        }
-
-        txn.execute(query, (pdu_id, origin))
-
-        # Did we get anything?
-        if txn.fetchall():
-            logger.debug(
-                "is_new true: id=%s, o=%s, d=%s was forward",
-                pdu_id, origin, depth
-            )
-            return True
-
-        logger.debug(
-            "is_new false: id=%s, o=%s, d=%s",
-            pdu_id, origin, depth
-        )
-
-        # FINE THEN. It's probably old.
-        return False
-
-    @staticmethod
-    @log_function
-    def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
-                          context):
-        txn.executemany(
-            PduEdgesTable.insert_statement(),
-            [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
-        )
-
-        # Update the extremities table if this is not an outlier.
-        if not outlier:
-
-            # First, we delete the new one from the forwards extremities table.
-            query = (
-                "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
-                % PduForwardExtremitiesTable.table_name
-            )
-            txn.executemany(query, prev_pdus)
-
-            # We only insert as a forward extremety the new pdu if there are no
-            # other pdus that reference it as a prev pdu
-            query = (
-                "INSERT INTO %(table)s (pdu_id, origin, context) "
-                "SELECT ?, ?, ? WHERE NOT EXISTS ("
-                "SELECT 1 FROM %(pdu_edges)s WHERE "
-                "prev_pdu_id = ? AND prev_origin = ?"
-                ")"
-            ) % {
-                "table": PduForwardExtremitiesTable.table_name,
-                "pdu_edges": PduEdgesTable.table_name
-            }
-
-            logger.debug("query: %s", query)
-
-            txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
-
-            # Insert all the prev_pdus as a backwards thing, they'll get
-            # deleted in a second if they're incorrect anyway.
-            txn.executemany(
-                PduBackwardExtremitiesTable.insert_statement(),
-                [(i, o, context) for i, o in prev_pdus]
-            )
-
-            # Also delete from the backwards extremities table all ones that
-            # reference pdus that we have already seen
-            query = (
-                "DELETE FROM %(pdu_back)s WHERE EXISTS ("
-                "SELECT 1 FROM %(pdus)s AS pdus "
-                "WHERE "
-                "%(pdu_back)s.pdu_id = pdus.pdu_id "
-                "AND %(pdu_back)s.origin = pdus.origin "
-                "AND not pdus.outlier "
-                ")"
-            ) % {
-                "pdu_back": PduBackwardExtremitiesTable.table_name,
-                "pdus": PdusTable.table_name,
-            }
-            txn.execute(query)
-
-
-class StatePduStore(SQLBaseStore):
-    """A collection of queries for handling state PDUs.
-    """
-
-    def _persist_state_txn(self, txn, prev_pdus, cols):
-        """Inserts a state PDU into the database
-
-        Args:
-            txn,
-            prev_pdus (list)
-            **cols: The columns to insert into the PdusTable and StatePdusTable
-        """
-        pdu_entry = PdusTable.EntryType(
-            **{k: cols.get(k, None) for k in PdusTable.fields}
-        )
-        state_entry = StatePdusTable.EntryType(
-            **{k: cols.get(k, None) for k in StatePdusTable.fields}
-        )
-
-        logger.debug("Inserting pdu: %s", repr(pdu_entry))
-        logger.debug("Inserting state: %s", repr(state_entry))
-
-        txn.execute(PdusTable.insert_statement(), pdu_entry)
-        txn.execute(StatePdusTable.insert_statement(), state_entry)
-
-        self._handle_prev_pdus(
-            txn,
-            pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
-            pdu_entry.context
-        )
-
-    def get_unresolved_state_tree(self, new_state_pdu):
-        return self.runInteraction(
-            self._get_unresolved_state_tree, new_state_pdu
-        )
-
-    @log_function
-    def _get_unresolved_state_tree(self, txn, new_pdu):
-        current = self._get_current_interaction(
-            txn,
-            new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
-        )
-
-        ReturnType = namedtuple(
-            "StateReturnType", ["new_branch", "current_branch"]
-        )
-        return_value = ReturnType([new_pdu], [])
-
-        if not current:
-            logger.debug("get_unresolved_state_tree No current state.")
-            return (return_value, None)
-
-        return_value.current_branch.append(current)
-
-        enum_branches = self._enumerate_state_branches(
-            txn, new_pdu, current
-        )
-
-        missing_branch = None
-        for branch, prev_state, state in enum_branches:
-            if state:
-                return_value[branch].append(state)
-            else:
-                # We don't have prev_state :(
-                missing_branch = branch
-                break
-
-        return (return_value, missing_branch)
-
-    def update_current_state(self, pdu_id, origin, context, pdu_type,
-                             state_key):
-        return self.runInteraction(
-            self._update_current_state,
-            pdu_id, origin, context, pdu_type, state_key
-        )
-
-    def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
-                              state_key):
-        query = (
-            "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
-        ) % {
-            "curr": CurrentStateTable.table_name,
-            "fields": CurrentStateTable.get_fields_string(),
-            "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
-        }
-
-        query_args = CurrentStateTable.EntryType(
-            pdu_id=pdu_id,
-            origin=origin,
-            context=context,
-            pdu_type=pdu_type,
-            state_key=state_key
-        )
-
-        txn.execute(query, query_args)
-
-    def get_current_state_pdu(self, context, pdu_type, state_key):
-        """For a given context, pdu_type, state_key 3-tuple, return what is
-        currently considered the current state.
-
-        Args:
-            txn
-            context (str)
-            pdu_type (str)
-            state_key (str)
-
-        Returns:
-            PduEntry
-        """
-
-        return self.runInteraction(
-            self._get_current_state_pdu, context, pdu_type, state_key
-        )
-
-    def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
-        return self._get_current_interaction(txn, context, pdu_type, state_key)
-
-    def _get_current_interaction(self, txn, context, pdu_type, state_key):
-        logger.debug(
-            "_get_current_interaction %s %s %s",
-            context, pdu_type, state_key
-        )
-
-        fields = _pdu_state_joiner.get_fields(
-            PdusTable="p", StatePdusTable="s")
-
-        current_query = (
-            "SELECT %(fields)s FROM %(state)s as s "
-            "INNER JOIN %(pdus)s as p "
-            "ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
-            "INNER JOIN %(curr)s as c "
-            "ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
-            "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
-        ) % {
-            "fields": fields,
-            "curr": CurrentStateTable.table_name,
-            "state": StatePdusTable.table_name,
-            "pdus": PdusTable.table_name,
-        }
-
-        txn.execute(
-            current_query,
-            (context, pdu_type, state_key)
-        )
-
-        row = txn.fetchone()
-
-        result = PduEntry(*row) if row else None
-
-        if not result:
-            logger.debug("_get_current_interaction not found")
-        else:
-            logger.debug(
-                "_get_current_interaction found %s %s",
-                result.pdu_id, result.origin
-            )
-
-        return result
-
-    def handle_new_state(self, new_pdu):
-        """Actually perform conflict resolution on the new_pdu on the
-        assumption we have all the pdus required to perform it.
-
-        Args:
-            new_pdu
-
-        Returns:
-            bool: True if the new_pdu clobbered the current state, False if not
-        """
-        return self.runInteraction(
-            self._handle_new_state, new_pdu
-        )
-
-    def _handle_new_state(self, txn, new_pdu):
-        logger.debug(
-            "handle_new_state %s %s",
-            new_pdu.pdu_id, new_pdu.origin
-        )
-
-        current = self._get_current_interaction(
-            txn,
-            new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
-        )
-
-        is_current = False
-
-        if (not current or not current.prev_state_id
-                or not current.prev_state_origin):
-            # Oh, we don't have any state for this yet.
-            is_current = True
-        elif (current.pdu_id == new_pdu.prev_state_id
-                and current.origin == new_pdu.prev_state_origin):
-            # Oh! A direct clobber. Just do it.
-            is_current = True
-        else:
-            ##
-            # Ok, now loop through until we get to a common ancestor.
-            max_new = int(new_pdu.power_level)
-            max_current = int(current.power_level)
-
-            enum_branches = self._enumerate_state_branches(
-                txn, new_pdu, current
-            )
-            for branch, prev_state, state in enum_branches:
-                if not state:
-                    raise RuntimeError(
-                        "Could not find state_pdu %s %s" %
-                        (
-                            prev_state.prev_state_id,
-                            prev_state.prev_state_origin
-                        )
-                    )
-
-                if branch == 0:
-                    max_new = max(int(state.depth), max_new)
-                else:
-                    max_current = max(int(state.depth), max_current)
-
-            is_current = max_new > max_current
-
-        if is_current:
-            logger.debug("handle_new_state make current")
-
-            # Right, this is a new thing, so woo, just insert it.
-            txn.execute(
-                "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
-                % {
-                    "curr": CurrentStateTable.table_name,
-                    "fields": CurrentStateTable.get_fields_string(),
-                    "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
-                },
-                CurrentStateTable.EntryType(
-                    *(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
-                )
-            )
-        else:
-            logger.debug("handle_new_state not current")
-
-        logger.debug("handle_new_state done")
-
-        return is_current
-
-    @log_function
-    def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
-        branch_a = pdu_a
-        branch_b = pdu_b
-
-        while True:
-            if (branch_a.pdu_id == branch_b.pdu_id
-                    and branch_a.origin == branch_b.origin):
-                # Woo! We found a common ancestor
-                logger.debug("_enumerate_state_branches Found common ancestor")
-                break
-
-            do_branch_a = (
-                hasattr(branch_a, "prev_state_id") and
-                branch_a.prev_state_id
-            )
-
-            do_branch_b = (
-                hasattr(branch_b, "prev_state_id") and
-                branch_b.prev_state_id
-            )
-
-            logger.debug(
-                "do_branch_a=%s, do_branch_b=%s",
-                do_branch_a, do_branch_b
-            )
-
-            if do_branch_a and do_branch_b:
-                do_branch_a = int(branch_a.depth) > int(branch_b.depth)
-
-            if do_branch_a:
-                pdu_tuple = PduIdTuple(
-                    branch_a.prev_state_id,
-                    branch_a.prev_state_origin
-                )
-
-                prev_branch = branch_a
-
-                logger.debug("getting branch_a prev %s", pdu_tuple)
-                branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
-                if branch_a:
-                    branch_a = Pdu.from_pdu_tuple(branch_a)
-
-                logger.debug("branch_a=%s", branch_a)
-
-                yield (0, prev_branch, branch_a)
-
-                if not branch_a:
-                    break
-            elif do_branch_b:
-                pdu_tuple = PduIdTuple(
-                    branch_b.prev_state_id,
-                    branch_b.prev_state_origin
-                )
-
-                prev_branch = branch_b
-
-                logger.debug("getting branch_b prev %s", pdu_tuple)
-                branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
-                if branch_b:
-                    branch_b = Pdu.from_pdu_tuple(branch_b)
-
-                logger.debug("branch_b=%s", branch_b)
-
-                yield (1, prev_branch, branch_b)
-
-                if not branch_b:
-                    break
-            else:
-                break
-
-
-class PdusTable(Table):
-    table_name = "pdus"
-
-    fields = [
-        "pdu_id",
-        "origin",
-        "context",
-        "pdu_type",
-        "ts",
-        "depth",
-        "is_state",
-        "content_json",
-        "unrecognized_keys",
-        "outlier",
-        "have_processed",
-    ]
-
-    EntryType = namedtuple("PdusEntry", fields)
-
-
-class PduDestinationsTable(Table):
-    table_name = "pdu_destinations"
-
-    fields = [
-        "pdu_id",
-        "origin",
-        "destination",
-        "delivered_ts",
-    ]
-
-    EntryType = namedtuple("PduDestinationsEntry", fields)
-
-
-class PduEdgesTable(Table):
-    table_name = "pdu_edges"
-
-    fields = [
-        "pdu_id",
-        "origin",
-        "prev_pdu_id",
-        "prev_origin",
-        "context"
-    ]
-
-    EntryType = namedtuple("PduEdgesEntry", fields)
-
-
-class PduForwardExtremitiesTable(Table):
-    table_name = "pdu_forward_extremities"
-
-    fields = [
-        "pdu_id",
-        "origin",
-        "context",
-    ]
-
-    EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
-
-
-class PduBackwardExtremitiesTable(Table):
-    table_name = "pdu_backward_extremities"
-
-    fields = [
-        "pdu_id",
-        "origin",
-        "context",
-    ]
-
-    EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
-
-
-class ContextDepthTable(Table):
-    table_name = "context_depth"
-
-    fields = [
-        "context",
-        "min_depth",
-    ]
-
-    EntryType = namedtuple("ContextDepthEntry", fields)
-
-
-class StatePdusTable(Table):
-    table_name = "state_pdus"
-
-    fields = [
-        "pdu_id",
-        "origin",
-        "context",
-        "pdu_type",
-        "state_key",
-        "power_level",
-        "prev_state_id",
-        "prev_state_origin",
-    ]
-
-    EntryType = namedtuple("StatePdusEntry", fields)
-
-
-class CurrentStateTable(Table):
-    table_name = "current_state"
-
-    fields = [
-        "pdu_id",
-        "origin",
-        "context",
-        "pdu_type",
-        "state_key",
-    ]
-
-    EntryType = namedtuple("CurrentStateEntry", fields)
-
-_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
-
-
-# TODO: These should probably be put somewhere more sensible
-PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
-
-PduEntry = _pdu_state_joiner.EntryType
-""" We are always interested in the join of the PdusTable and StatePdusTable,
-rather than just the PdusTable.
-
-This does not include a prev_pdus key.
-"""
-
-PduTuple = namedtuple(
-    "PduTuple",
-    ("pdu_entry", "prev_pdu_list")
-)
-""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
-the `prev_pdus` key of a PDU.
-"""
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 719806f82b..1f89d77344 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore):
         Raises:
             StoreError if the user_id could not be registered.
         """
-        yield self.runInteraction(self._register, user_id, token,
-                                           password_hash)
+        yield self.runInteraction(
+            "register",
+            self._register, user_id, token, password_hash
+        )
 
     def _register(self, txn, user_id, token, password_hash):
         now = int(self.clock.time())
@@ -100,17 +102,22 @@ class RegistrationStore(SQLBaseStore):
             StoreError if no user was found.
         """
         return self.runInteraction(
+            "get_user_by_token",
             self._query_for_auth,
             token
         )
 
+    @defer.inlineCallbacks
     def is_server_admin(self, user):
-        return self._simple_select_one_onecol(
+        res = yield self._simple_select_one_onecol(
             table="users",
             keyvalues={"name": user.to_string()},
             retcol="admin",
+            allow_none=True,
         )
 
+        defer.returnValue(res if res else False)
+
     def _query_for_auth(self, txn, token):
         sql = (
             "SELECT users.name, users.admin, access_tokens.device_id "
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8cd46334cf..cc0513b8d2 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -132,209 +132,29 @@ class RoomStore(SQLBaseStore):
 
         defer.returnValue(ret)
 
-    @defer.inlineCallbacks
-    def get_room_join_rule(self, room_id):
-        sql = (
-            "SELECT join_rule FROM room_join_rules as r "
-            "INNER JOIN current_state_events as c "
-            "ON r.event_id = c.event_id "
-            "WHERE c.room_id = ? "
-        )
-
-        rows = yield self._execute(None, sql, room_id)
-
-        if len(rows) == 1:
-            defer.returnValue(rows[0][0])
-        else:
-            defer.returnValue(None)
-
-    def get_power_level(self, room_id, user_id):
-        return self.runInteraction(
-            self._get_power_level,
-            room_id, user_id,
-        )
-
-    def _get_power_level(self, txn, room_id, user_id):
-        sql = (
-            "SELECT level FROM room_power_levels as r "
-            "INNER JOIN current_state_events as c "
-            "ON r.event_id = c.event_id "
-            "WHERE c.room_id = ? AND r.user_id = ? "
-        )
-
-        rows = txn.execute(sql, (room_id, user_id,)).fetchall()
-
-        if len(rows) == 1:
-            return rows[0][0]
-
-        sql = (
-            "SELECT level FROM room_default_levels as r "
-            "INNER JOIN current_state_events as c "
-            "ON r.event_id = c.event_id "
-            "WHERE c.room_id = ? "
-        )
-
-        rows = txn.execute(sql, (room_id,)).fetchall()
-
-        if len(rows) == 1:
-            return rows[0][0]
-        else:
-            return None
-
-    def get_ops_levels(self, room_id):
-        return self.runInteraction(
-            self._get_ops_levels,
-            room_id,
-        )
-
-    def _get_ops_levels(self, txn, room_id):
-        sql = (
-            "SELECT ban_level, kick_level, redact_level "
-            "FROM room_ops_levels as r "
-            "INNER JOIN current_state_events as c "
-            "ON r.event_id = c.event_id "
-            "WHERE c.room_id = ? "
-        )
-
-        rows = txn.execute(sql, (room_id,)).fetchall()
-
-        if len(rows) == 1:
-            return OpsLevel(rows[0][0], rows[0][1], rows[0][2])
-        else:
-            return OpsLevel(None, None)
-
-    def get_add_state_level(self, room_id):
-        return self._get_level_from_table("room_add_state_levels", room_id)
-
-    def get_send_event_level(self, room_id):
-        return self._get_level_from_table("room_send_event_levels", room_id)
-
-    @defer.inlineCallbacks
-    def _get_level_from_table(self, table, room_id):
-        sql = (
-            "SELECT level FROM %(table)s as r "
-            "INNER JOIN current_state_events as c "
-            "ON r.event_id = c.event_id "
-            "WHERE c.room_id = ? "
-        ) % {"table": table}
-
-        rows = yield self._execute(None, sql, room_id)
-
-        if len(rows) == 1:
-            defer.returnValue(rows[0][0])
-        else:
-            defer.returnValue(None)
-
     def _store_room_topic_txn(self, txn, event):
-        self._simple_insert_txn(
-            txn,
-            "topics",
-            {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "topic": event.topic,
-            }
-        )
+        if hasattr(event, "topic"):
+            self._simple_insert_txn(
+                txn,
+                "topics",
+                {
+                    "event_id": event.event_id,
+                    "room_id": event.room_id,
+                    "topic": event.topic,
+                }
+            )
 
     def _store_room_name_txn(self, txn, event):
-        self._simple_insert_txn(
-            txn,
-            "room_names",
-            {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "name": event.name,
-            }
-        )
-
-    def _store_join_rule(self, txn, event):
-        self._simple_insert_txn(
-            txn,
-            "room_join_rules",
-            {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "join_rule": event.content["join_rule"],
-            },
-        )
-
-    def _store_power_levels(self, txn, event):
-        for user_id, level in event.content.items():
-            if user_id == "default":
-                self._simple_insert_txn(
-                    txn,
-                    "room_default_levels",
-                    {
-                        "event_id": event.event_id,
-                        "room_id": event.room_id,
-                        "level": level,
-                    },
-                )
-            else:
-                self._simple_insert_txn(
-                    txn,
-                    "room_power_levels",
-                    {
-                        "event_id": event.event_id,
-                        "room_id": event.room_id,
-                        "user_id": user_id,
-                        "level": level
-                    },
-                )
-
-    def _store_default_level(self, txn, event):
-        self._simple_insert_txn(
-            txn,
-            "room_default_levels",
-            {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "level": event.content["default_level"],
-            },
-        )
-
-    def _store_add_state_level(self, txn, event):
-        self._simple_insert_txn(
-            txn,
-            "room_add_state_levels",
-            {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "level": event.content["level"],
-            },
-        )
-
-    def _store_send_event_level(self, txn, event):
-        self._simple_insert_txn(
-            txn,
-            "room_send_event_levels",
-            {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "level": event.content["level"],
-            },
-        )
-
-    def _store_ops_level(self, txn, event):
-        content = {
-            "event_id": event.event_id,
-            "room_id": event.room_id,
-        }
-
-        if "kick_level" in event.content:
-            content["kick_level"] = event.content["kick_level"]
-
-        if "ban_level" in event.content:
-            content["ban_level"] = event.content["ban_level"]
-
-        if "redact_level" in event.content:
-            content["redact_level"] = event.content["redact_level"]
-
-        self._simple_insert_txn(
-            txn,
-            "room_ops_levels",
-            content,
-        )
+        if hasattr(event, "name"):
+            self._simple_insert_txn(
+                txn,
+                "room_names",
+                {
+                    "event_id": event.event_id,
+                    "room_id": event.room_id,
+                    "name": event.name,
+                }
+            )
 
 
 class RoomsTable(Table):
diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql
deleted file mode 100644
index 8a00868065..0000000000
--- a/synapse/storage/schema/edge_pdus.sql
+++ /dev/null
@@ -1,31 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-CREATE TABLE IF NOT EXISTS context_edge_pdus(
-    id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
-    pdu_id TEXT, 
-    origin TEXT,
-    context TEXT, 
-    CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
-CREATE TABLE IF NOT EXISTS origin_edge_pdus(
-    id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
-    pdu_id TEXT,
-    origin TEXT,
-    CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
-CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin); 
-CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);
diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql
new file mode 100644
index 0000000000..be1c72a775
--- /dev/null
+++ b/synapse/storage/schema/event_edges.sql
@@ -0,0 +1,75 @@
+
+CREATE TABLE IF NOT EXISTS event_forward_extremities(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_backward_extremities(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_edges(
+    event_id TEXT NOT NULL,
+    prev_event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    is_state INTEGER NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
+);
+
+CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
+CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
+
+
+CREATE TABLE IF NOT EXISTS room_depth(
+    room_id TEXT NOT NULL,
+    min_depth INTEGER NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (room_id)
+);
+
+CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
+
+
+create TABLE IF NOT EXISTS event_destinations(
+    event_id TEXT NOT NULL,
+    destination TEXT NOT NULL,
+    delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
+    CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
+
+
+CREATE TABLE IF NOT EXISTS state_forward_extremities(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    state_key TEXT NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
+    room_id, type, state_key
+);
+CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_auth(
+    event_id TEXT NOT NULL,
+    auth_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id)
+);
+
+CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id);
+CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id);
\ No newline at end of file
diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql
new file mode 100644
index 0000000000..5491c7ecec
--- /dev/null
+++ b/synapse/storage/schema/event_signatures.sql
@@ -0,0 +1,65 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_content_hashes (
+    event_id TEXT,
+    algorithm TEXT,
+    hash BLOB,
+    CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
+    event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_reference_hashes (
+    event_id TEXT,
+    algorithm TEXT,
+    hash BLOB,
+    CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
+    event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_origin_signatures (
+    event_id TEXT,
+    origin TEXT,
+    key_id TEXT,
+    signature BLOB,
+    CONSTRAINT uniqueness UNIQUE (event_id, key_id)
+);
+
+CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures (
+    event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_edge_hashes(
+    event_id TEXT,
+    prev_event_id TEXT,
+    algorithm TEXT,
+    hash BLOB,
+    CONSTRAINT uniqueness UNIQUE (
+        event_id, prev_event_id, algorithm
+    )
+);
+
+CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
+    event_id
+);
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index 3aa83f5c8c..8ba732a23b 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events(
     unrecognized_keys TEXT,
     processed BOOL NOT NULL,
     outlier BOOL NOT NULL,
+    depth INTEGER DEFAULT 0 NOT NULL,
     CONSTRAINT ev_uniq UNIQUE (event_id)
 );
 
@@ -84,80 +85,24 @@ CREATE TABLE IF NOT EXISTS topics(
     topic TEXT NOT NULL
 );
 
+CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id);
+CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id);
+
 CREATE TABLE IF NOT EXISTS room_names(
     event_id TEXT NOT NULL,
     room_id TEXT NOT NULL,
     name TEXT NOT NULL
 );
 
+CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id);
+CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id);
+
 CREATE TABLE IF NOT EXISTS rooms(
     room_id TEXT PRIMARY KEY NOT NULL,
     is_public INTEGER,
     creator TEXT
 );
 
-CREATE TABLE IF NOT EXISTS room_join_rules(
-    event_id TEXT NOT NULL,
-    room_id TEXT NOT NULL,
-    join_rule TEXT NOT NULL
-);
-CREATE INDEX IF NOT EXISTS room_join_rules_event_id ON room_join_rules(event_id);
-CREATE INDEX IF NOT EXISTS room_join_rules_room_id ON room_join_rules(room_id);
-
-
-CREATE TABLE IF NOT EXISTS room_power_levels(
-    event_id TEXT NOT NULL,
-    room_id TEXT NOT NULL,
-    user_id TEXT NOT NULL,
-    level INTEGER NOT NULL
-);
-CREATE INDEX IF NOT EXISTS room_power_levels_event_id ON room_power_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_power_levels_room_id ON room_power_levels(room_id);
-CREATE INDEX IF NOT EXISTS room_power_levels_room_user ON room_power_levels(room_id, user_id);
-
-
-CREATE TABLE IF NOT EXISTS room_default_levels(
-    event_id TEXT NOT NULL,
-    room_id TEXT NOT NULL,
-    level INTEGER NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS room_default_levels_event_id ON room_default_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_default_levels_room_id ON room_default_levels(room_id);
-
-
-CREATE TABLE IF NOT EXISTS room_add_state_levels(
-    event_id TEXT NOT NULL,
-    room_id TEXT NOT NULL,
-    level INTEGER NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS room_add_state_levels_event_id ON room_add_state_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_add_state_levels_room_id ON room_add_state_levels(room_id);
-
-
-CREATE TABLE IF NOT EXISTS room_send_event_levels(
-    event_id TEXT NOT NULL,
-    room_id TEXT NOT NULL,
-    level INTEGER NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS room_send_event_levels_event_id ON room_send_event_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_send_event_levels_room_id ON room_send_event_levels(room_id);
-
-
-CREATE TABLE IF NOT EXISTS room_ops_levels(
-    event_id TEXT NOT NULL,
-    room_id TEXT NOT NULL,
-    ban_level INTEGER,
-    kick_level INTEGER,
-    redact_level INTEGER
-);
-
-CREATE INDEX IF NOT EXISTS room_ops_levels_event_id ON room_ops_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_ops_levels_room_id ON room_ops_levels(room_id);
-
-
 CREATE TABLE IF NOT EXISTS room_hosts(
     room_id TEXT NOT NULL,
     host TEXT NOT NULL,
diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql
deleted file mode 100644
index 16e111a56c..0000000000
--- a/synapse/storage/schema/pdu.sql
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
--- Stores pdus and their content
-CREATE TABLE IF NOT EXISTS pdus(
-    pdu_id TEXT, 
-    origin TEXT, 
-    context TEXT,
-    pdu_type TEXT,
-    ts INTEGER,
-    depth INTEGER DEFAULT 0 NOT NULL,
-    is_state BOOL, 
-    content_json TEXT,
-    unrecognized_keys TEXT,
-    outlier BOOL NOT NULL,
-    have_processed BOOL, 
-    CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
--- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
-CREATE TABLE IF NOT EXISTS state_pdus(
-    pdu_id TEXT,
-    origin TEXT,
-    context TEXT,
-    pdu_type TEXT,
-    state_key TEXT,
-    power_level TEXT,
-    prev_state_id TEXT,
-    prev_state_origin TEXT,
-    CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
-    CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
-);
-
-CREATE TABLE IF NOT EXISTS current_state(
-    pdu_id TEXT,
-    origin TEXT,
-    context TEXT,
-    pdu_type TEXT,
-    state_key TEXT,
-    CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
-    CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
-);
-
--- Stores where each pdu we want to send should be sent and the delivery status.
-create TABLE IF NOT EXISTS pdu_destinations(
-    pdu_id TEXT,
-    origin TEXT,
-    destination TEXT,
-    delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
-    CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
-    pdu_id TEXT,
-    origin TEXT,
-    context TEXT,
-    CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
-    pdu_id TEXT,
-    origin TEXT,
-    context TEXT,
-    CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_edges(
-    pdu_id TEXT,
-    origin TEXT,
-    prev_pdu_id TEXT,
-    prev_origin TEXT,
-    context TEXT,
-    CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
-);
-
-CREATE TABLE IF NOT EXISTS context_depth(
-    context TEXT,
-    min_depth INTEGER,
-    CONSTRAINT uniqueness UNIQUE (context)
-);
-
-CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
-
-
-CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
--- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
-
-CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
-CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);
diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql
new file mode 100644
index 0000000000..b44c56b519
--- /dev/null
+++ b/synapse/storage/schema/state.sql
@@ -0,0 +1,33 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS state_groups(
+    id INTEGER PRIMARY KEY,
+    room_id TEXT NOT NULL,
+    event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS state_groups_state(
+    state_group INTEGER NOT NULL,
+    room_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    state_key TEXT NOT NULL,
+    event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS event_to_state_groups(
+    event_id TEXT NOT NULL,
+    state_group INTEGER NOT NULL
+);
\ No newline at end of file
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
new file mode 100644
index 0000000000..84a49088a2
--- /dev/null
+++ b/synapse/storage/signatures.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore
+
+
+class SignatureStore(SQLBaseStore):
+    """Persistence for event signatures and hashes"""
+
+    def _get_event_content_hashes_txn(self, txn, event_id):
+        """Get all the hashes for a given Event.
+        Args:
+            txn (cursor):
+            event_id (str): Id for the Event.
+        Returns:
+            A dict of algorithm -> hash.
+        """
+        query = (
+            "SELECT algorithm, hash"
+            " FROM event_content_hashes"
+            " WHERE event_id = ?"
+        )
+        txn.execute(query, (event_id, ))
+        return dict(txn.fetchall())
+
+    def _store_event_content_hash_txn(self, txn, event_id, algorithm,
+                                    hash_bytes):
+        """Store a hash for a Event
+        Args:
+            txn (cursor):
+            event_id (str): Id for the Event.
+            algorithm (str): Hashing algorithm.
+            hash_bytes (bytes): Hash function output bytes.
+        """
+        self._simple_insert_txn(
+            txn,
+            "event_content_hashes",
+            {
+                "event_id": event_id,
+                "algorithm": algorithm,
+                "hash": buffer(hash_bytes),
+            },
+            or_ignore=True,
+        )
+
+    def get_event_reference_hashes(self, event_ids):
+        def f(txn):
+            return [
+                self._get_event_reference_hashes_txn(txn, ev)
+                for ev in event_ids
+            ]
+
+        return self.runInteraction(
+            "get_event_reference_hashes",
+            f
+        )
+
+    def _get_event_reference_hashes_txn(self, txn, event_id):
+        """Get all the hashes for a given PDU.
+        Args:
+            txn (cursor):
+            event_id (str): Id for the Event.
+        Returns:
+            A dict of algorithm -> hash.
+        """
+        query = (
+            "SELECT algorithm, hash"
+            " FROM event_reference_hashes"
+            " WHERE event_id = ?"
+        )
+        txn.execute(query, (event_id, ))
+        return dict(txn.fetchall())
+
+    def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
+                                      hash_bytes):
+        """Store a hash for a PDU
+        Args:
+            txn (cursor):
+            event_id (str): Id for the Event.
+            algorithm (str): Hashing algorithm.
+            hash_bytes (bytes): Hash function output bytes.
+        """
+        self._simple_insert_txn(
+            txn,
+            "event_reference_hashes",
+            {
+                "event_id": event_id,
+                "algorithm": algorithm,
+                "hash": buffer(hash_bytes),
+            },
+            or_ignore=True,
+        )
+
+
+    def _get_event_origin_signatures_txn(self, txn, event_id):
+        """Get all the signatures for a given PDU.
+        Args:
+            txn (cursor):
+            event_id (str): Id for the Event.
+        Returns:
+            A dict of key_id -> signature_bytes.
+        """
+        query = (
+            "SELECT key_id, signature"
+            " FROM event_origin_signatures"
+            " WHERE event_id = ? "
+        )
+        txn.execute(query, (event_id, ))
+        return dict(txn.fetchall())
+
+    def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id,
+                                          signature_bytes):
+        """Store a signature from the origin server for a PDU.
+        Args:
+            txn (cursor):
+            event_id (str): Id for the Event.
+            origin (str): origin of the Event.
+            key_id (str): Id for the signing key.
+            signature (bytes): The signature.
+        """
+        self._simple_insert_txn(
+            txn,
+            "event_origin_signatures",
+            {
+                "event_id": event_id,
+                "origin": origin,
+                "key_id": key_id,
+                "signature": buffer(signature_bytes),
+            },
+            or_ignore=True,
+        )
+
+    def _get_prev_event_hashes_txn(self, txn, event_id):
+        """Get all the hashes for previous PDUs of a PDU
+        Args:
+            txn (cursor):
+            event_id (str): Id for the Event.
+        Returns:
+            dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
+        """
+        query = (
+            "SELECT prev_event_id, algorithm, hash"
+            " FROM event_edge_hashes"
+            " WHERE event_id = ?"
+        )
+        txn.execute(query, (event_id, ))
+        results = {}
+        for prev_event_id, algorithm, hash_bytes in txn.fetchall():
+            hashes = results.setdefault(prev_event_id, {})
+            hashes[algorithm] = hash_bytes
+        return results
+
+    def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
+                                 algorithm, hash_bytes):
+        self._simple_insert_txn(
+            txn,
+            "event_edge_hashes",
+            {
+                "event_id": event_id,
+                "prev_event_id": prev_event_id,
+                "algorithm": algorithm,
+                "hash": buffer(hash_bytes),
+            },
+            or_ignore=True,
+        )
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
new file mode 100644
index 0000000000..68975969f5
--- /dev/null
+++ b/synapse/storage/state.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from twisted.internet import defer
+
+
+class StateStore(SQLBaseStore):
+
+    @defer.inlineCallbacks
+    def get_state_groups(self, event_ids):
+        groups = set()
+        for event_id in event_ids:
+            group = yield self._simple_select_one_onecol(
+                table="event_to_state_groups",
+                keyvalues={"event_id": event_id},
+                retcol="state_group",
+                allow_none=True,
+            )
+            if group:
+                groups.add(group)
+
+        res = {}
+        for group in groups:
+            state_ids = yield self._simple_select_onecol(
+                table="state_groups_state",
+                keyvalues={"state_group": group},
+                retcol="event_id",
+            )
+            state = []
+            for state_id in state_ids:
+                s = yield self.get_event(
+                    state_id,
+                    allow_none=True,
+                )
+                if s:
+                    state.append(s)
+
+            res[group] = state
+
+        defer.returnValue(res)
+
+    def store_state_groups(self, event):
+        return self.runInteraction(
+            "store_state_groups",
+            self._store_state_groups_txn, event
+        )
+
+    def _store_state_groups_txn(self, txn, event):
+        if not event.state_events:
+            return
+
+        state_group = event.state_group
+        if not state_group:
+            state_group = self._simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={
+                    "room_id": event.room_id,
+                    "event_id": event.event_id,
+                }
+            )
+
+            for state in event.state_events.values():
+                self._simple_insert_txn(
+                    txn,
+                    table="state_groups_state",
+                    values={
+                        "state_group": state_group,
+                        "room_id": state.room_id,
+                        "type": state.type,
+                        "state_key": state.state_key,
+                        "event_id": state.event_id,
+                    }
+                )
+
+        self._simple_insert_txn(
+            txn,
+            table="event_to_state_groups",
+            values={
+                "state_group": state_group,
+                "event_id": event.event_id,
+            }
+        )
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d61f909939..475e7f20a1 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -177,10 +177,9 @@ class StreamStore(SQLBaseStore):
 
         sql = (
             "SELECT *, (%(redacted)s) AS redacted FROM events AS e WHERE "
-            "((room_id IN (%(current)s)) OR "
+            "(e.outlier = 0 AND (room_id IN (%(current)s)) OR "
             "(event_id IN (%(invites)s))) "
             "AND e.stream_ordering > ? AND e.stream_ordering <= ? "
-            "AND e.outlier = 0 "
             "ORDER BY stream_ordering ASC LIMIT %(limit)d "
         ) % {
             "redacted": del_sql,
@@ -309,7 +308,10 @@ class StreamStore(SQLBaseStore):
         defer.returnValue(ret)
 
     def get_room_events_max_id(self):
-        return self.runInteraction(self._get_room_events_max_id_txn)
+        return self.runInteraction(
+            "get_room_events_max_id",
+            self._get_room_events_max_id_txn
+        )
 
     def _get_room_events_max_id_txn(self, txn):
         txn.execute(
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 2ba8e30efe..00d0f48082 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore, Table
-from .pdu import PdusTable
 
 from collections import namedtuple
 
@@ -42,6 +41,7 @@ class TransactionStore(SQLBaseStore):
         """
 
         return self.runInteraction(
+            "get_received_txn_response",
             self._get_received_txn_response, transaction_id, origin
         )
 
@@ -73,6 +73,7 @@ class TransactionStore(SQLBaseStore):
         """
 
         return self.runInteraction(
+            "set_received_txn_response",
             self._set_received_txn_response,
             transaction_id, origin, code, response_dict
         )
@@ -88,7 +89,7 @@ class TransactionStore(SQLBaseStore):
         txn.execute(query, (code, response_json, transaction_id, origin))
 
     def prep_send_transaction(self, transaction_id, destination,
-                              origin_server_ts, pdu_list):
+                              origin_server_ts):
         """Persists an outgoing transaction and calculates the values for the
         previous transaction id list.
 
@@ -99,19 +100,19 @@ class TransactionStore(SQLBaseStore):
             transaction_id (str)
             destination (str)
             origin_server_ts (int)
-            pdu_list (list)
 
         Returns:
             list: A list of previous transaction ids.
         """
 
         return self.runInteraction(
+            "prep_send_transaction",
             self._prep_send_transaction,
-            transaction_id, destination, origin_server_ts, pdu_list
+            transaction_id, destination, origin_server_ts
         )
 
     def _prep_send_transaction(self, txn, transaction_id, destination,
-                               origin_server_ts, pdu_list):
+                               origin_server_ts):
 
         # First we find out what the prev_txs should be.
         # Since we know that we are only sending one transaction at a time,
@@ -139,15 +140,15 @@ class TransactionStore(SQLBaseStore):
 
         # Update the tx id -> pdu id mapping
 
-        values = [
-            (transaction_id, destination, pdu[0], pdu[1])
-            for pdu in pdu_list
-        ]
-
-        logger.debug("Inserting: %s", repr(values))
-
-        query = TransactionsToPduTable.insert_statement()
-        txn.executemany(query, values)
+        # values = [
+        #     (transaction_id, destination, pdu[0], pdu[1])
+        #     for pdu in pdu_list
+        # ]
+        #
+        # logger.debug("Inserting: %s", repr(values))
+        #
+        # query = TransactionsToPduTable.insert_statement()
+        # txn.executemany(query, values)
 
         return prev_txns
 
@@ -161,6 +162,7 @@ class TransactionStore(SQLBaseStore):
             response_json (str)
         """
         return self.runInteraction(
+            "delivered_txn",
             self._delivered_txn,
             transaction_id, destination, code, response_dict
         )
@@ -186,6 +188,7 @@ class TransactionStore(SQLBaseStore):
             list: A list of `ReceivedTransactionsTable.EntryType`
         """
         return self.runInteraction(
+            "get_transactions_after",
             self._get_transactions_after, transaction_id, destination
         )
 
@@ -202,49 +205,6 @@ class TransactionStore(SQLBaseStore):
 
         return ReceivedTransactionsTable.decode_results(txn.fetchall())
 
-    def get_pdus_after_transaction(self, transaction_id, destination):
-        """For a given local transaction_id that we sent to a given destination
-        home server, return a list of PDUs that were sent to that destination
-        after it.
-
-        Args:
-            txn
-            transaction_id (str)
-            destination (str)
-
-        Returns
-            list: A list of PduTuple
-        """
-        return self.runInteraction(
-            self._get_pdus_after_transaction,
-            transaction_id, destination
-        )
-
-    def _get_pdus_after_transaction(self, txn, transaction_id, destination):
-
-        # Query that first get's all transaction_ids with an id greater than
-        # the one given from the `sent_transactions` table. Then JOIN on this
-        # from the `tx->pdu` table to get a list of (pdu_id, origin) that
-        # specify the pdus that were sent in those transactions.
-        query = (
-            "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
-            "INNER JOIN %(sent_tx)s as st "
-            "ON tp.transaction_id = st.transaction_id "
-            "AND tp.destination = st.destination "
-            "WHERE st.id > ("
-            "SELECT id FROM %(sent_tx)s "
-            "WHERE transaction_id = ? AND destination = ?"
-        ) % {
-            "tx_pdu": TransactionsToPduTable.table_name,
-            "sent_tx": SentTransactions.table_name,
-        }
-
-        txn.execute(query, (transaction_id, destination))
-
-        pdus = PdusTable.decode_results(txn.fetchall())
-
-        return self._get_pdu_tuples(txn, pdus)
-
 
 class ReceivedTransactionsTable(Table):
     table_name = "received_transactions"
diff --git a/synapse/types.py b/synapse/types.py
index c51bc8e4f2..649ff2f7d7 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -78,6 +78,11 @@ class DomainSpecificString(
         """Create a structure on the local domain"""
         return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
 
+    @classmethod
+    def create(cls, localpart, domain, hs):
+        is_mine = domain == hs.hostname
+        return cls(localpart=localpart, domain=domain, is_mine=is_mine)
+
 
 class UserID(DomainSpecificString):
     """Structure representing a user ID."""
@@ -94,6 +99,11 @@ class RoomID(DomainSpecificString):
     SIGIL = "!"
 
 
+class EventID(DomainSpecificString):
+    """Structure representing an event id. """
+    SIGIL = "$"
+
+
 class StreamToken(
     namedtuple(
         "Token",
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 647ea6142c..bf578f8bfb 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -21,3 +21,10 @@ def sleep(seconds):
     d = defer.Deferred()
     reactor.callLater(seconds, d.callback, seconds)
     return d
+
+
+def run_on_reactor():
+    """ This will cause the rest of the function to be invoked upon the next
+    iteration of the main loop
+    """
+    return sleep(0)
\ No newline at end of file
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index c91eb897a8..e79b68f661 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -80,7 +80,7 @@ class JsonEncodedObject(object):
 
     def get_full_dict(self):
         d = {
-            k: v for (k, v) in self.__dict__.items()
+            k: _encode(v) for (k, v) in self.__dict__.items()
             if k in self.valid_keys or k in self.internal_keys
         }
         d.update(self.unrecognized_keys)
diff --git a/tests/events/test_events.py b/tests/events/test_events.py
index a4b6cb3afd..91d1d44fee 100644
--- a/tests/events/test_events.py
+++ b/tests/events/test_events.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 
 from synapse.api.events import SynapseEvent
+from synapse.api.events.validator import EventValidator
+from synapse.api.errors import SynapseError
 
 from tests import unittest
 
@@ -21,7 +23,7 @@ from tests import unittest
 class SynapseTemplateCheckTestCase(unittest.TestCase):
 
     def setUp(self):
-        pass
+        self.validator = EventValidator(None)
 
     def tearDown(self):
         pass
@@ -38,22 +40,28 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
         }
 
         event = MockSynapseEvent(template)
-        self.assertTrue(event.check_json(content, raises=False))
+        event.content = content
+        self.assertTrue(self.validator.validate(event))
 
         content = {
             "person": {"name": "bob"},
             "friends": ["jill"],
             "enemies": ["mike"]
         }
-        event = MockSynapseEvent(template)
-        self.assertTrue(event.check_json(content, raises=False))
+        event.content = content
+        self.assertTrue(self.validator.validate(event))
 
         content = {
             "person": {"name": "bob"},
             # missing friends
             "enemies": ["mike", "jill"]
         }
-        self.assertFalse(event.check_json(content, raises=False))
+        event.content = content
+        self.assertRaises(
+            SynapseError,
+            self.validator.validate,
+            event
+        )
 
     def test_lists(self):
         template = {
@@ -67,13 +75,19 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
         }
 
         event = MockSynapseEvent(template)
-        self.assertFalse(event.check_json(content, raises=False))
+        event.content = content
+        self.assertRaises(
+            SynapseError,
+            self.validator.validate,
+            event
+        )
 
         content = {
             "person": {"name": "bob"},
             "friends": [{"name": "jill"}, {"name": "mike"}]
         }
-        self.assertTrue(event.check_json(content, raises=False))
+        event.content = content
+        self.assertTrue(self.validator.validate(event))
 
     def test_nested_lists(self):
         template = {
@@ -103,7 +117,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
         }
 
         event = MockSynapseEvent(template)
-        self.assertFalse(event.check_json(content, raises=False))
+        event.content = content
+        self.assertRaises(
+            SynapseError,
+            self.validator.validate,
+            event
+        )
 
         content = {
             "results": {
@@ -117,7 +136,8 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
                 ]
             }
         }
-        self.assertTrue(event.check_json(content, raises=False))
+        event.content = content
+        self.assertTrue(self.validator.validate(event))
 
     def test_nested_keys(self):
         template = {
@@ -145,7 +165,8 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
             }
         }
 
-        self.assertTrue(event.check_json(content, raises=False))
+        event.content = content
+        self.assertTrue(self.validator.validate(event))
 
         content = {
             "person": {
@@ -159,7 +180,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
             }
         }
 
-        self.assertFalse(event.check_json(content, raises=False))
+        event.content = content
+        self.assertRaises(
+            SynapseError,
+            self.validator.validate,
+            event
+        )
 
         content = {
             "person": {
@@ -173,7 +199,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
             }
         }
 
-        self.assertFalse(event.check_json(content, raises=False))
+        event.content = content
+        self.assertRaises(
+            SynapseError,
+            self.validator.validate,
+            event
+        )
 
 
 class MockSynapseEvent(SynapseEvent):
diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py
index 933aa61c77..eb329eec50 100644
--- a/tests/federation/test_federation.py
+++ b/tests/federation/test_federation.py
@@ -24,7 +24,6 @@ from ..utils import MockHttpResource, MockClock, MockKey
 from synapse.server import HomeServer
 from synapse.federation import initialize_http_replication
 from synapse.federation.units import Pdu
-from synapse.storage.pdu import PduTuple, PduEntry
 
 
 def make_pdu(prev_pdus=[], **kwargs):
@@ -41,7 +40,7 @@ def make_pdu(prev_pdus=[], **kwargs):
     }
     pdu_fields.update(kwargs)
 
-    return PduTuple(PduEntry(**pdu_fields), prev_pdus)
+    return Pdu(prev_pdus=prev_pdus, **pdu_fields)
 
 
 class FederationTestCase(unittest.TestCase):
@@ -52,177 +51,185 @@ class FederationTestCase(unittest.TestCase):
             "put_json",
         ])
         self.mock_persistence = Mock(spec=[
-            "get_current_state_for_context",
-            "get_pdu",
-            "persist_event",
-            "update_min_depth_for_context",
             "prep_send_transaction",
             "delivered_txn",
             "get_received_txn_response",
             "set_received_txn_response",
         ])
         self.mock_persistence.get_received_txn_response.return_value = (
-                defer.succeed(None)
+            defer.succeed(None)
         )
         self.mock_config = Mock()
         self.mock_config.signing_key = [MockKey()]
         self.clock = MockClock()
-        hs = HomeServer("test",
-                resource_for_federation=self.mock_resource,
-                http_client=self.mock_http_client,
-                db_pool=None,
-                datastore=self.mock_persistence,
-                clock=self.clock,
-                config=self.mock_config,
-                keyring=Mock(),
+        hs = HomeServer(
+            "test",
+            resource_for_federation=self.mock_resource,
+            http_client=self.mock_http_client,
+            db_pool=None,
+            datastore=self.mock_persistence,
+            clock=self.clock,
+            config=self.mock_config,
+            keyring=Mock(),
         )
         self.federation = initialize_http_replication(hs)
         self.distributor = hs.get_distributor()
 
     @defer.inlineCallbacks
     def test_get_state(self):
-        self.mock_persistence.get_current_state_for_context.return_value = (
-            defer.succeed([])
-        )
+        mock_handler = Mock(spec=[
+            "get_state_for_pdu",
+        ])
+
+        self.federation.set_handler(mock_handler)
+
+        mock_handler.get_state_for_pdu.return_value = defer.succeed([])
 
         # Empty context initially
-        (code, response) = yield self.mock_resource.trigger("GET",
-                "/_matrix/federation/v1/state/my-context/", None)
+        (code, response) = yield self.mock_resource.trigger(
+            "GET",
+            "/_matrix/federation/v1/state/my-context/",
+            None
+        )
         self.assertEquals(200, code)
         self.assertFalse(response["pdus"])
 
         # Now lets give the context some state
-        self.mock_persistence.get_current_state_for_context.return_value = (
+        mock_handler.get_state_for_pdu.return_value = (
             defer.succeed([
                 make_pdu(
-                    pdu_id="the-pdu-id",
+                    event_id="the-pdu-id",
                     origin="red",
-                    context="my-context",
-                    pdu_type="m.topic",
-                    ts=123456789000,
+                    room_id="my-context",
+                    type="m.topic",
+                    origin_server_ts=123456789000,
                     depth=1,
-                    is_state=True,
-                    content_json='{"topic":"The topic"}',
+                    content={"topic": "The topic"},
                     state_key="",
                     power_level=1000,
-                    prev_state_id="last-pdu-id",
-                    prev_state_origin="blue",
+                    prev_state="last-pdu-id",
                 ),
             ])
         )
 
-        (code, response) = yield self.mock_resource.trigger("GET",
-                "/_matrix/federation/v1/state/my-context/", None)
+        (code, response) = yield self.mock_resource.trigger(
+            "GET",
+            "/_matrix/federation/v1/state/my-context/",
+            None
+        )
         self.assertEquals(200, code)
         self.assertEquals(1, len(response["pdus"]))
 
     @defer.inlineCallbacks
     def test_get_pdu(self):
-        self.mock_persistence.get_pdu.return_value = (
+        mock_handler = Mock(spec=[
+            "get_persisted_pdu",
+        ])
+
+        self.federation.set_handler(mock_handler)
+
+        mock_handler.get_persisted_pdu.return_value = (
             defer.succeed(None)
         )
 
-        (code, response) = yield self.mock_resource.trigger("GET",
-                "/_matrix/federation/v1/pdu/red/abc123def456/", None)
+        (code, response) = yield self.mock_resource.trigger(
+            "GET",
+            "/_matrix/federation/v1/event/abc123def456/",
+            None
+        )
         self.assertEquals(404, code)
 
         # Now insert such a PDU
-        self.mock_persistence.get_pdu.return_value = (
+        mock_handler.get_persisted_pdu.return_value = (
             defer.succeed(
                 make_pdu(
-                    pdu_id="abc123def456",
+                    event_id="abc123def456",
                     origin="red",
-                    context="my-context",
-                    pdu_type="m.text",
-                    ts=123456789001,
+                    room_id="my-context",
+                    type="m.text",
+                    origin_server_ts=123456789001,
                     depth=1,
-                    content_json='{"text":"Here is the message"}',
+                    content={"text": "Here is the message"},
                 )
             )
         )
 
-        (code, response) = yield self.mock_resource.trigger("GET",
-                "/_matrix/federation/v1/pdu/red/abc123def456/", None)
+        (code, response) = yield self.mock_resource.trigger(
+            "GET",
+            "/_matrix/federation/v1/event/abc123def456/",
+            None
+        )
         self.assertEquals(200, code)
         self.assertEquals(1, len(response["pdus"]))
-        self.assertEquals("m.text", response["pdus"][0]["pdu_type"])
+        self.assertEquals("m.text", response["pdus"][0]["type"])
 
     @defer.inlineCallbacks
     def test_send_pdu(self):
         self.mock_http_client.put_json.return_value = defer.succeed(
-                (200, "OK")
+            (200, "OK")
         )
 
         pdu = Pdu(
-                pdu_id="abc123def456",
-                origin="red",
-                destinations=["remote"],
-                context="my-context",
-                origin_server_ts=123456789002,
-                pdu_type="m.test",
-                content={"testing": "content here"},
-                depth=1,
+            event_id="abc123def456",
+            origin="red",
+            room_id="my-context",
+            type="m.text",
+            origin_server_ts=123456789001,
+            depth=1,
+            content={"text": "Here is the message"},
+            destinations=["remote"],
         )
 
         yield self.federation.send_pdu(pdu)
 
         self.mock_http_client.put_json.assert_called_with(
-                "remote",
-                path="/_matrix/federation/v1/send/1000000/",
-                data={
-                    "origin_server_ts": 1000000,
-                    "origin": "test",
-                    "pdus": [
-                        {
-                            "origin": "red",
-                            "pdu_id": "abc123def456",
-                            "prev_pdus": [],
-                            "origin_server_ts": 123456789002,
-                            "context": "my-context",
-                            "pdu_type": "m.test",
-                            "is_state": False,
-                            "content": {"testing": "content here"},
-                            "depth": 1,
-                        },
-                    ]
-                },
-                json_data_callback=ANY,
+            "remote",
+            path="/_matrix/federation/v1/send/1000000/",
+            data={
+                "origin_server_ts": 1000000,
+                "origin": "test",
+                "pdus": [
+                    pdu.get_dict(),
+                ],
+                'pdu_failures': [],
+            },
+            json_data_callback=ANY,
         )
 
     @defer.inlineCallbacks
     def test_send_edu(self):
         self.mock_http_client.put_json.return_value = defer.succeed(
-                (200, "OK")
+            (200, "OK")
         )
 
         yield self.federation.send_edu(
-                destination="remote",
-                edu_type="m.test",
-                content={"testing": "content here"},
+            destination="remote",
+            edu_type="m.test",
+            content={"testing": "content here"},
         )
 
         # MockClock ensures we can guess these timestamps
         self.mock_http_client.put_json.assert_called_with(
-                "remote",
-                path="/_matrix/federation/v1/send/1000000/",
-                data={
-                    "origin": "test",
-                    "origin_server_ts": 1000000,
-                    "pdus": [],
-                    "edus": [
-                        {
-                            # TODO: SYN-103: Remove "origin" and "destination"
-                            "origin": "test",
-                            "destination": "remote",
-                            "edu_type": "m.test",
-                            "content": {"testing": "content here"},
-                        }
-                    ],
-                },
-                json_data_callback=ANY,
+            "remote",
+            path="/_matrix/federation/v1/send/1000000/",
+            data={
+                "origin": "test",
+                "origin_server_ts": 1000000,
+                "pdus": [],
+                "edus": [
+                    {
+                        # TODO: SYN-103: Remove "origin" and "destination"
+                        "origin": "test",
+                        "destination": "remote",
+                        "edu_type": "m.test",
+                        "content": {"testing": "content here"},
+                    }
+                ],
+                'pdu_failures': [],
+            },
+            json_data_callback=ANY,
         )
 
-
     @defer.inlineCallbacks
     def test_recv_edu(self):
         recv_observer = Mock()
@@ -230,24 +237,26 @@ class FederationTestCase(unittest.TestCase):
 
         self.federation.register_edu_handler("m.test", recv_observer)
 
-        yield self.mock_resource.trigger("PUT",
-                "/_matrix/federation/v1/send/1001000/",
-                """{
-                    "origin": "remote",
-                    "origin_server_ts": 1001000,
-                    "pdus": [],
-                    "edus": [
-                        {
-                            "origin": "remote",
-                            "destination": "test",
-                            "edu_type": "m.test",
-                            "content": {"testing": "reply here"}
-                        }
-                    ]
-                }""")
+        yield self.mock_resource.trigger(
+            "PUT",
+            "/_matrix/federation/v1/send/1001000/",
+            """{
+                "origin": "remote",
+                "origin_server_ts": 1001000,
+                "pdus": [],
+                "edus": [
+                    {
+                        "origin": "remote",
+                        "destination": "test",
+                        "edu_type": "m.test",
+                        "content": {"testing": "reply here"}
+                    }
+                ]
+            }"""
+        )
 
         recv_observer.assert_called_with(
-                "remote", {"testing": "reply here"}
+            "remote", {"testing": "reply here"}
         )
 
     @defer.inlineCallbacks
@@ -278,8 +287,11 @@ class FederationTestCase(unittest.TestCase):
 
         self.federation.register_query_handler("a-question", recv_handler)
 
-        code, response = yield self.mock_resource.trigger("GET",
-            "/_matrix/federation/v1/query/a-question?three=3&four=4", None)
+        code, response = yield self.mock_resource.trigger(
+            "GET",
+            "/_matrix/federation/v1/query/a-question?three=3&four=4",
+            None
+        )
 
         self.assertEquals(200, code)
         self.assertEquals({"another": "response"}, response)
diff --git a/tests/federation/test_pdu_codec.py b/tests/federation/test_pdu_codec.py
deleted file mode 100644
index 0754ef92e8..0000000000
--- a/tests/federation/test_pdu_codec.py
+++ /dev/null
@@ -1,160 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from tests import unittest
-
-from synapse.federation.pdu_codec import (
-    PduCodec, encode_event_id, decode_event_id
-)
-from synapse.federation.units import Pdu
-#from synapse.api.events.room import MessageEvent
-
-from synapse.server import HomeServer
-
-from mock import Mock
-
-
-class PduCodecTestCase(unittest.TestCase):
-    def setUp(self):
-        self.hs = HomeServer("blargle.net")
-        self.event_factory = self.hs.get_event_factory()
-
-        self.codec = PduCodec(self.hs)
-
-    def test_decode_event_id(self):
-        self.assertEquals(
-            ("foo", "bar.com"),
-            decode_event_id("foo@bar.com", "A")
-        )
-
-        self.assertEquals(
-            ("foo", "bar.com"),
-            decode_event_id("foo", "bar.com")
-        )
-
-    def test_encode_event_id(self):
-        self.assertEquals("A@B", encode_event_id("A", "B"))
-
-    def test_codec_event_id(self):
-        event_id = "aa@bb.com"
-
-        self.assertEquals(
-            event_id,
-            encode_event_id(*decode_event_id(event_id, None))
-        )
-
-        pdu_id = ("aa", "bb.com")
-
-        self.assertEquals(
-            pdu_id,
-            decode_event_id(encode_event_id(*pdu_id), None)
-        )
-
-    def test_event_from_pdu(self):
-        pdu = Pdu(
-            pdu_id="foo",
-            context="rooooom",
-            pdu_type="m.room.message",
-            origin="bar.com",
-            origin_server_ts=12345,
-            depth=5,
-            prev_pdus=[("alice", "bob.com")],
-            is_state=False,
-            content={"msgtype": u"test"},
-        )
-
-        event = self.codec.event_from_pdu(pdu)
-
-        self.assertEquals("foo@bar.com", event.event_id)
-        self.assertEquals(pdu.context, event.room_id)
-        self.assertEquals(pdu.is_state, event.is_state)
-        self.assertEquals(pdu.depth, event.depth)
-        self.assertEquals(["alice@bob.com"], event.prev_events)
-        self.assertEquals(pdu.content, event.content)
-
-    def test_pdu_from_event(self):
-        event = self.event_factory.create_event(
-            etype="m.room.message",
-            event_id="gargh_id",
-            room_id="rooom",
-            user_id="sender",
-            content={"msgtype": u"test"},
-        )
-
-        pdu = self.codec.pdu_from_event(event)
-
-        self.assertEquals(event.event_id, pdu.pdu_id)
-        self.assertEquals(self.hs.hostname, pdu.origin)
-        self.assertEquals(event.room_id, pdu.context)
-        self.assertEquals(event.content, pdu.content)
-        self.assertEquals(event.type, pdu.pdu_type)
-
-        event = self.event_factory.create_event(
-            etype="m.room.message",
-            event_id="gargh_id@bob.com",
-            room_id="rooom",
-            user_id="sender",
-            content={"msgtype": u"test"},
-        )
-
-        pdu = self.codec.pdu_from_event(event)
-
-        self.assertEquals("gargh_id", pdu.pdu_id)
-        self.assertEquals("bob.com", pdu.origin)
-        self.assertEquals(event.room_id, pdu.context)
-        self.assertEquals(event.content, pdu.content)
-        self.assertEquals(event.type, pdu.pdu_type)
-
-    def test_event_from_state_pdu(self):
-        pdu = Pdu(
-            pdu_id="foo",
-            context="rooooom",
-            pdu_type="m.room.topic",
-            origin="bar.com",
-            origin_server_ts=12345,
-            depth=5,
-            prev_pdus=[("alice", "bob.com")],
-            is_state=True,
-            content={"topic": u"test"},
-            state_key="",
-        )
-
-        event = self.codec.event_from_pdu(pdu)
-
-        self.assertEquals("foo@bar.com", event.event_id)
-        self.assertEquals(pdu.context, event.room_id)
-        self.assertEquals(pdu.is_state, event.is_state)
-        self.assertEquals(pdu.depth, event.depth)
-        self.assertEquals(["alice@bob.com"], event.prev_events)
-        self.assertEquals(pdu.content, event.content)
-        self.assertEquals(pdu.state_key, event.state_key)
-
-    def test_pdu_from_state_event(self):
-        event = self.event_factory.create_event(
-            etype="m.room.topic",
-            event_id="gargh_id",
-            room_id="rooom",
-            user_id="sender",
-            content={"topic": u"test"},
-        )
-
-        pdu = self.codec.pdu_from_event(event)
-
-        self.assertEquals(event.event_id, pdu.pdu_id)
-        self.assertEquals(self.hs.hostname, pdu.origin)
-        self.assertEquals(event.room_id, pdu.context)
-        self.assertEquals(event.content, pdu.content)
-        self.assertEquals(event.type, pdu.pdu_type)
-        self.assertEquals(event.state_key, pdu.state_key)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index e10a49a8ac..8e164e4be0 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -21,9 +21,8 @@ from mock import Mock
 
 from synapse.server import HomeServer
 from synapse.handlers.directory import DirectoryHandler
-from synapse.storage.directory import RoomAliasMapping
 
-from tests.utils import SQLiteMemoryDbPool
+from tests.utils import SQLiteMemoryDbPool, MockKey
 
 
 class DirectoryHandlers(object):
@@ -41,6 +40,7 @@ class DirectoryTestCase(unittest.TestCase):
         ])
 
         self.query_handlers = {}
+
         def register_query_handler(query_type, handler):
             self.query_handlers[query_type] = handler
         self.mock_federation.register_query_handler = register_query_handler
@@ -48,11 +48,16 @@ class DirectoryTestCase(unittest.TestCase):
         db_pool = SQLiteMemoryDbPool()
         yield db_pool.prepare()
 
-        hs = HomeServer("test",
+        self.mock_config = Mock()
+        self.mock_config.signing_key = [MockKey()]
+
+        hs = HomeServer(
+            "test",
             db_pool=db_pool,
             http_client=None,
             resource_for_federation=Mock(),
             replication_layer=self.mock_federation,
+            config=self.mock_config,
         )
         hs.handlers = DirectoryHandlers(hs)
 
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 219b2c4c5e..a9d6b2bb17 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -17,16 +17,15 @@ from twisted.internet import defer
 from tests import unittest
 
 from synapse.api.events.room import (
-    InviteJoinEvent, MessageEvent, RoomMemberEvent
+    MessageEvent,
 )
-from synapse.api.constants import Membership
 from synapse.handlers.federation import FederationHandler
 from synapse.server import HomeServer
 from synapse.federation.units import Pdu
 
 from mock import NonCallableMock, ANY
 
-from ..utils import get_mock_call_args, MockKey
+from ..utils import MockKey
 
 
 class FederationTestCase(unittest.TestCase):
@@ -36,6 +35,14 @@ class FederationTestCase(unittest.TestCase):
         self.mock_config = NonCallableMock()
         self.mock_config.signing_key = [MockKey()]
 
+        self.state_handler = NonCallableMock(spec_set=[
+            "annotate_state_groups",
+        ])
+
+        self.auth = NonCallableMock(spec_set=[
+            "check",
+        ])
+
         self.hostname = "test"
         hs = HomeServer(
             self.hostname,
@@ -53,6 +60,8 @@ class FederationTestCase(unittest.TestCase):
                 "federation_handler",
             ]),
             config=self.mock_config,
+            auth=self.auth,
+            state_handler=self.state_handler,
         )
 
         self.datastore = hs.get_datastore()
@@ -65,74 +74,35 @@ class FederationTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_msg(self):
         pdu = Pdu(
-            pdu_type=MessageEvent.TYPE,
-            context="foo",
+            type=MessageEvent.TYPE,
+            room_id="foo",
             content={"msgtype": u"fooo"},
             origin_server_ts=0,
-            pdu_id="a",
+            event_id="$a:b",
             origin="b",
         )
 
-        store_id = "ASD"
-        self.datastore.persist_event.return_value = defer.succeed(store_id)
+        self.datastore.persist_event.return_value = defer.succeed(None)
         self.datastore.get_room.return_value = defer.succeed(True)
 
+        self.state_handler.annotate_state_groups.return_value = (
+            defer.succeed(False)
+        )
+
         yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
 
         self.datastore.persist_event.assert_called_once_with(
             ANY, False, is_new_state=False
         )
-        self.notifier.on_new_room_event.assert_called_once_with(ANY, extra_users=[])
-
-    @defer.inlineCallbacks
-    def test_invite_join_target_this(self):
-        room_id = "foo"
-        user_id = "@bob:red"
 
-        pdu = Pdu(
-            pdu_type=InviteJoinEvent.TYPE,
-            user_id=user_id,
-            target_host=self.hostname,
-            context=room_id,
-            content={},
-            origin_server_ts=0,
-            pdu_id="a",
-            origin="b",
+        self.state_handler.annotate_state_groups.assert_called_once_with(
+            ANY,
+            old_state=None,
         )
 
-        yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
+        self.auth.check.assert_called_once_with(ANY, raises=True)
 
-        mem_handler = self.handlers.room_member_handler
-        self.assertEquals(1, mem_handler.change_membership.call_count)
-        call_args = get_mock_call_args(
-            lambda event, do_auth: None,
-            mem_handler.change_membership
+        self.notifier.on_new_room_event.assert_called_once_with(
+            ANY,
+            extra_users=[]
         )
-        self.assertEquals(False, call_args["do_auth"])
-
-        new_event = call_args["event"]
-        self.assertEquals(RoomMemberEvent.TYPE, new_event.type)
-        self.assertEquals(room_id, new_event.room_id)
-        self.assertEquals(user_id, new_event.state_key)
-        self.assertEquals(Membership.JOIN, new_event.membership)
-
-    @defer.inlineCallbacks
-    def test_invite_join_target_other(self):
-        room_id = "foo"
-        user_id = "@bob:red"
-
-        pdu = Pdu(
-            pdu_type=InviteJoinEvent.TYPE,
-            user_id=user_id,
-            state_key="@red:not%s" % self.hostname,
-            context=room_id,
-            content={},
-            origin_server_ts=0,
-            pdu_id="a",
-            origin="b",
-        )
-
-        yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
-
-        mem_handler = self.handlers.room_member_handler
-        self.assertEquals(0, mem_handler.change_membership.call_count)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index fdc2e8de4a..a6af648def 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -51,6 +51,7 @@ def _expect_edu(destination, edu_type, content, origin="test"):
                 "content": content,
             }
         ],
+        "pdu_failures": [],
     }
 
 def _make_edu_json(origin, edu_type, content):
diff --git a/tests/handlers/test_presencelike.py b/tests/handlers/test_presencelike.py
index 047752ad68..532ecf0f2c 100644
--- a/tests/handlers/test_presencelike.py
+++ b/tests/handlers/test_presencelike.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 from mock import Mock, call, ANY
 
-from ..utils import MockClock
+from ..utils import MockClock, MockKey
 
 from synapse.server import HomeServer
 from synapse.api.constants import PresenceState
@@ -57,6 +57,9 @@ class PresenceAndProfileHandlers(object):
 class PresenceProfilelikeDataTestCase(unittest.TestCase):
 
     def setUp(self):
+        self.mock_config = Mock()
+        self.mock_config.signing_key = [MockKey()]
+
         hs = HomeServer("test",
                 clock=MockClock(),
                 db_pool=None,
@@ -72,6 +75,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
                 resource_for_federation=Mock(),
                 http_client=None,
                 replication_layer=MockReplication(),
+                config=self.mock_config,
             )
         hs.handlers = PresenceAndProfileHandlers(hs)
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 5dc9b456e1..1660e7e928 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,7 +24,7 @@ from synapse.server import HomeServer
 from synapse.handlers.profile import ProfileHandler
 from synapse.api.constants import Membership
 
-from tests.utils import SQLiteMemoryDbPool
+from tests.utils import SQLiteMemoryDbPool, MockKey
 
 
 class ProfileHandlers(object):
@@ -49,12 +49,16 @@ class ProfileTestCase(unittest.TestCase):
         db_pool = SQLiteMemoryDbPool()
         yield db_pool.prepare()
 
+        self.mock_config = Mock()
+        self.mock_config.signing_key = [MockKey()]
+
         hs = HomeServer("test",
                 db_pool=db_pool,
                 http_client=None,
                 handlers=None,
                 resource_for_federation=Mock(),
                 replication_layer=self.mock_federation,
+                config=self.mock_config,
             )
         hs.handlers = ProfileHandlers(hs)
 
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index c88d1c8840..55c9f6e142 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 from tests import unittest
 
 from synapse.api.events.room import (
-    InviteJoinEvent, RoomMemberEvent, RoomConfigEvent
+    RoomMemberEvent,
 )
 from synapse.api.constants import Membership
 from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
@@ -34,6 +34,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
     def setUp(self):
         self.mock_config = NonCallableMock()
         self.mock_config.signing_key = [MockKey()]
+
         self.hostname = "red"
         hs = HomeServer(
             self.hostname,
@@ -57,13 +58,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
                 "profile_handler",
                 "federation_handler",
             ]),
-            auth=NonCallableMock(spec_set=["check"]),
-            state_handler=NonCallableMock(spec_set=["handle_new_event"]),
+            auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
+            state_handler=NonCallableMock(spec_set=[
+                "annotate_state_groups",
+            ]),
             config=self.mock_config,
         )
 
         self.federation = NonCallableMock(spec_set=[
             "handle_new_event",
+            "send_invite",
             "get_state_for_room",
         ])
 
@@ -106,7 +110,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
 
         joined = ["red", "green"]
 
-        self.state_handler.handle_new_event.return_value = defer.succeed(True)
         self.datastore.get_joined_hosts_for_room.return_value = (
             defer.succeed(joined)
         )
@@ -114,18 +117,29 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         store_id = "store_id_fooo"
         self.datastore.persist_event.return_value = defer.succeed(store_id)
 
+        self.datastore.get_room_member.return_value = defer.succeed(None)
+
+        event.state_events = {
+            (RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
+                user_id="@alice:green",
+                room_id=room_id,
+            ),
+            (RoomMemberEvent.TYPE, "@bob:red"): self._create_member(
+                user_id="@bob:red",
+                room_id=room_id,
+            ),
+            (RoomMemberEvent.TYPE, target_user_id): event,
+        }
+
         # Actual invocation
         yield self.room_member_handler.change_membership(event)
 
-        self.state_handler.handle_new_event.assert_called_once_with(
-            event, self.snapshot,
-        )
         self.federation.handle_new_event.assert_called_once_with(
             event, self.snapshot,
         )
 
         self.assertEquals(
-            set(["blue", "red", "green"]),
+            set(["red", "green"]),
             set(event.destinations)
         )
 
@@ -144,28 +158,19 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         room_id = "!foo:red"
         user_id = "@bob:red"
         user = self.hs.parse_userid(user_id)
-        target_user_id = "@bob:red"
-        content = {"membership": Membership.JOIN}
 
-        event = self.hs.get_event_factory().create_event(
-            etype=RoomMemberEvent.TYPE,
+        event = self._create_member(
             user_id=user_id,
-            state_key=target_user_id,
             room_id=room_id,
-            membership=Membership.JOIN,
-            content=content,
         )
 
         joined = ["red", "green"]
 
-        self.state_handler.handle_new_event.return_value = defer.succeed(True)
-
         def get_joined(*args):
             return defer.succeed(joined)
 
         self.datastore.get_joined_hosts_for_room.side_effect = get_joined
 
-
         store_id = "store_id_fooo"
         self.datastore.persist_event.return_value = defer.succeed(store_id)
         self.datastore.get_room.return_value = defer.succeed(1)  # Not None.
@@ -178,12 +183,17 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         join_signal_observer = Mock()
         self.distributor.observe("user_joined_room", join_signal_observer)
 
+        event.state_events = {
+            (RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
+                user_id="@alice:green",
+                room_id=room_id,
+            ),
+            (RoomMemberEvent.TYPE, user_id): event,
+        }
+
         # Actual invocation
         yield self.room_member_handler.change_membership(event)
 
-        self.state_handler.handle_new_event.assert_called_once_with(
-            event, self.snapshot
-        )
         self.federation.handle_new_event.assert_called_once_with(
             event, self.snapshot
         )
@@ -197,138 +207,32 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             event
         )
         self.notifier.on_new_room_event.assert_called_once_with(
-                event, extra_users=[user])
-
-        join_signal_observer.assert_called_with(
-                user=user, room_id=room_id)
-
-    @defer.inlineCallbacks
-    def STALE_test_invite_join(self):
-        room_id = "foo"
-        user_id = "@bob:red"
-        target_user_id = "@bob:red"
-        content = {"membership": Membership.JOIN}
-
-        event = self.hs.get_event_factory().create_event(
-            etype=RoomMemberEvent.TYPE,
-            user_id=user_id,
-            target_user_id=target_user_id,
-            room_id=room_id,
-            membership=Membership.JOIN,
-            content=content,
-        )
-
-        joined = ["red", "blue", "green"]
-
-        self.state_handler.handle_new_event.return_value = defer.succeed(True)
-        self.datastore.get_joined_hosts_for_room.return_value = (
-            defer.succeed(joined)
-        )
-
-        store_id = "store_id_fooo"
-        self.datastore.store_room_member.return_value = defer.succeed(store_id)
-        self.datastore.get_room.return_value = defer.succeed(None)
-
-        prev_state = NonCallableMock(name="prev_state")
-        prev_state.membership = Membership.INVITE
-        prev_state.sender = "@foo:blue"
-        self.datastore.get_room_member.return_value = defer.succeed(prev_state)
-
-        # Actual invocation
-        yield self.room_member_handler.change_membership(event)
-
-        self.datastore.get_room_member.assert_called_once_with(
-            target_user_id, room_id
+            event, extra_users=[user]
         )
 
-        self.assertTrue(self.federation.handle_new_event.called)
-        args = self.federation.handle_new_event.call_args[0]
-        invite_join_event = args[0]
-
-        self.assertTrue(InviteJoinEvent.TYPE, invite_join_event.TYPE)
-        self.assertTrue("blue", invite_join_event.target_host)
-        self.assertTrue(room_id, invite_join_event.room_id)
-        self.assertTrue(user_id, invite_join_event.user_id)
-        self.assertFalse(hasattr(invite_join_event, "state_key"))
-
-        self.assertEquals(
-            set(["blue"]),
-            set(invite_join_event.destinations)
-        )
-
-        self.federation.get_state_for_room.assert_called_once_with(
-            "blue", room_id
+        join_signal_observer.assert_called_with(
+            user=user, room_id=room_id
         )
 
-        self.assertFalse(self.datastore.store_room_member.called)
-
-        self.assertFalse(self.notifier.on_new_room_event.called)
-        self.assertFalse(self.state_handler.handle_new_event.called)
-
-    @defer.inlineCallbacks
-    def STALE_test_invite_join_public(self):
-        room_id = "#foo:blue"
-        user_id = "@bob:red"
-        target_user_id = "@bob:red"
-        content = {"membership": Membership.JOIN}
-
-        event = self.hs.get_event_factory().create_event(
+    def _create_member(self, user_id, room_id):
+        return self.hs.get_event_factory().create_event(
             etype=RoomMemberEvent.TYPE,
             user_id=user_id,
-            target_user_id=target_user_id,
+            state_key=user_id,
             room_id=room_id,
             membership=Membership.JOIN,
-            content=content,
-        )
-
-        joined = ["red", "blue", "green"]
-
-        self.state_handler.handle_new_event.return_value = defer.succeed(True)
-        self.datastore.get_joined_hosts_for_room.return_value = (
-            defer.succeed(joined)
-        )
-
-        store_id = "store_id_fooo"
-        self.datastore.store_room_member.return_value = defer.succeed(store_id)
-        self.datastore.get_room.return_value = defer.succeed(None)
-
-        prev_state = NonCallableMock(name="prev_state")
-        prev_state.membership = Membership.INVITE
-        prev_state.sender = "@foo:blue"
-        self.datastore.get_room_member.return_value = defer.succeed(prev_state)
-
-        # Actual invocation
-        yield self.room_member_handler.change_membership(event)
-
-        self.assertTrue(self.federation.handle_new_event.called)
-        args = self.federation.handle_new_event.call_args[0]
-        invite_join_event = args[0]
-
-        self.assertTrue(InviteJoinEvent.TYPE, invite_join_event.TYPE)
-        self.assertTrue("blue", invite_join_event.target_host)
-        self.assertTrue("foo", invite_join_event.room_id)
-        self.assertTrue(user_id, invite_join_event.user_id)
-        self.assertFalse(hasattr(invite_join_event, "state_key"))
-
-        self.assertEquals(
-            set(["blue"]),
-            set(invite_join_event.destinations)
+            content={"membership": Membership.JOIN},
         )
 
-        self.federation.get_state_for_room.assert_called_once_with(
-            "blue", "foo"
-        )
-
-        self.assertFalse(self.datastore.store_room_member.called)
-
-        self.assertFalse(self.notifier.on_new_room_event.called)
-        self.assertFalse(self.state_handler.handle_new_event.called)
-
 
 class RoomCreationTest(unittest.TestCase):
 
     def setUp(self):
         self.hostname = "red"
+
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
+
         hs = HomeServer(
             self.hostname,
             db_pool=None,
@@ -345,12 +249,14 @@ class RoomCreationTest(unittest.TestCase):
                 "room_member_handler",
                 "federation_handler",
             ]),
-            auth=NonCallableMock(spec_set=["check"]),
-            state_handler=NonCallableMock(spec_set=["handle_new_event"]),
+            auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
+            state_handler=NonCallableMock(spec_set=[
+                "annotate_state_groups",
+            ]),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
-            config=NonCallableMock(),
+            config=self.mock_config,
         )
 
         self.federation = NonCallableMock(spec_set=[
@@ -373,6 +279,11 @@ class RoomCreationTest(unittest.TestCase):
         ])
         self.room_member_handler = self.handlers.room_member_handler
 
+        def annotate(event):
+            event.state_events = {}
+            return defer.succeed(None)
+        self.state_handler.annotate_state_groups.side_effect = annotate
+
         def hosts(room):
             return defer.succeed([])
         self.datastore.get_joined_hosts_for_room.side_effect = hosts
@@ -400,6 +311,6 @@ class RoomCreationTest(unittest.TestCase):
         self.assertEquals(user_id, join_event.user_id)
         self.assertEquals(user_id, join_event.state_key)
 
-        self.assertTrue(self.state_handler.handle_new_event.called)
+        self.assertTrue(self.state_handler.annotate_state_groups.called)
 
         self.assertTrue(self.federation.handle_new_event.called)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f1d3b27f74..07acda5eee 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -40,6 +40,7 @@ def _expect_edu(destination, edu_type, content, origin="test"):
                 "content": content,
             }
         ],
+        "pdu_failures": [],
     }
 
 
diff --git a/tests/rest/test_events.py b/tests/rest/test_events.py
index 79b371c04d..4a3234c332 100644
--- a/tests/rest/test_events.py
+++ b/tests/rest/test_events.py
@@ -25,10 +25,7 @@ import synapse.rest.room
 
 from synapse.server import HomeServer
 
-# python imports
-import json
-
-from ..utils import MockHttpResource, MemoryDataStore
+from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey
 from .utils import RestTestCase
 
 from mock import Mock, NonCallableMock
@@ -49,7 +46,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
     def tearDown(self):
         pass
 
-    def test_long_poll(self):
+    def TODO_test_long_poll(self):
         # stream from 'end' key, send (self+other) message, expect message.
 
         # stream from 'END', send (self+other) message, expect message.
@@ -64,7 +61,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
 
         pass
 
-    def test_stream_forward(self):
+    def TODO_test_stream_forward(self):
         # stream from START, expect injected items
 
         # stream from 'start' key, expect same content
@@ -80,14 +77,14 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
         # returned as end key
         pass
 
-    def test_limits(self):
+    def TODO_test_limits(self):
         # stream from a key, expect limit_num items
 
         # stream from START, expect limit_num items
 
         pass
 
-    def test_range(self):
+    def TODO_test_range(self):
         # stream from key to key, expect X items
 
         # stream from key to END, expect X items
@@ -97,7 +94,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
         # stream from START to END, expect all items
         pass
 
-    def test_direction(self):
+    def TODO_test_direction(self):
         # stream from END to START and fwds, expect newest first
 
         # stream from END to START and bwds, expect oldest first
@@ -116,19 +113,20 @@ class EventStreamPermissionsTestCase(RestTestCase):
     def setUp(self):
         self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
 
-        state_handler = Mock(spec=["handle_new_event"])
-        state_handler.handle_new_event.return_value = True
-
         persistence_service = Mock(spec=["get_latest_pdus_in_context"])
         persistence_service.get_latest_pdus_in_context.return_value = []
 
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
+
+        db_pool = SQLiteMemoryDbPool()
+        yield db_pool.prepare()
+
         hs = HomeServer(
             "test",
-            db_pool=None,
+            db_pool=db_pool,
             http_client=None,
             replication_layer=Mock(),
-            state_handler=state_handler,
-            datastore=MemoryDataStore(),
             persistence_service=persistence_service,
             clock=Mock(spec=[
                 "call_later",
@@ -139,7 +137,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
-            config=NonCallableMock(),
+            config=self.mock_config,
         )
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
@@ -148,6 +146,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
         hs.get_handlers().federation_handler = Mock()
 
         hs.get_clock().time_msec.return_value = 1000000
+        hs.get_clock().time.return_value = 1000
 
         synapse.rest.register.register_servlets(hs, self.mock_resource)
         synapse.rest.events.register_servlets(hs, self.mock_resource)
@@ -172,12 +171,14 @@ class EventStreamPermissionsTestCase(RestTestCase):
     def test_stream_basic_permissions(self):
         # invalid token, expect 403
         (code, response) = yield self.mock_resource.trigger_get(
-                           "/events?access_token=%s" % ("invalid" + self.token))
+            "/events?access_token=%s" % ("invalid" + self.token, )
+        )
         self.assertEquals(403, code, msg=str(response))
 
         # valid token, expect content
         (code, response) = yield self.mock_resource.trigger_get(
-                           "/events?access_token=%s&timeout=0" % (self.token))
+            "/events?access_token=%s&timeout=0" % (self.token,)
+        )
         self.assertEquals(200, code, msg=str(response))
         self.assertTrue("chunk" in response)
         self.assertTrue("start" in response)
@@ -185,15 +186,23 @@ class EventStreamPermissionsTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_stream_room_permissions(self):
-        room_id = yield self.create_room_as(self.other_user,
-                                            tok=self.other_token)
+        room_id = yield self.create_room_as(
+            self.other_user,
+            tok=self.other_token
+        )
         yield self.send(room_id, tok=self.other_token)
 
         # invited to room (expect no content for room)
-        yield self.invite(room_id, src=self.other_user, targ=self.user_id,
-                          tok=self.other_token)
+        yield self.invite(
+            room_id,
+            src=self.other_user,
+            targ=self.user_id,
+            tok=self.other_token
+        )
+
         (code, response) = yield self.mock_resource.trigger_get(
-                           "/events?access_token=%s&timeout=0" % (self.token))
+            "/events?access_token=%s&timeout=0" % (self.token,)
+        )
         self.assertEquals(200, code, msg=str(response))
 
         self.assertEquals(0, len(response["chunk"]))
@@ -203,7 +212,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
 
         # left to room (expect no content for room)
 
-    def test_stream_items(self):
+    def TODO_test_stream_items(self):
         # new user, no content
 
         # join room, expect 1 item (join)
diff --git a/tests/rest/test_profile.py b/tests/rest/test_profile.py
index b0f48e7fd8..3a0d1e700a 100644
--- a/tests/rest/test_profile.py
+++ b/tests/rest/test_profile.py
@@ -18,9 +18,9 @@
 from tests import unittest
 from twisted.internet import defer
 
-from mock import Mock
+from mock import Mock, NonCallableMock
 
-from ..utils import MockHttpResource
+from ..utils import MockHttpResource, MockKey
 
 from synapse.api.errors import SynapseError, AuthError
 from synapse.server import HomeServer
@@ -41,6 +41,9 @@ class ProfileTestCase(unittest.TestCase):
             "set_avatar_url",
         ])
 
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
+
         hs = HomeServer("test",
             db_pool=None,
             http_client=None,
@@ -48,6 +51,7 @@ class ProfileTestCase(unittest.TestCase):
             federation=Mock(),
             replication_layer=Mock(),
             datastore=None,
+            config=self.mock_config,
         )
 
         def _get_user_by_req(request=None):
diff --git a/tests/rest/test_rooms.py b/tests/rest/test_rooms.py
index 1ce9b8a83d..61b01d369d 100644
--- a/tests/rest/test_rooms.py
+++ b/tests/rest/test_rooms.py
@@ -23,11 +23,14 @@ from synapse.api.constants import Membership
 
 from synapse.server import HomeServer
 
+from tests import unittest
+
 # python imports
 import json
 import urllib
+import types
 
-from ..utils import MockHttpResource, MemoryDataStore
+from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey
 from .utils import RestTestCase
 
 from mock import Mock, NonCallableMock
@@ -44,24 +47,21 @@ class RoomPermissionsTestCase(RestTestCase):
     def setUp(self):
         self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
 
-        state_handler = Mock(spec=["handle_new_event"])
-        state_handler.handle_new_event.return_value = True
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
 
-        persistence_service = Mock(spec=["get_latest_pdus_in_context"])
-        persistence_service.get_latest_pdus_in_context.return_value = []
+        db_pool = SQLiteMemoryDbPool()
+        yield db_pool.prepare()
 
         hs = HomeServer(
             "red",
-            db_pool=None,
+            db_pool=db_pool,
             http_client=None,
-            datastore=MemoryDataStore(),
             replication_layer=Mock(),
-            state_handler=state_handler,
-            persistence_service=persistence_service,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
-            config=NonCallableMock(),
+            config=self.mock_config,
         )
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
@@ -76,6 +76,10 @@ class RoomPermissionsTestCase(RestTestCase):
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
+        def _insert_client_ip(*args, **kwargs):
+            return defer.succeed(None)
+        hs.get_datastore().insert_client_ip = _insert_client_ip
+
         self.auth_user_id = self.rmcreator_id
 
         synapse.rest.room.register_servlets(hs, self.mock_resource)
@@ -147,38 +151,55 @@ class RoomPermissionsTestCase(RestTestCase):
     @defer.inlineCallbacks
     def test_send_message(self):
         msg_content = '{"msgtype":"m.text","body":"hello"}'
-        send_msg_path = ("/rooms/%s/send/m.room.message/mid1" %
-                        (self.created_rmid))
+        send_msg_path = (
+            "/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,)
+        )
 
         # send message in uncreated room, expect 403
         (code, response) = yield self.mock_resource.trigger(
-                           "PUT",
-                           "/rooms/%s/send/m.room.message/mid2" %
-                           (self.uncreated_rmid), msg_content)
+            "PUT",
+            "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
+            msg_content
+        )
         self.assertEquals(403, code, msg=str(response))
 
         # send message in created room not joined (no state), expect 403
         (code, response) = yield self.mock_resource.trigger(
-                           "PUT", send_msg_path, msg_content)
+            "PUT",
+            send_msg_path,
+            msg_content
+        )
         self.assertEquals(403, code, msg=str(response))
 
         # send message in created room and invited, expect 403
-        yield self.invite(room=self.created_rmid, src=self.rmcreator_id,
-                          targ=self.user_id)
+        yield self.invite(
+            room=self.created_rmid,
+            src=self.rmcreator_id,
+            targ=self.user_id
+        )
         (code, response) = yield self.mock_resource.trigger(
-                           "PUT", send_msg_path, msg_content)
+            "PUT",
+            send_msg_path,
+            msg_content
+        )
         self.assertEquals(403, code, msg=str(response))
 
         # send message in created room and joined, expect 200
         yield self.join(room=self.created_rmid, user=self.user_id)
         (code, response) = yield self.mock_resource.trigger(
-                           "PUT", send_msg_path, msg_content)
+            "PUT",
+            send_msg_path,
+            msg_content
+        )
         self.assertEquals(200, code, msg=str(response))
 
         # send message in created room and left, expect 403
         yield self.leave(room=self.created_rmid, user=self.user_id)
         (code, response) = yield self.mock_resource.trigger(
-                           "PUT", send_msg_path, msg_content)
+            "PUT",
+            send_msg_path,
+            msg_content
+        )
         self.assertEquals(403, code, msg=str(response))
 
     @defer.inlineCallbacks
@@ -215,9 +236,14 @@ class RoomPermissionsTestCase(RestTestCase):
 
         # set/get topic in created PRIVATE room and joined, expect 200
         yield self.join(room=self.created_rmid, user=self.user_id)
+
+        # Only room ops can set topic by default
+        self.auth_user_id = self.rmcreator_id
         (code, response) = yield self.mock_resource.trigger(
                            "PUT", topic_path, topic_content)
         self.assertEquals(200, code, msg=str(response))
+        self.auth_user_id = self.user_id
+
         (code, response) = yield self.mock_resource.trigger_get(topic_path)
         self.assertEquals(200, code, msg=str(response))
         self.assert_dict(json.loads(topic_content), response)
@@ -381,45 +407,55 @@ class RoomPermissionsTestCase(RestTestCase):
         # set [invite/join/left] of self, set [invite/join/left] of other,
         # expect all 403s
         for usr in [self.user_id, self.rmcreator_id]:
-            yield self.change_membership(room=room, src=self.user_id,
-                                     targ=usr,
-                                     membership=Membership.INVITE,
-                                     expect_code=403)
-            yield self.change_membership(room=room, src=self.user_id,
-                                     targ=usr,
-                                     membership=Membership.JOIN,
-                                     expect_code=403)
-            yield self.change_membership(room=room, src=self.user_id,
-                                     targ=usr,
-                                     membership=Membership.LEAVE,
-                                     expect_code=403)
+            yield self.change_membership(
+                room=room,
+                src=self.user_id,
+                targ=usr,
+                membership=Membership.INVITE,
+                expect_code=403
+            )
+
+            yield self.change_membership(
+                room=room,
+                src=self.user_id,
+                targ=usr,
+                membership=Membership.JOIN,
+                expect_code=403
+            )
+
+        # It is always valid to LEAVE if you've already left (currently.)
+        yield self.change_membership(
+            room=room,
+            src=self.user_id,
+            targ=self.rmcreator_id,
+            membership=Membership.LEAVE,
+            expect_code=403
+        )
 
 
 class RoomsMemberListTestCase(RestTestCase):
     """ Tests /rooms/$room_id/members/list REST events."""
     user_id = "@sid1:red"
 
+    @defer.inlineCallbacks
     def setUp(self):
         self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
 
-        state_handler = Mock(spec=["handle_new_event"])
-        state_handler.handle_new_event.return_value = True
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
 
-        persistence_service = Mock(spec=["get_latest_pdus_in_context"])
-        persistence_service.get_latest_pdus_in_context.return_value = []
+        db_pool = SQLiteMemoryDbPool()
+        yield db_pool.prepare()
 
         hs = HomeServer(
             "red",
-            db_pool=None,
+            db_pool=db_pool,
             http_client=None,
-            datastore=MemoryDataStore(),
             replication_layer=Mock(),
-            state_handler=state_handler,
-            persistence_service=persistence_service,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
-            config=NonCallableMock(),
+            config=self.mock_config,
         )
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
@@ -436,6 +472,10 @@ class RoomsMemberListTestCase(RestTestCase):
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
+        def _insert_client_ip(*args, **kwargs):
+            return defer.succeed(None)
+        hs.get_datastore().insert_client_ip = _insert_client_ip
+
         synapse.rest.room.register_servlets(hs, self.mock_resource)
 
     def tearDown(self):
@@ -487,28 +527,26 @@ class RoomsCreateTestCase(RestTestCase):
     """ Tests /rooms and /rooms/$room_id REST events. """
     user_id = "@sid1:red"
 
+    @defer.inlineCallbacks
     def setUp(self):
         self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
         self.auth_user_id = self.user_id
 
-        state_handler = Mock(spec=["handle_new_event"])
-        state_handler.handle_new_event.return_value = True
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
 
-        persistence_service = Mock(spec=["get_latest_pdus_in_context"])
-        persistence_service.get_latest_pdus_in_context.return_value = []
+        db_pool = SQLiteMemoryDbPool()
+        yield db_pool.prepare()
 
         hs = HomeServer(
             "red",
-            db_pool=None,
+            db_pool=db_pool,
             http_client=None,
-            datastore=MemoryDataStore(),
             replication_layer=Mock(),
-            state_handler=state_handler,
-            persistence_service=persistence_service,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
-            config=NonCallableMock(),
+            config=self.mock_config,
         )
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
@@ -523,6 +561,10 @@ class RoomsCreateTestCase(RestTestCase):
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
+        def _insert_client_ip(*args, **kwargs):
+            return defer.succeed(None)
+        hs.get_datastore().insert_client_ip = _insert_client_ip
+
         synapse.rest.room.register_servlets(hs, self.mock_resource)
 
     def tearDown(self):
@@ -592,24 +634,21 @@ class RoomTopicTestCase(RestTestCase):
         self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
         self.auth_user_id = self.user_id
 
-        state_handler = Mock(spec=["handle_new_event"])
-        state_handler.handle_new_event.return_value = True
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
 
-        persistence_service = Mock(spec=["get_latest_pdus_in_context"])
-        persistence_service.get_latest_pdus_in_context.return_value = []
+        db_pool = SQLiteMemoryDbPool()
+        yield db_pool.prepare()
 
         hs = HomeServer(
             "red",
-            db_pool=None,
+            db_pool=db_pool,
             http_client=None,
-            datastore=MemoryDataStore(),
             replication_layer=Mock(),
-            state_handler=state_handler,
-            persistence_service=persistence_service,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
-            config=NonCallableMock(),
+            config=self.mock_config,
         )
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
@@ -622,13 +661,18 @@ class RoomTopicTestCase(RestTestCase):
                 "admin": False,
                 "device_id": None,
             }
+
         hs.get_auth().get_user_by_token = _get_user_by_token
 
+        def _insert_client_ip(*args, **kwargs):
+            return defer.succeed(None)
+        hs.get_datastore().insert_client_ip = _insert_client_ip
+
         synapse.rest.room.register_servlets(hs, self.mock_resource)
 
         # create the room
         self.room_id = yield self.create_room_as(self.user_id)
-        self.path = "/rooms/%s/state/m.room.topic" % self.room_id
+        self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
 
     def tearDown(self):
         pass
@@ -706,24 +750,21 @@ class RoomMemberStateTestCase(RestTestCase):
         self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
         self.auth_user_id = self.user_id
 
-        state_handler = Mock(spec=["handle_new_event"])
-        state_handler.handle_new_event.return_value = True
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
 
-        persistence_service = Mock(spec=["get_latest_pdus_in_context"])
-        persistence_service.get_latest_pdus_in_context.return_value = []
+        db_pool = SQLiteMemoryDbPool()
+        yield db_pool.prepare()
 
         hs = HomeServer(
             "red",
-            db_pool=None,
+            db_pool=db_pool,
             http_client=None,
-            datastore=MemoryDataStore(),
             replication_layer=Mock(),
-            state_handler=state_handler,
-            persistence_service=persistence_service,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
-            config=NonCallableMock(),
+            config=self.mock_config,
         )
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
@@ -736,13 +777,12 @@ class RoomMemberStateTestCase(RestTestCase):
                 "admin": False,
                 "device_id": None,
             }
-            return {
-                "user": hs.parse_userid(self.auth_user_id),
-                "admin": False,
-                "device_id": None,
-            }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
+        def _insert_client_ip(*args, **kwargs):
+            return defer.succeed(None)
+        hs.get_datastore().insert_client_ip = _insert_client_ip
+
         synapse.rest.room.register_servlets(hs, self.mock_resource)
 
         self.room_id = yield self.create_room_as(self.user_id)
@@ -847,24 +887,21 @@ class RoomMessagesTestCase(RestTestCase):
         self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
         self.auth_user_id = self.user_id
 
-        state_handler = Mock(spec=["handle_new_event"])
-        state_handler.handle_new_event.return_value = True
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
 
-        persistence_service = Mock(spec=["get_latest_pdus_in_context"])
-        persistence_service.get_latest_pdus_in_context.return_value = []
+        db_pool = SQLiteMemoryDbPool()
+        yield db_pool.prepare()
 
         hs = HomeServer(
             "red",
-            db_pool=None,
+            db_pool=db_pool,
             http_client=None,
-            datastore=MemoryDataStore(),
             replication_layer=Mock(),
-            state_handler=state_handler,
-            persistence_service=persistence_service,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
-            config=NonCallableMock(),
+            config=self.mock_config,
         )
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
@@ -879,6 +916,10 @@ class RoomMessagesTestCase(RestTestCase):
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
+        def _insert_client_ip(*args, **kwargs):
+            return defer.succeed(None)
+        hs.get_datastore().insert_client_ip = _insert_client_ip
+
         synapse.rest.room.register_servlets(hs, self.mock_resource)
 
         self.room_id = yield self.create_room_as(self.user_id)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 3ad9a4b0c0..fabd364be9 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -74,7 +74,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_select_one_1col(self):
         self.mock_txn.rowcount = 1
-        self.mock_txn.fetchone.return_value = ("Value",)
+        self.mock_txn.fetchall.return_value = [("Value",)]
 
         value = yield self.datastore._simple_select_one_onecol(
                 table="tablename",
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index dae1641ea1..adfe64a980 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -61,6 +61,7 @@ class RedactionTestCase(unittest.TestCase):
             membership=membership,
             content={"membership": membership},
             depth=self.depth,
+            prev_events=[],
         )
 
         event.content.update(extra_content)
@@ -68,6 +69,11 @@ class RedactionTestCase(unittest.TestCase):
         if prev_state:
             event.prev_state = prev_state
 
+        event.state_events = None
+        event.hashes = {}
+        event.prev_state = []
+        event.auth_events = []
+
         # Have to create a join event using the eventfactory
         yield self.store.persist_event(
             event
@@ -85,8 +91,13 @@ class RedactionTestCase(unittest.TestCase):
             room_id=room.to_string(),
             content={"body": body, "msgtype": u"message"},
             depth=self.depth,
+            prev_events=[],
         )
 
+        event.state_events = None
+        event.hashes = {}
+        event.auth_events = []
+
         yield self.store.persist_event(
             event
         )
@@ -102,8 +113,13 @@ class RedactionTestCase(unittest.TestCase):
             content={"reason": reason},
             depth=self.depth,
             redacts=event_id,
+            prev_events=[],
         )
 
+        event.state_events = None
+        event.hashes = {}
+        event.auth_events = []
+
         yield self.store.persist_event(
             event
         )
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 369a73d917..4ff02c306b 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -127,7 +127,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         )
 
     @defer.inlineCallbacks
-    def test_room_name(self):
+    def STALE_test_room_name(self):
         name = u"A-Room-Name"
 
         yield self.inject_room_event(
@@ -150,7 +150,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         )
 
     @defer.inlineCallbacks
-    def test_room_name(self):
+    def STALE_test_room_topic(self):
         topic = u"A place for things"
 
         yield self.inject_room_event(
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index eae278ee8d..8614e5ca9d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -51,16 +51,24 @@ class RoomMemberStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def inject_room_member(self, room, user, membership):
         # Have to create a join event using the eventfactory
+        event = self.event_factory.create_event(
+            etype=RoomMemberEvent.TYPE,
+            user_id=user.to_string(),
+            state_key=user.to_string(),
+            room_id=room.to_string(),
+            membership=membership,
+            content={"membership": membership},
+            depth=1,
+            prev_events=[],
+        )
+
+        event.state_events = None
+        event.hashes = {}
+        event.prev_state = {}
+        event.auth_events = {}
+
         yield self.store.persist_event(
-            self.event_factory.create_event(
-                etype=RoomMemberEvent.TYPE,
-                user_id=user.to_string(),
-                state_key=user.to_string(),
-                room_id=room.to_string(),
-                membership=membership,
-                content={"membership": membership},
-                depth=1,
-            )
+            event
         )
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index ab30e6ea25..5038546aee 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -48,7 +48,7 @@ class StreamStoreTestCase(unittest.TestCase):
         self.depth = 1
 
     @defer.inlineCallbacks
-    def inject_room_member(self, room, user, membership, prev_state=None):
+    def inject_room_member(self, room, user, membership, replaces_state=None):
         self.depth += 1
 
         event = self.event_factory.create_event(
@@ -59,10 +59,17 @@ class StreamStoreTestCase(unittest.TestCase):
             membership=membership,
             content={"membership": membership},
             depth=self.depth,
+            prev_events=[],
         )
 
-        if prev_state:
-            event.prev_state = prev_state
+        event.state_events = None
+        event.hashes = {}
+        event.prev_state = []
+        event.auth_events = []
+
+        if replaces_state:
+            event.prev_state = [(replaces_state, "hash")]
+            event.replaces_state = replaces_state
 
         # Have to create a join event using the eventfactory
         yield self.store.persist_event(
@@ -75,15 +82,22 @@ class StreamStoreTestCase(unittest.TestCase):
     def inject_message(self, room, user, body):
         self.depth += 1
 
+        event = self.event_factory.create_event(
+            etype=MessageEvent.TYPE,
+            user_id=user.to_string(),
+            room_id=room.to_string(),
+            content={"body": body, "msgtype": u"message"},
+            depth=self.depth,
+            prev_events=[],
+        )
+
+        event.state_events = None
+        event.hashes = {}
+        event.auth_events = []
+
         # Have to create a join event using the eventfactory
         yield self.store.persist_event(
-            self.event_factory.create_event(
-                etype=MessageEvent.TYPE,
-                user_id=user.to_string(),
-                room_id=room.to_string(),
-                content={"body": body, "msgtype": u"message"},
-                depth=self.depth,
-            )
+            event
         )
 
     @defer.inlineCallbacks
@@ -206,7 +220,7 @@ class StreamStoreTestCase(unittest.TestCase):
 
         event2 = yield self.inject_room_member(
             self.room1, self.u_alice, Membership.JOIN,
-            prev_state=event1.event_id,
+            replaces_state=event1.event_id,
         )
 
         end = yield self.store.get_room_events_max_id()
@@ -223,4 +237,7 @@ class StreamStoreTestCase(unittest.TestCase):
 
         event = results[0]
 
-        self.assertTrue(hasattr(event, "prev_content"), msg="No prev_content key")
+        self.assertTrue(
+            hasattr(event, "prev_content"),
+            msg="No prev_content key"
+        )
diff --git a/tests/test_state.py b/tests/test_state.py
index 4b1feaf410..3cc358be32 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -15,599 +15,258 @@
 
 from tests import unittest
 from twisted.internet import defer
-from twisted.python.log import PythonLoggingObserver
 
 from synapse.state import StateHandler
-from synapse.storage.pdu import PduEntry
-from synapse.federation.pdu_codec import encode_event_id
-from synapse.federation.units import Pdu
-
-from collections import namedtuple
 
 from mock import Mock
 
-import mock
-
-
-ReturnType = namedtuple(
-    "StateReturnType", ["new_branch", "current_branch"]
-)
-
-
-def _gen_get_power_level(power_level_list):
-    def get_power_level(room_id, user_id):
-        return defer.succeed(power_level_list.get(user_id, None))
-    return get_power_level
 
 class StateTestCase(unittest.TestCase):
     def setUp(self):
-        self.persistence = Mock(spec=[
-            "get_unresolved_state_tree",
-            "update_current_state",
-            "get_latest_pdus_in_context",
-            "get_current_state_pdu",
-            "get_pdu",
-            "get_power_level",
-        ])
-        self.replication = Mock(spec=["get_pdu"])
-
-        hs = Mock(spec=["get_datastore", "get_replication_layer"])
-        hs.get_datastore.return_value = self.persistence
-        hs.get_replication_layer.return_value = self.replication
-        hs.hostname = "bob.com"
-
-        self.state = StateHandler(hs)
-
-    @defer.inlineCallbacks
-    def test_new_state_key(self):
-        # We've never seen anything for this state before
-        new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u")
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({})
-
-        self.persistence.get_unresolved_state_tree.return_value = (
-            (ReturnType([new_pdu], []), None)
-        )
-
-        is_new = yield self.state.handle_new_state(new_pdu)
-
-        self.assertTrue(is_new)
-
-        self.persistence.get_unresolved_state_tree.assert_called_once_with(
-            new_pdu
-        )
-
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
-
-        self.assertFalse(self.replication.get_pdu.called)
-
-    @defer.inlineCallbacks
-    def test_direct_overwrite(self):
-        # We do a direct overwriting of the old state, i.e., the new state
-        # points to the old state.
-
-        old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
-        new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 5,
-        })
-
-        self.persistence.get_unresolved_state_tree.return_value = (
-            (ReturnType([new_pdu, old_pdu], [old_pdu]), None)
-        )
-
-        is_new = yield self.state.handle_new_state(new_pdu)
-
-        self.assertTrue(is_new)
-
-        self.persistence.get_unresolved_state_tree.assert_called_once_with(
-            new_pdu
+        self.store = Mock(
+            spec_set=[
+                "get_state_groups",
+            ]
         )
+        hs = Mock(spec=["get_datastore"])
+        hs.get_datastore.return_value = self.store
 
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
-
-        self.assertFalse(self.replication.get_pdu.called)
+        self.state = StateHandler(hs)
+        self.event_id = 0
 
     @defer.inlineCallbacks
-    def test_overwrite(self):
-        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
-        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
-        new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3")
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 5,
-            "u3": 0,
-        })
-
-        self.persistence.get_unresolved_state_tree.return_value = (
-            (ReturnType([new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None)
-        )
+    def test_annotate_with_old_message(self):
+        event = self.create_event(type="test_message", name="event")
 
-        is_new = yield self.state.handle_new_state(new_pdu)
+        old_state = [
+            self.create_event(type="test1", state_key="1"),
+            self.create_event(type="test1", state_key="2"),
+            self.create_event(type="test2", state_key=""),
+        ]
 
-        self.assertTrue(is_new)
+        yield self.state.annotate_state_groups(event, old_state=old_state)
 
-        self.persistence.get_unresolved_state_tree.assert_called_once_with(
-            new_pdu
-        )
+        for k, v in event.old_state_events.items():
+            type, state_key = k
+            self.assertEqual(type, v.type)
+            self.assertEqual(state_key, v.state_key)
 
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
+        self.assertEqual(set(old_state), set(event.old_state_events.values()))
+        self.assertDictEqual(event.old_state_events, event.state_events)
 
-        self.assertFalse(self.replication.get_pdu.called)
+        self.assertIsNone(event.state_group)
 
     @defer.inlineCallbacks
-    def test_power_level_fail(self):
-        # We try to update the state based on an outdated state, and have a
-        # too low power level.
-
-        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
-        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
-        new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 10,
-            "u3": 5,
-        })
-
-        self.persistence.get_unresolved_state_tree.return_value = (
-            (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
-        )
-
-        is_new = yield self.state.handle_new_state(new_pdu)
-
-        self.assertFalse(is_new)
-
-        self.persistence.get_unresolved_state_tree.assert_called_once_with(
-            new_pdu
-        )
-
-        self.assertEqual(0, self.persistence.update_current_state.call_count)
+    def test_annotate_with_old_state(self):
+        event = self.create_event(type="state", state_key="", name="event")
 
-        self.assertFalse(self.replication.get_pdu.called)
-
-    @defer.inlineCallbacks
-    def test_power_level_succeed(self):
-        # We try to update the state based on an outdated state, but have
-        # sufficient power level to force the update.
-
-        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
-        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
-        new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 10,
-            "u3": 15,
-        })
-
-        self.persistence.get_unresolved_state_tree.return_value = (
-            (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
-        )
+        old_state = [
+            self.create_event(type="test1", state_key="1"),
+            self.create_event(type="test1", state_key="2"),
+            self.create_event(type="test2", state_key=""),
+        ]
 
-        is_new = yield self.state.handle_new_state(new_pdu)
+        yield self.state.annotate_state_groups(event, old_state=old_state)
 
-        self.assertTrue(is_new)
+        for k, v in event.old_state_events.items():
+            type, state_key = k
+            self.assertEqual(type, v.type)
+            self.assertEqual(state_key, v.state_key)
 
-        self.persistence.get_unresolved_state_tree.assert_called_once_with(
-            new_pdu
+        self.assertEqual(
+            set(old_state + [event]),
+            set(event.old_state_events.values())
         )
 
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
+        self.assertDictEqual(event.old_state_events, event.state_events)
 
-        self.assertFalse(self.replication.get_pdu.called)
+        self.assertIsNone(event.state_group)
 
     @defer.inlineCallbacks
-    def test_power_level_equal_same_len(self):
-        # We try to update the state based on an outdated state, the power
-        # levels are the same and so are the branch lengths
-
-        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
-        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
-        new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 10,
-            "u3": 10,
-        })
-
-        self.persistence.get_unresolved_state_tree.return_value = (
-            (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
-        )
-
-        is_new = yield self.state.handle_new_state(new_pdu)
+    def test_trivial_annotate_message(self):
+        event = self.create_event(type="test_message", name="event")
+        event.prev_events = []
+
+        old_state = [
+            self.create_event(type="test1", state_key="1"),
+            self.create_event(type="test1", state_key="2"),
+            self.create_event(type="test2", state_key=""),
+        ]
 
-        self.assertTrue(is_new)
+        group_name = "group_name_1"
 
-        self.persistence.get_unresolved_state_tree.assert_called_once_with(
-            new_pdu
-        )
+        self.store.get_state_groups.return_value = {
+            group_name: old_state,
+        }
 
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
+        yield self.state.annotate_state_groups(event)
 
-        self.assertFalse(self.replication.get_pdu.called)
+        for k, v in event.old_state_events.items():
+            type, state_key = k
+            self.assertEqual(type, v.type)
+            self.assertEqual(state_key, v.state_key)
 
-    @defer.inlineCallbacks
-    def test_power_level_equal_diff_len(self):
-        # We try to update the state based on an outdated state, the power
-        # levels are the same but the branch length of the new one is longer.
-
-        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
-        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
-        old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
-        new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4")
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 10,
-            "u3": 10,
-            "u4": 10,
-        })
-
-        self.persistence.get_unresolved_state_tree.return_value = (
-            (
-                ReturnType(
-                    [new_pdu, old_pdu_3, old_pdu_1],
-                    [old_pdu_2, old_pdu_1]
-                ),
-                None
-            )
+        self.assertEqual(
+            set([e.event_id for e in old_state]),
+            set([e.event_id for e in event.old_state_events.values()])
         )
 
-        is_new = yield self.state.handle_new_state(new_pdu)
-
-        self.assertTrue(is_new)
-
-        self.persistence.get_unresolved_state_tree.assert_called_once_with(
-            new_pdu
+        self.assertDictEqual(
+            {
+                k: v.event_id
+                for k, v in event.old_state_events.items()
+            },
+            {
+                k: v.event_id
+                for k, v in event.state_events.items()
+            }
         )
 
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
-
-        self.assertFalse(self.replication.get_pdu.called)
+        self.assertEqual(group_name, event.state_group)
 
     @defer.inlineCallbacks
-    def test_missing_pdu(self):
-        # We try to update state against a PDU we haven't yet seen,
-        # triggering a get_pdu request
-
-        # The pdu we haven't seen
-        old_pdu_1 = new_fake_pdu(
-            "A", "test", "mem", "x", None, "u1", depth=0
-        )
-
-        old_pdu_2 = new_fake_pdu(
-            "B", "test", "mem", "x", "A", "u2", depth=1
-        )
-        new_pdu = new_fake_pdu(
-            "C", "test", "mem", "x", "A", "u3", depth=2
-        )
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 10,
-            "u3": 20,
-        })
-
-        # The return_value of `get_unresolved_state_tree`, which changes after
-        # the call to get_pdu
-        tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)]
-
-        def return_tree(p):
-            return tree_to_return[0]
-
-        def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
-            tree_to_return[0] = (
-                ReturnType(
-                    [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]
-                ),
-                None
-            )
-            return defer.succeed(None)
-
-        self.persistence.get_unresolved_state_tree.side_effect = return_tree
+    def test_trivial_annotate_state(self):
+        event = self.create_event(type="state", state_key="", name="event")
+        event.prev_events = []
+
+        old_state = [
+            self.create_event(type="test1", state_key="1"),
+            self.create_event(type="test1", state_key="2"),
+            self.create_event(type="test2", state_key=""),
+        ]
 
-        self.replication.get_pdu.side_effect = set_return_tree
+        group_name = "group_name_1"
 
-        self.persistence.get_pdu.return_value = None
+        self.store.get_state_groups.return_value = {
+            group_name: old_state,
+        }
 
-        is_new = yield self.state.handle_new_state(new_pdu)
+        yield self.state.annotate_state_groups(event)
 
-        self.assertTrue(is_new)
+        for k, v in event.old_state_events.items():
+            type, state_key = k
+            self.assertEqual(type, v.type)
+            self.assertEqual(state_key, v.state_key)
 
-        self.replication.get_pdu.assert_called_with(
-            destination=new_pdu.origin,
-            pdu_origin=old_pdu_1.origin,
-            pdu_id=old_pdu_1.pdu_id,
-            outlier=True
+        self.assertEqual(
+            set([e.event_id for e in old_state]),
+            set([e.event_id for e in event.old_state_events.values()])
         )
 
-        self.persistence.get_unresolved_state_tree.assert_called_with(
-            new_pdu
+        self.assertEqual(
+            set([e.event_id for e in old_state] + [event.event_id]),
+            set([e.event_id for e in event.state_events.values()])
         )
 
-        self.assertEquals(
-            2, self.persistence.get_unresolved_state_tree.call_count
+        new_state = {
+            k: v.event_id
+            for k, v in event.state_events.items()
+        }
+        old_state = {
+            k: v.event_id
+            for k, v in event.old_state_events.items()
+        }
+        old_state[(event.type, event.state_key)] = event.event_id
+        self.assertDictEqual(
+            old_state,
+            new_state
         )
 
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
+        self.assertIsNone(event.state_group)
 
     @defer.inlineCallbacks
-    def test_missing_pdu_depth_1(self):
-        # We try to update state against a PDU we haven't yet seen,
-        # triggering a get_pdu request
-
-        # The pdu we haven't seen
-        old_pdu_1 = new_fake_pdu(
-            "A", "test", "mem", "x", None, "u1", depth=0
-        )
-
-        old_pdu_2 = new_fake_pdu(
-            "B", "test", "mem", "x", "A", "u2", depth=2
-        )
-        old_pdu_3 = new_fake_pdu(
-            "C", "test", "mem", "x", "B", "u3", depth=3
-        )
-        new_pdu = new_fake_pdu(
-            "D", "test", "mem", "x", "A", "u4", depth=4
-        )
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 10,
-            "u3": 10,
-            "u4": 20,
-        })
-
-        # The return_value of `get_unresolved_state_tree`, which changes after
-        # the call to get_pdu
-        tree_to_return = [
-            (
-                ReturnType([new_pdu], [old_pdu_3]),
-                0
-            ),
-            (
-                ReturnType(
-                    [new_pdu, old_pdu_1], [old_pdu_3]
-                ),
-                1
-            ),
-            (
-                ReturnType(
-                    [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
-                ),
-                None
-            ),
+    def test_resolve_message_conflict(self):
+        event = self.create_event(type="test_message", name="event")
+        event.prev_events = []
+
+        old_state_1 = [
+            self.create_event(type="test1", state_key="1"),
+            self.create_event(type="test1", state_key="2"),
+            self.create_event(type="test2", state_key=""),
         ]
 
-        to_return = [0]
-
-        def return_tree(p):
-            return tree_to_return[to_return[0]]
-
-        def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
-            to_return[0] += 1
-            return defer.succeed(None)
-
-        self.persistence.get_unresolved_state_tree.side_effect = return_tree
-
-        self.replication.get_pdu.side_effect = set_return_tree
-
-        self.persistence.get_pdu.return_value = None
-
-        is_new = yield self.state.handle_new_state(new_pdu)
+        old_state_2 = [
+            self.create_event(type="test1", state_key="1"),
+            self.create_event(type="test3", state_key="2"),
+            self.create_event(type="test4", state_key=""),
+        ]
 
-        self.assertTrue(is_new)
+        group_name_1 = "group_name_1"
+        group_name_2 = "group_name_2"
 
-        self.assertEqual(2, self.replication.get_pdu.call_count)
+        self.store.get_state_groups.return_value = {
+            group_name_1: old_state_1,
+            group_name_2: old_state_2,
+        }
 
-        self.replication.get_pdu.assert_has_calls(
-            [
-                mock.call(
-                    destination=new_pdu.origin,
-                    pdu_origin=old_pdu_1.origin,
-                    pdu_id=old_pdu_1.pdu_id,
-                    outlier=True
-                ),
-                mock.call(
-                    destination=old_pdu_3.origin,
-                    pdu_origin=old_pdu_2.origin,
-                    pdu_id=old_pdu_2.pdu_id,
-                    outlier=True
-                ),
-            ]
-        )
+        yield self.state.annotate_state_groups(event)
 
-        self.persistence.get_unresolved_state_tree.assert_called_with(
-            new_pdu
-        )
+        self.assertEqual(len(event.old_state_events), 5)
 
-        self.assertEquals(
-            3, self.persistence.get_unresolved_state_tree.call_count
+        self.assertEqual(
+            set([e.event_id for e in event.state_events.values()]),
+            set([e.event_id for e in event.old_state_events.values()])
         )
 
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
+        self.assertIsNone(event.state_group)
 
     @defer.inlineCallbacks
-    def test_missing_pdu_depth_2(self):
-        # We try to update state against a PDU we haven't yet seen,
-        # triggering a get_pdu request
-
-        # The pdu we haven't seen
-        old_pdu_1 = new_fake_pdu(
-            "A", "test", "mem", "x", None, "u1", depth=0
-        )
-
-        old_pdu_2 = new_fake_pdu(
-            "B", "test", "mem", "x", "A", "u2", depth=2
-        )
-        old_pdu_3 = new_fake_pdu(
-            "C", "test", "mem", "x", "B", "u3", depth=3
-        )
-        new_pdu = new_fake_pdu(
-            "D", "test", "mem", "x", "A", "u4", depth=1
-        )
-
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 10,
-            "u2": 10,
-            "u3": 10,
-            "u4": 20,
-        })
-
-        # The return_value of `get_unresolved_state_tree`, which changes after
-        # the call to get_pdu
-        tree_to_return = [
-            (
-                ReturnType([new_pdu], [old_pdu_3]),
-                1,
-            ),
-            (
-                ReturnType(
-                    [new_pdu], [old_pdu_3, old_pdu_2]
-                ),
-                0,
-            ),
-            (
-                ReturnType(
-                    [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
-                ),
-                None
-            ),
+    def test_resolve_state_conflict(self):
+        event = self.create_event(type="test4", state_key="", name="event")
+        event.prev_events = []
+
+        old_state_1 = [
+            self.create_event(type="test1", state_key="1"),
+            self.create_event(type="test1", state_key="2"),
+            self.create_event(type="test2", state_key=""),
         ]
 
-        to_return = [0]
-
-        def return_tree(p):
-            return tree_to_return[to_return[0]]
-
-        def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
-            to_return[0] += 1
-            return defer.succeed(None)
-
-        self.persistence.get_unresolved_state_tree.side_effect = return_tree
-
-        self.replication.get_pdu.side_effect = set_return_tree
-
-        self.persistence.get_pdu.return_value = None
-
-        is_new = yield self.state.handle_new_state(new_pdu)
-
-        self.assertTrue(is_new)
-
-        self.assertEqual(2, self.replication.get_pdu.call_count)
-
-        self.replication.get_pdu.assert_has_calls(
-            [
-                mock.call(
-                    destination=old_pdu_3.origin,
-                    pdu_origin=old_pdu_2.origin,
-                    pdu_id=old_pdu_2.pdu_id,
-                    outlier=True
-                ),
-                mock.call(
-                    destination=new_pdu.origin,
-                    pdu_origin=old_pdu_1.origin,
-                    pdu_id=old_pdu_1.pdu_id,
-                    outlier=True
-                ),
-            ]
-        )
-
-        self.persistence.get_unresolved_state_tree.assert_called_with(
-            new_pdu
-        )
-
-        self.assertEquals(
-            3, self.persistence.get_unresolved_state_tree.call_count
-        )
-
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
-
-    @defer.inlineCallbacks
-    def test_no_common_ancestor(self):
-        # We do a direct overwriting of the old state, i.e., the new state
-        # points to the old state.
+        old_state_2 = [
+            self.create_event(type="test1", state_key="1"),
+            self.create_event(type="test3", state_key="2"),
+            self.create_event(type="test4", state_key=""),
+        ]
 
-        old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
-        new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2")
+        group_name_1 = "group_name_1"
+        group_name_2 = "group_name_2"
 
-        self.persistence.get_power_level.side_effect = _gen_get_power_level({
-            "u1": 5,
-            "u2": 10,
-        })
+        self.store.get_state_groups.return_value = {
+            group_name_1: old_state_1,
+            group_name_2: old_state_2,
+        }
 
-        self.persistence.get_unresolved_state_tree.return_value = (
-            (ReturnType([new_pdu], [old_pdu]), None)
-        )
+        yield self.state.annotate_state_groups(event)
 
-        is_new = yield self.state.handle_new_state(new_pdu)
+        self.assertEqual(len(event.old_state_events), 5)
 
-        self.assertTrue(is_new)
+        expected_new = event.old_state_events
+        expected_new[(event.type, event.state_key)] = event
 
-        self.persistence.get_unresolved_state_tree.assert_called_once_with(
-            new_pdu
+        self.assertEqual(
+            set([e.event_id for e in expected_new.values()]),
+            set([e.event_id for e in event.state_events.values()]),
         )
 
-        self.assertEqual(1, self.persistence.update_current_state.call_count)
-
-        self.assertFalse(self.replication.get_pdu.called)
-
-    @defer.inlineCallbacks
-    def test_new_event(self):
-        event = Mock()
-        event.event_id = "12123123@test"
+        self.assertIsNone(event.state_group)
 
-        state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20)
+    def create_event(self, name=None, type=None, state_key=None):
+        self.event_id += 1
+        event_id = str(self.event_id)
 
-        snapshot = Mock()
-        snapshot.prev_state_pdu = state_pdu
-        event_id = "pdu_id@origin.com"
+        if not name:
+            if state_key is not None:
+                name = "<%s-%s>" % (type, state_key)
+            else:
+                name = "<%s>" % (type, )
 
-        def fill_out_prev_events(event):
-            event.prev_events = [event_id]
-            event.depth = 6
-        snapshot.fill_out_prev_events = fill_out_prev_events
+        event = Mock(name=name, spec=[])
+        event.type = type
 
-        yield self.state.handle_new_event(event, snapshot)
-
-        self.assertLess(5, event.depth)
-
-        self.assertEquals(1, len(event.prev_events))
-
-        prev_id = event.prev_events[0]
-
-        self.assertEqual(event_id, prev_id)
-
-        self.assertEqual(
-            encode_event_id(state_pdu.pdu_id, state_pdu.origin),
-            event.prev_state
-        )
+        if state_key is not None:
+            event.state_key = state_key
+        event.event_id = event_id
 
+        event.user_id = "@user_id:example.com"
+        event.room_id = "!room_id:example.com"
 
-def new_fake_pdu(pdu_id, context, pdu_type, state_key, prev_state_id,
-                 user_id, depth=0):
-    new_pdu = Pdu(
-        pdu_id=pdu_id,
-        pdu_type=pdu_type,
-        state_key=state_key,
-        user_id=user_id,
-        prev_state_id=prev_state_id,
-        origin="example.com",
-        context="context",
-        origin_server_ts=1405353060021,
-        depth=depth,
-        content_json="{}",
-        unrecognized_keys="{}",
-        outlier=True,
-        is_state=True,
-        prev_state_origin="example.com",
-        have_processed=True,
-        content={},
-    )
-
-    return new_pdu
+        return event
diff --git a/tests/utils.py b/tests/utils.py
index 60fd6085ac..d8be73dba8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -118,13 +118,14 @@ class MockHttpResource(HttpServer):
 class MockKey(object):
     alg = "mock_alg"
     version = "mock_version"
+    signature = b"\x9a\x87$"
 
     @property
     def verify_key(self):
         return self
 
     def sign(self, message):
-        return b"\x9a\x87$"
+        return self
 
     def verify(self, message, sig):
         assert sig == b"\x9a\x87$"