summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/snapshot.py1
-rw-r--r--synapse/federation/replication.py291
-rw-r--r--synapse/federation/transaction_queue.py314
-rw-r--r--synapse/state.py127
-rw-r--r--synapse/storage/__init__.py40
-rw-r--r--synapse/storage/_base.py21
-rw-r--r--synapse/storage/rejections.py33
-rw-r--r--synapse/storage/schema/delta/v12.sql21
-rw-r--r--synapse/storage/schema/im.sql7
-rw-r--r--tests/test_state.py428
10 files changed, 866 insertions, 417 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6bbba8d6ba..7e98bdef28 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -20,3 +20,4 @@ class EventContext(object):
         self.current_state = current_state
         self.auth_events = auth_events
         self.state_group = None
+        self.rejected = False
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 6620532a60..accf95e406 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -20,8 +20,8 @@ a given transport.
 from twisted.internet import defer
 
 from .units import Transaction, Edu
-
 from .persistence import TransactionActions
+from .transaction_queue import TransactionQueue
 
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext
@@ -62,7 +62,7 @@ class ReplicationLayer(object):
         # self.pdu_actions = PduActions(self.store)
         self.transaction_actions = TransactionActions(self.store)
 
-        self._transaction_queue = _TransactionQueue(
+        self._transaction_queue = TransactionQueue(
             hs, self.transaction_actions, transport_layer
         )
 
@@ -662,290 +662,3 @@ class ReplicationLayer(object):
         event.internal_metadata.outlier = outlier
 
         return event
-
-
-class _TransactionQueue(object):
-    """This class makes sure we only have one transaction in flight at
-    a time for a given destination.
-
-    It batches pending PDUs into single transactions.
-    """
-
-    def __init__(self, hs, transaction_actions, transport_layer):
-        self.server_name = hs.hostname
-        self.transaction_actions = transaction_actions
-        self.transport_layer = transport_layer
-
-        self._clock = hs.get_clock()
-        self.store = hs.get_datastore()
-
-        # Is a mapping from destinations -> deferreds. Used to keep track
-        # of which destinations have transactions in flight and when they are
-        # done
-        self.pending_transactions = {}
-
-        # Is a mapping from destination -> list of
-        # tuple(pending pdus, deferred, order)
-        self.pending_pdus_by_dest = {}
-        # destination -> list of tuple(edu, deferred)
-        self.pending_edus_by_dest = {}
-
-        # destination -> list of tuple(failure, deferred)
-        self.pending_failures_by_dest = {}
-
-        # HACK to get unique tx id
-        self._next_txn_id = int(self._clock.time_msec())
-
-    @defer.inlineCallbacks
-    @log_function
-    def enqueue_pdu(self, pdu, destinations, order):
-        # We loop through all destinations to see whether we already have
-        # a transaction in progress. If we do, stick it in the pending_pdus
-        # table and we'll get back to it later.
-
-        destinations = set(destinations)
-        destinations.discard(self.server_name)
-        destinations.discard("localhost")
-
-        logger.debug("Sending to: %s", str(destinations))
-
-        if not destinations:
-            return
-
-        deferreds = []
-
-        for destination in destinations:
-            deferred = defer.Deferred()
-            self.pending_pdus_by_dest.setdefault(destination, []).append(
-                (pdu, deferred, order)
-            )
-
-            def eb(failure):
-                if not deferred.called:
-                    deferred.errback(failure)
-                else:
-                    logger.warn("Failed to send pdu", failure)
-
-            with PreserveLoggingContext():
-                self._attempt_new_transaction(destination).addErrback(eb)
-
-            deferreds.append(deferred)
-
-        yield defer.DeferredList(deferreds)
-
-    # NO inlineCallbacks
-    def enqueue_edu(self, edu):
-        destination = edu.destination
-
-        if destination == self.server_name:
-            return
-
-        deferred = defer.Deferred()
-        self.pending_edus_by_dest.setdefault(destination, []).append(
-            (edu, deferred)
-        )
-
-        def eb(failure):
-            if not deferred.called:
-                deferred.errback(failure)
-            else:
-                logger.warn("Failed to send edu", failure)
-
-        with PreserveLoggingContext():
-            self._attempt_new_transaction(destination).addErrback(eb)
-
-        return deferred
-
-    @defer.inlineCallbacks
-    def enqueue_failure(self, failure, destination):
-        deferred = defer.Deferred()
-
-        self.pending_failures_by_dest.setdefault(
-            destination, []
-        ).append(
-            (failure, deferred)
-        )
-
-        yield deferred
-
-    @defer.inlineCallbacks
-    @log_function
-    def _attempt_new_transaction(self, destination):
-
-        (retry_last_ts, retry_interval) = (0, 0)
-        retry_timings = yield self.store.get_destination_retry_timings(
-            destination
-        )
-        if retry_timings:
-            (retry_last_ts, retry_interval) = (
-                retry_timings.retry_last_ts, retry_timings.retry_interval
-            )
-            if retry_last_ts + retry_interval > int(self._clock.time_msec()):
-                logger.info(
-                    "TX [%s] not ready for retry yet - "
-                    "dropping transaction for now",
-                    destination,
-                )
-                return
-            else:
-                logger.info("TX [%s] is ready for retry", destination)
-
-        logger.info("TX [%s] _attempt_new_transaction", destination)
-
-        if destination in self.pending_transactions:
-            # XXX: pending_transactions can get stuck on by a never-ending
-            # request at which point pending_pdus_by_dest just keeps growing.
-            # we need application-layer timeouts of some flavour of these
-            # requests
-            return
-
-        # list of (pending_pdu, deferred, order)
-        pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-        pending_edus = self.pending_edus_by_dest.pop(destination, [])
-        pending_failures = self.pending_failures_by_dest.pop(destination, [])
-
-        if pending_pdus:
-            logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                        destination, len(pending_pdus))
-
-        if not pending_pdus and not pending_edus and not pending_failures:
-            return
-
-        logger.debug(
-            "TX [%s] Attempting new transaction"
-            " (pdus: %d, edus: %d, failures: %d)",
-            destination,
-            len(pending_pdus),
-            len(pending_edus),
-            len(pending_failures)
-        )
-
-        # Sort based on the order field
-        pending_pdus.sort(key=lambda t: t[2])
-
-        pdus = [x[0] for x in pending_pdus]
-        edus = [x[0] for x in pending_edus]
-        failures = [x[0].get_dict() for x in pending_failures]
-        deferreds = [
-            x[1]
-            for x in pending_pdus + pending_edus + pending_failures
-        ]
-
-        try:
-            self.pending_transactions[destination] = 1
-
-            logger.debug("TX [%s] Persisting transaction...", destination)
-
-            transaction = Transaction.create_new(
-                origin_server_ts=int(self._clock.time_msec()),
-                transaction_id=str(self._next_txn_id),
-                origin=self.server_name,
-                destination=destination,
-                pdus=pdus,
-                edus=edus,
-                pdu_failures=failures,
-            )
-
-            self._next_txn_id += 1
-
-            yield self.transaction_actions.prepare_to_send(transaction)
-
-            logger.debug("TX [%s] Persisted transaction", destination)
-            logger.info(
-                "TX [%s] Sending transaction [%s]",
-                destination,
-                transaction.transaction_id,
-            )
-
-            # Actually send the transaction
-
-            # FIXME (erikj): This is a bit of a hack to make the Pdu age
-            # keys work
-            def json_data_cb():
-                data = transaction.get_dict()
-                now = int(self._clock.time_msec())
-                if "pdus" in data:
-                    for p in data["pdus"]:
-                        if "age_ts" in p:
-                            unsigned = p.setdefault("unsigned", {})
-                            unsigned["age"] = now - int(p["age_ts"])
-                            del p["age_ts"]
-                return data
-
-            code, response = yield self.transport_layer.send_transaction(
-                transaction, json_data_cb
-            )
-
-            logger.info("TX [%s] got %d response", destination, code)
-
-            logger.debug("TX [%s] Sent transaction", destination)
-            logger.debug("TX [%s] Marking as delivered...", destination)
-
-            yield self.transaction_actions.delivered(
-                transaction, code, response
-            )
-
-            logger.debug("TX [%s] Marked as delivered", destination)
-            logger.debug("TX [%s] Yielding to callbacks...", destination)
-
-            for deferred in deferreds:
-                if code == 200:
-                    if retry_last_ts:
-                        # this host is alive! reset retry schedule
-                        yield self.store.set_destination_retry_timings(
-                            destination, 0, 0
-                        )
-                    deferred.callback(None)
-                else:
-                    self.set_retrying(destination, retry_interval)
-                    deferred.errback(RuntimeError("Got status %d" % code))
-
-                # Ensures we don't continue until all callbacks on that
-                # deferred have fired
-                try:
-                    yield deferred
-                except:
-                    pass
-
-            logger.debug("TX [%s] Yielded to callbacks", destination)
-
-        except Exception as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
-
-            self.set_retrying(destination, retry_interval)
-
-            for deferred in deferreds:
-                if not deferred.called:
-                    deferred.errback(e)
-
-        finally:
-            # We want to be *very* sure we delete this after we stop processing
-            self.pending_transactions.pop(destination, None)
-
-            # Check to see if there is anything else to send.
-            self._attempt_new_transaction(destination)
-
-    @defer.inlineCallbacks
-    def set_retrying(self, destination, retry_interval):
-        # track that this destination is having problems and we should
-        # give it a chance to recover before trying it again
-
-        if retry_interval:
-            retry_interval *= 2
-            # plateau at hourly retries for now
-            if retry_interval >= 60 * 60 * 1000:
-                retry_interval = 60 * 60 * 1000
-        else:
-            retry_interval = 2000  # try again at first after 2 seconds
-
-        yield self.store.set_destination_retry_timings(
-            destination,
-            int(self._clock.time_msec()),
-            retry_interval
-        )
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
new file mode 100644
index 0000000000..c2cb4a1c49
--- /dev/null
+++ b/synapse/federation/transaction_queue.py
@@ -0,0 +1,314 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from .units import Transaction
+
+from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class TransactionQueue(object):
+    """This class makes sure we only have one transaction in flight at
+    a time for a given destination.
+
+    It batches pending PDUs into single transactions.
+    """
+
+    def __init__(self, hs, transaction_actions, transport_layer):
+        self.server_name = hs.hostname
+        self.transaction_actions = transaction_actions
+        self.transport_layer = transport_layer
+
+        self._clock = hs.get_clock()
+        self.store = hs.get_datastore()
+
+        # Is a mapping from destinations -> deferreds. Used to keep track
+        # of which destinations have transactions in flight and when they are
+        # done
+        self.pending_transactions = {}
+
+        # Is a mapping from destination -> list of
+        # tuple(pending pdus, deferred, order)
+        self.pending_pdus_by_dest = {}
+        # destination -> list of tuple(edu, deferred)
+        self.pending_edus_by_dest = {}
+
+        # destination -> list of tuple(failure, deferred)
+        self.pending_failures_by_dest = {}
+
+        # HACK to get unique tx id
+        self._next_txn_id = int(self._clock.time_msec())
+
+    @defer.inlineCallbacks
+    @log_function
+    def enqueue_pdu(self, pdu, destinations, order):
+        # We loop through all destinations to see whether we already have
+        # a transaction in progress. If we do, stick it in the pending_pdus
+        # table and we'll get back to it later.
+
+        destinations = set(destinations)
+        destinations.discard(self.server_name)
+        destinations.discard("localhost")
+
+        logger.debug("Sending to: %s", str(destinations))
+
+        if not destinations:
+            return
+
+        deferreds = []
+
+        for destination in destinations:
+            deferred = defer.Deferred()
+            self.pending_pdus_by_dest.setdefault(destination, []).append(
+                (pdu, deferred, order)
+            )
+
+            def eb(failure):
+                if not deferred.called:
+                    deferred.errback(failure)
+                else:
+                    logger.warn("Failed to send pdu", failure)
+
+            with PreserveLoggingContext():
+                self._attempt_new_transaction(destination).addErrback(eb)
+
+            deferreds.append(deferred)
+
+        yield defer.DeferredList(deferreds)
+
+    # NO inlineCallbacks
+    def enqueue_edu(self, edu):
+        destination = edu.destination
+
+        if destination == self.server_name:
+            return
+
+        deferred = defer.Deferred()
+        self.pending_edus_by_dest.setdefault(destination, []).append(
+            (edu, deferred)
+        )
+
+        def eb(failure):
+            if not deferred.called:
+                deferred.errback(failure)
+            else:
+                logger.warn("Failed to send edu", failure)
+
+        with PreserveLoggingContext():
+            self._attempt_new_transaction(destination).addErrback(eb)
+
+        return deferred
+
+    @defer.inlineCallbacks
+    def enqueue_failure(self, failure, destination):
+        deferred = defer.Deferred()
+
+        self.pending_failures_by_dest.setdefault(
+            destination, []
+        ).append(
+            (failure, deferred)
+        )
+
+        yield deferred
+
+    @defer.inlineCallbacks
+    @log_function
+    def _attempt_new_transaction(self, destination):
+
+        (retry_last_ts, retry_interval) = (0, 0)
+        retry_timings = yield self.store.get_destination_retry_timings(
+            destination
+        )
+        if retry_timings:
+            (retry_last_ts, retry_interval) = (
+                retry_timings.retry_last_ts, retry_timings.retry_interval
+            )
+            if retry_last_ts + retry_interval > int(self._clock.time_msec()):
+                logger.info(
+                    "TX [%s] not ready for retry yet - "
+                    "dropping transaction for now",
+                    destination,
+                )
+                return
+            else:
+                logger.info("TX [%s] is ready for retry", destination)
+
+        logger.info("TX [%s] _attempt_new_transaction", destination)
+
+        if destination in self.pending_transactions:
+            # XXX: pending_transactions can get stuck on by a never-ending
+            # request at which point pending_pdus_by_dest just keeps growing.
+            # we need application-layer timeouts of some flavour of these
+            # requests
+            return
+
+        # list of (pending_pdu, deferred, order)
+        pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+        pending_edus = self.pending_edus_by_dest.pop(destination, [])
+        pending_failures = self.pending_failures_by_dest.pop(destination, [])
+
+        if pending_pdus:
+            logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                        destination, len(pending_pdus))
+
+        if not pending_pdus and not pending_edus and not pending_failures:
+            return
+
+        logger.debug(
+            "TX [%s] Attempting new transaction"
+            " (pdus: %d, edus: %d, failures: %d)",
+            destination,
+            len(pending_pdus),
+            len(pending_edus),
+            len(pending_failures)
+        )
+
+        # Sort based on the order field
+        pending_pdus.sort(key=lambda t: t[2])
+
+        pdus = [x[0] for x in pending_pdus]
+        edus = [x[0] for x in pending_edus]
+        failures = [x[0].get_dict() for x in pending_failures]
+        deferreds = [
+            x[1]
+            for x in pending_pdus + pending_edus + pending_failures
+        ]
+
+        try:
+            self.pending_transactions[destination] = 1
+
+            logger.debug("TX [%s] Persisting transaction...", destination)
+
+            transaction = Transaction.create_new(
+                origin_server_ts=int(self._clock.time_msec()),
+                transaction_id=str(self._next_txn_id),
+                origin=self.server_name,
+                destination=destination,
+                pdus=pdus,
+                edus=edus,
+                pdu_failures=failures,
+            )
+
+            self._next_txn_id += 1
+
+            yield self.transaction_actions.prepare_to_send(transaction)
+
+            logger.debug("TX [%s] Persisted transaction", destination)
+            logger.info(
+                "TX [%s] Sending transaction [%s]",
+                destination,
+                transaction.transaction_id,
+            )
+
+            # Actually send the transaction
+
+            # FIXME (erikj): This is a bit of a hack to make the Pdu age
+            # keys work
+            def json_data_cb():
+                data = transaction.get_dict()
+                now = int(self._clock.time_msec())
+                if "pdus" in data:
+                    for p in data["pdus"]:
+                        if "age_ts" in p:
+                            unsigned = p.setdefault("unsigned", {})
+                            unsigned["age"] = now - int(p["age_ts"])
+                            del p["age_ts"]
+                return data
+
+            code, response = yield self.transport_layer.send_transaction(
+                transaction, json_data_cb
+            )
+
+            logger.info("TX [%s] got %d response", destination, code)
+
+            logger.debug("TX [%s] Sent transaction", destination)
+            logger.debug("TX [%s] Marking as delivered...", destination)
+
+            yield self.transaction_actions.delivered(
+                transaction, code, response
+            )
+
+            logger.debug("TX [%s] Marked as delivered", destination)
+            logger.debug("TX [%s] Yielding to callbacks...", destination)
+
+            for deferred in deferreds:
+                if code == 200:
+                    if retry_last_ts:
+                        # this host is alive! reset retry schedule
+                        yield self.store.set_destination_retry_timings(
+                            destination, 0, 0
+                        )
+                    deferred.callback(None)
+                else:
+                    self.set_retrying(destination, retry_interval)
+                    deferred.errback(RuntimeError("Got status %d" % code))
+
+                # Ensures we don't continue until all callbacks on that
+                # deferred have fired
+                try:
+                    yield deferred
+                except:
+                    pass
+
+            logger.debug("TX [%s] Yielded to callbacks", destination)
+
+        except Exception as e:
+            # We capture this here as there as nothing actually listens
+            # for this finishing functions deferred.
+            logger.warn(
+                "TX [%s] Problem in _attempt_transaction: %s",
+                destination,
+                e,
+            )
+
+            self.set_retrying(destination, retry_interval)
+
+            for deferred in deferreds:
+                if not deferred.called:
+                    deferred.errback(e)
+
+        finally:
+            # We want to be *very* sure we delete this after we stop processing
+            self.pending_transactions.pop(destination, None)
+
+            # Check to see if there is anything else to send.
+            self._attempt_new_transaction(destination)
+
+    @defer.inlineCallbacks
+    def set_retrying(self, destination, retry_interval):
+        # track that this destination is having problems and we should
+        # give it a chance to recover before trying it again
+
+        if retry_interval:
+            retry_interval *= 2
+            # plateau at hourly retries for now
+            if retry_interval >= 60 * 60 * 1000:
+                retry_interval = 60 * 60 * 1000
+        else:
+            retry_interval = 2000  # try again at first after 2 seconds
+
+        yield self.store.set_destination_retry_timings(
+            destination,
+            int(self._clock.time_msec()),
+            retry_interval
+        )
diff --git a/synapse/state.py b/synapse/state.py
index 8144fa02b4..d9fdfb34be 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError
 from synapse.events.snapshot import EventContext
 
 from collections import namedtuple
@@ -42,6 +43,8 @@ class StateHandler(object):
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
+        # self.auth = hs.get_auth()
+        self.hs = hs
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
@@ -210,64 +213,96 @@ class StateHandler(object):
         else:
             prev_states = []
 
+        auth_events = {
+            k: e for k, e in unconflicted_state.items()
+            if k[0] in (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
+        }
+
         try:
-            new_state = {}
-            new_state.update(unconflicted_state)
-            for key, events in conflicted_state.items():
-                new_state[key] = self._resolve_state_events(events)
+            resolved_state = self._resolve_state_events(
+                conflicted_state, auth_events
+            )
         except:
             logger.exception("Failed to resolve state")
             raise
 
+        new_state = unconflicted_state
+        new_state.update(resolved_state)
+
         defer.returnValue((None, new_state, prev_states))
 
-    def _get_power_level_from_event_state(self, event, user_id):
-        if hasattr(event, "old_state_events") and event.old_state_events:
-            key = (EventTypes.PowerLevels, "", )
-            power_level_event = event.old_state_events.get(key)
-            level = None
-            if power_level_event:
-                level = power_level_event.content.get("users", {}).get(
-                    user_id
+    @log_function
+    def _resolve_state_events(self, conflicted_state, auth_events):
+        """ This is where we actually decide which of the conflicted state to
+        use.
+
+        We resolve conflicts in the following order:
+            1. power levels
+            2. memberships
+            3. other events.
+
+        :param conflicted_state:
+        :param auth_events:
+        :return:
+        """
+        resolved_state = {}
+        power_key = (EventTypes.PowerLevels, "")
+        if power_key in conflicted_state.items():
+            power_levels = conflicted_state[power_key]
+            resolved_state[power_key] = self._resolve_auth_events(power_levels)
+
+        auth_events.update(resolved_state)
+
+        for key, events in conflicted_state.items():
+            if key[0] == EventTypes.Member:
+                resolved_state[key] = self._resolve_auth_events(
+                    events,
+                    auth_events
                 )
-                if not level:
-                    level = power_level_event.content.get("users_default", 0)
 
-            return level
-        else:
-            return 0
+        auth_events.update(resolved_state)
 
-    @log_function
-    def _resolve_state_events(self, events):
-        curr_events = events
+        for key, events in conflicted_state.items():
+            if key not in resolved_state:
+                resolved_state[key] = self._resolve_normal_events(
+                    events, auth_events
+                )
 
-        new_powers = [
-            self._get_power_level_from_event_state(e, e.user_id)
-            for e in curr_events
-        ]
+        return resolved_state
 
-        new_powers = [
-            int(p) if p else 0 for p in new_powers
-        ]
+    def _resolve_auth_events(self, events, auth_events):
+        reverse = [i for i in reversed(self._ordered_events(events))]
 
-        max_power = max(new_powers)
+        auth_events = dict(auth_events)
 
-        curr_events = [
-            z[0] for z in zip(curr_events, new_powers)
-            if z[1] == max_power
-        ]
+        prev_event = reverse[0]
+        for event in reverse[1:]:
+            auth_events[(prev_event.type, prev_event.state_key)] = prev_event
+            try:
+                # FIXME: hs.get_auth() is bad style, but we need to do it to
+                # get around circular deps.
+                self.hs.get_auth().check(event, auth_events)
+                prev_event = event
+            except AuthError:
+                return prev_event
 
-        if not curr_events:
-            raise RuntimeError("Max didn't get a max?")
-        elif len(curr_events) == 1:
-            return curr_events[0]
-
-        # TODO: For now, just choose the one with the largest event_id.
-        return (
-            sorted(
-                curr_events,
-                key=lambda e: hashlib.sha1(
-                    e.event_id + e.user_id + e.room_id + e.type
-                ).hexdigest()
-            )[0]
-        )
+        return event
+
+    def _resolve_normal_events(self, events, auth_events):
+        for event in self._ordered_events(events):
+            try:
+                # FIXME: hs.get_auth() is bad style, but we need to do it to
+                # get around circular deps.
+                self.hs.get_auth().check(event, auth_events)
+                return event
+            except AuthError:
+                pass
+
+        # Oh dear.
+        return event
+
+    def _ordered_events(self, events):
+        def key_func(e):
+            return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
+
+        return sorted(events, key=key_func)
\ No newline at end of file
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4beb951b9f..4f09909607 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -30,6 +30,7 @@ from .transactions import TransactionStore
 from .keys import KeyStore
 from .event_federation import EventFederationStore
 from .media_repository import MediaRepositoryStore
+from .rejections import RejectionsStore
 
 from .state import StateStore
 from .signatures import SignatureStore
@@ -66,7 +67,7 @@ SCHEMAS = [
 
 # Remember to update this number every time an incompatible change is made to
 # database schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 11
+SCHEMA_VERSION = 12
 
 
 class _RollbackButIsFineException(Exception):
@@ -82,6 +83,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 DirectoryStore, KeyStore, StateStore, SignatureStore,
                 EventFederationStore,
                 MediaRepositoryStore,
+                RejectionsStore,
                 ):
 
     def __init__(self, hs):
@@ -224,6 +226,9 @@ class DataStore(RoomMemberStore, RoomStore,
         if not outlier:
             self._store_state_groups_txn(txn, event, context)
 
+        if context.rejected:
+            self._store_rejections_txn(txn, event.event_id, context.rejected)
+
         if current_state:
             txn.execute(
                 "DELETE FROM current_state_events WHERE room_id = ?",
@@ -262,7 +267,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 or_replace=True,
             )
 
-            if is_new_state:
+            if is_new_state and not context.rejected:
                 self._simple_insert_txn(
                     txn,
                     "current_state_events",
@@ -288,7 +293,7 @@ class DataStore(RoomMemberStore, RoomStore,
                     or_ignore=True,
                 )
 
-            if not backfilled:
+            if not backfilled and not context.rejected:
                 self._simple_insert_txn(
                     txn,
                     table="state_forward_extremities",
@@ -417,6 +422,35 @@ class DataStore(RoomMemberStore, RoomStore,
             ],
         )
 
+    def have_events(self, event_ids):
+        """Given a list of event ids, check if we have already processed them.
+
+        Returns:
+            dict: Has an entry for each event id we already have seen. Maps to
+            the rejected reason string if we rejected the event, else maps to
+            None.
+        """
+        def f(txn):
+            sql = (
+                "SELECT e.event_id, reason FROM events as e "
+                "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+                "WHERE event_id = ?"
+            )
+
+            res = {}
+            for event_id in event_ids:
+                txn.execute(sql, (event_id,))
+                row = txn.fetchone()
+                if row:
+                    _, rejected = row
+                    res[event_id] = rejected
+
+            return res
+
+        return self.runInteraction(
+            "have_events", f,
+        )
+
 
 def schema_path(schema):
     """ Get a filesystem path for the named database schema
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f660fc6eaf..2075a018b2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -458,10 +458,12 @@ class SQLBaseStore(object):
         return [e for e in events if e]
 
     def _get_event_txn(self, txn, event_id, check_redacted=True,
-                       get_prev_content=False):
+                       get_prev_content=False, allow_rejected=False):
         sql = (
-            "SELECT internal_metadata, json, r.event_id FROM event_json as e "
+            "SELECT internal_metadata, json, r.event_id, reason "
+            "FROM event_json as e "
             "LEFT JOIN redactions as r ON e.event_id = r.redacts "
+            "LEFT JOIN rejections as rej on rej.event_id = e.event_id  "
             "WHERE e.event_id = ? "
             "LIMIT 1 "
         )
@@ -473,13 +475,16 @@ class SQLBaseStore(object):
         if not res:
             return None
 
-        internal_metadata, js, redacted = res
+        internal_metadata, js, redacted, rejected_reason = res
 
-        return self._get_event_from_row_txn(
-            txn, internal_metadata, js, redacted,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-        )
+        if allow_rejected or not rejected_reason:
+            return self._get_event_from_row_txn(
+                txn, internal_metadata, js, redacted,
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content,
+            )
+        else:
+            return None
 
     def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
                                 check_redacted=True, get_prev_content=False):
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
new file mode 100644
index 0000000000..7d38b31f44
--- /dev/null
+++ b/synapse/storage/rejections.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class RejectionsStore(SQLBaseStore):
+    def _store_rejections_txn(self, txn, event_id, reason):
+        self._simple_insert_txn(
+            txn,
+            table="rejections",
+            values={
+                "event_id": event_id,
+                "reason": reason,
+                "last_failure": self._clock.time_msec(),
+            }
+        )
diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/v12.sql
new file mode 100644
index 0000000000..bd2a8b1bb5
--- /dev/null
+++ b/synapse/storage/schema/delta/v12.sql
@@ -0,0 +1,21 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS rejections(
+    event_id TEXT NOT NULL,
+    reason TEXT NOT NULL,
+    last_check TEXT NOT NULL,
+    CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
+);
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index dd00c1cd2f..bc7c6b6ed5 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -123,3 +123,10 @@ CREATE TABLE IF NOT EXISTS room_hosts(
 );
 
 CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id);
+
+CREATE TABLE IF NOT EXISTS rejections(
+    event_id TEXT NOT NULL,
+    reason TEXT NOT NULL,
+    last_check TEXT NOT NULL,
+    CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
+);
diff --git a/tests/test_state.py b/tests/test_state.py
index 98ad9e54cd..019e794aa2 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -16,11 +16,120 @@
 from tests import unittest
 from twisted.internet import defer
 
+from synapse.events import FrozenEvent
+from synapse.api.auth import Auth
+from synapse.api.constants import EventTypes, Membership
 from synapse.state import StateHandler
 
 from mock import Mock
 
 
+_next_event_id = 1000
+
+
+def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
+                 prev_events=[], **kwargs):
+    global _next_event_id
+
+    if not event_id:
+        _next_event_id += 1
+        event_id = str(_next_event_id)
+
+    if not name:
+        if state_key is not None:
+            name = "<%s-%s, %s>" % (type, state_key, event_id,)
+        else:
+            name = "<%s, %s>" % (type, event_id,)
+
+    d = {
+        "event_id": event_id,
+        "type": type,
+        "sender": "@user_id:example.com",
+        "room_id": "!room_id:example.com",
+        "depth": depth,
+        "prev_events": prev_events,
+    }
+
+    if state_key is not None:
+        d["state_key"] = state_key
+
+    d.update(kwargs)
+
+    event = FrozenEvent(d)
+
+    return event
+
+
+class StateGroupStore(object):
+    def __init__(self):
+        self._event_to_state_group = {}
+        self._group_to_state = {}
+
+        self._next_group = 1
+
+    def get_state_groups(self, event_ids):
+        groups = {}
+        for event_id in event_ids:
+            group = self._event_to_state_group.get(event_id)
+            if group:
+                groups[group] = self._group_to_state[group]
+
+        return defer.succeed(groups)
+
+    def store_state_groups(self, event, context):
+        if context.current_state is None:
+            return
+
+        state_events = context.current_state
+
+        if event.is_state():
+            state_events[(event.type, event.state_key)] = event
+
+        state_group = context.state_group
+        if not state_group:
+            state_group = self._next_group
+            self._next_group += 1
+
+            self._group_to_state[state_group] = state_events.values()
+
+        self._event_to_state_group[event.event_id] = state_group
+
+
+class DictObj(dict):
+    def __init__(self, **kwargs):
+        super(DictObj, self).__init__(kwargs)
+        self.__dict__ = self
+
+
+class Graph(object):
+    def __init__(self, nodes, edges):
+        events = {}
+        clobbered = set(events.keys())
+
+        for event_id, fields in nodes.items():
+            refs = edges.get(event_id)
+            if refs:
+                clobbered.difference_update(refs)
+                prev_events = [(r, {}) for r in refs]
+            else:
+                prev_events = []
+
+            events[event_id] = create_event(
+                event_id=event_id,
+                prev_events=prev_events,
+                **fields
+            )
+
+        self._leaves = clobbered
+        self._events = sorted(events.values(), key=lambda e: e.depth)
+
+    def walk(self):
+        return iter(self._events)
+
+    def get_leaves(self):
+        return (self._events[i] for i in self._leaves)
+
+
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = Mock(
@@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
                 "add_event_hashes",
             ]
         )
-        hs = Mock(spec=["get_datastore"])
+        hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"])
         hs.get_datastore.return_value = self.store
+        hs.get_state_handler.return_value = None
+        hs.get_auth.return_value = Auth(hs)
 
         self.state = StateHandler(hs)
         self.event_id = 0
 
     @defer.inlineCallbacks
+    def test_branch_no_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="",
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Message,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Message,
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Message,
+                    depth=4,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["A"],
+                "D": ["B", "C"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertEqual(2, len(context_store["D"].current_state))
+
+    @defer.inlineCallbacks
+    def test_branch_basic_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="creator",
+                    content={"membership": "@user_id:example.com"},
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id:example.com",
+                    content={"membership": Membership.JOIN},
+                    membership=Membership.JOIN,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=4,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Message,
+                    depth=5,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["A"],
+                "D": ["B", "C"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertSetEqual(
+            {"START", "A", "C"},
+            {e.event_id for e in context_store["D"].current_state.values()}
+        )
+
+    @defer.inlineCallbacks
+    def test_branch_have_banned_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="creator",
+                    content={"membership": "@user_id:example.com"},
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id:example.com",
+                    content={"membership": Membership.JOIN},
+                    membership=Membership.JOIN,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id_2:example.com",
+                    content={"membership": Membership.BAN},
+                    membership=Membership.BAN,
+                    depth=4,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=4,
+                    sender="@user_id_2:example.com",
+                ),
+                "E": DictObj(
+                    type=EventTypes.Message,
+                    depth=5,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["B"],
+                "D": ["B"],
+                "E": ["C", "D"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertSetEqual(
+            {"START", "A", "B", "C"},
+            {e.event_id for e in context_store["E"].current_state.values()}
+        )
+
+    @defer.inlineCallbacks
     def test_annotate_with_old_message(self):
-        event = self.create_event(type="test_message", name="event")
+        event = create_event(type="test_message", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         context = yield self.state.compute_event_context(
@@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
-        event = self.create_event(type="state", state_key="", name="event")
+        event = create_event(type="state", state_key="", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         context = yield self.state.compute_event_context(
@@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
-        event = self.create_event(type="test_message", name="event")
-        event.prev_events = []
+        event = create_event(type="test_message", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         group_name = "group_name_1"
@@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_state(self):
-        event = self.create_event(type="state", state_key="", name="event")
-        event.prev_events = []
+        event = create_event(type="state", state_key="", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         group_name = "group_name_1"
@@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
-        event = self.create_event(type="test_message", name="event")
-        event.prev_events = []
+        event = create_event(type="test_message", name="event")
 
         old_state_1 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test3", state_key="2"),
-            self.create_event(type="test4", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test3", state_key="2"),
+            create_event(type="test4", state_key=""),
         ]
 
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
-
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
-        }
-
-        context = yield self.state.compute_event_context(event)
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
         self.assertEqual(len(context.current_state), 5)
 
@@ -181,56 +447,76 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
-        event = self.create_event(type="test4", state_key="", name="event")
-        event.prev_events = []
+        event = create_event(type="test4", state_key="", name="event")
 
         old_state_1 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test3", state_key="2"),
-            self.create_event(type="test4", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test3", state_key="2"),
+            create_event(type="test4", state_key=""),
         ]
 
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
-
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
-        }
-
-        context = yield self.state.compute_event_context(event)
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
         self.assertEqual(len(context.current_state), 5)
 
         self.assertIsNone(context.state_group)
 
-    def create_event(self, name=None, type=None, state_key=None):
-        self.event_id += 1
-        event_id = str(self.event_id)
+    @defer.inlineCallbacks
+    def test_standard_depth_conflict(self):
+        event = create_event(type="test4", name="event")
+
+        member_event = create_event(
+            type=EventTypes.Member,
+            state_key="@user_id:example.com",
+            content={
+                "membership": Membership.JOIN,
+            }
+        )
 
-        if not name:
-            if state_key is not None:
-                name = "<%s-%s>" % (type, state_key)
-            else:
-                name = "<%s>" % (type, )
+        old_state_1 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=1),
+        ]
+
+        old_state_2 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=2),
+        ]
 
-        event = Mock(name=name, spec=[])
-        event.type = type
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
-        if state_key is not None:
-            event.state_key = state_key
-        event.event_id = event_id
+        self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+
+        # Reverse the depth to make sure we are actually using the depths
+        # during state resolution.
+
+        old_state_1 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=2),
+        ]
+
+        old_state_2 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=1),
+        ]
+
+        context = yield self._get_context(event, old_state_1, old_state_2)
+
+        self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
 
-        event.is_state = lambda: (state_key is not None)
-        event.unsigned = {}
+    def _get_context(self, event, old_state_1, old_state_2):
+        group_name_1 = "group_name_1"
+        group_name_2 = "group_name_2"
 
-        event.user_id = "@user_id:example.com"
-        event.room_id = "!room_id:example.com"
+        self.store.get_state_groups.return_value = {
+            group_name_1: old_state_1,
+            group_name_2: old_state_2,
+        }
 
-        return event
+        return self.state.compute_event_context(event)