summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xsetup.py2
-rw-r--r--synapse/push/__init__.py97
-rw-r--r--synapse/push/baserules.py49
-rw-r--r--synapse/push/httppusher.py11
-rw-r--r--synapse/python_dependencies.py2
-rw-r--r--synapse/rest/client/v1/push_rule.py10
-rw-r--r--synapse/rest/client/v1/pusher.py11
-rw-r--r--synapse/state.py134
-rw-r--r--synapse/storage/__init__.py5
-rw-r--r--tests/test_state.py428
10 files changed, 606 insertions, 143 deletions
diff --git a/setup.py b/setup.py
index 043cd044a7..3249e87a96 100755
--- a/setup.py
+++ b/setup.py
@@ -33,7 +33,7 @@ setup(
     install_requires=[
         "syutil==0.0.2",
         "matrix_angular_sdk==0.6.0",
-        "Twisted>=14.0.0",
+        "Twisted==14.0.2",
         "service_identity>=1.0.0",
         "pyopenssl>=0.14",
         "pyyaml",
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 472ede5480..cc05278c8c 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -16,13 +16,15 @@
 from twisted.internet import defer
 
 from synapse.streams.config import PaginationConfig
-from synapse.types import StreamToken
+from synapse.types import StreamToken, UserID
 
 import synapse.util.async
+import baserules
 
 import logging
 import fnmatch
 import json
+import re
 
 logger = logging.getLogger(__name__)
 
@@ -33,6 +35,8 @@ class Pusher(object):
     GIVE_UP_AFTER = 24 * 60 * 60 * 1000
     DEFAULT_ACTIONS = ['notify']
 
+    INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
+
     def __init__(self, _hs, instance_handle, user_name, app_id,
                  app_display_name, device_display_name, pushkey, pushkey_ts,
                  data, last_token, last_success, failing_since):
@@ -76,13 +80,44 @@ class Pusher(object):
         rules = yield self.store.get_push_rules_for_user_name(self.user_name)
 
         for r in rules:
+            r['conditions'] = json.loads(r['conditions'])
+            r['actions'] = json.loads(r['actions'])
+
+        user_name_localpart = UserID.from_string(self.user_name).localpart
+
+        rules.extend(baserules.make_base_rules(user_name_localpart))
+
+        # get *our* member event for display name matching
+        member_events_for_room = yield self.store.get_current_state(
+            room_id=ev['room_id'],
+            event_type='m.room.member',
+            state_key=None
+        )
+        my_display_name = None
+        room_member_count = 0
+        for mev in member_events_for_room:
+            if mev.content['membership'] != 'join':
+                continue
+
+            # This loop does two things:
+            # 1) Find our current display name
+            if mev.state_key == self.user_name:
+                my_display_name = mev.content['displayname']
+
+            # and 2) Get the number of people in that room
+            room_member_count += 1
+
+        for r in rules:
             matches = True
 
-            conditions = json.loads(r['conditions'])
-            actions = json.loads(r['actions'])
+            conditions = r['conditions']
+            actions = r['actions']
 
             for c in conditions:
-                matches &= self._event_fulfills_condition(ev, c)
+                matches &= self._event_fulfills_condition(
+                    ev, c, display_name=my_display_name,
+                    room_member_count=room_member_count
+                )
             # ignore rules with no actions (we have an explict 'dont_notify'
             if len(actions) == 0:
                 logger.warn(
@@ -95,7 +130,7 @@ class Pusher(object):
 
         defer.returnValue(Pusher.DEFAULT_ACTIONS)
 
-    def _event_fulfills_condition(self, ev, condition):
+    def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
         if condition['kind'] == 'event_match':
             if 'pattern' not in condition:
                 logger.warn("event_match condition with no pattern")
@@ -103,13 +138,49 @@ class Pusher(object):
             pat = condition['pattern']
 
             val = _value_for_dotted_key(condition['key'], ev)
-            if fnmatch.fnmatch(val, pat):
-                return True
-            return False
+            if val is None:
+                return False
+            return fnmatch.fnmatch(val.upper(), pat.upper())
         elif condition['kind'] == 'device':
             if 'instance_handle' not in condition:
                 return True
             return condition['instance_handle'] == self.instance_handle
+        elif condition['kind'] == 'contains_display_name':
+            # This is special because display names can be different
+            # between rooms and so you can't really hard code it in a rule.
+            # Optimisation: we should cache these names and update them from
+            # the event stream.
+            if 'content' not in ev or 'body' not in ev['content']:
+                return False
+            if not display_name:
+                return False
+            return fnmatch.fnmatch(
+                ev['content']['body'].upper(), "*%s*" % (display_name.upper(),)
+            )
+        elif condition['kind'] == 'room_member_count':
+            if 'is' not in condition:
+                return False
+            m = Pusher.INEQUALITY_EXPR.match(condition['is'])
+            if not m:
+                return False
+            ineq = m.group(1)
+            rhs = m.group(2)
+            if not rhs.isdigit():
+                return False
+            rhs = int(rhs)
+
+            if ineq == '' or ineq == '==':
+                return room_member_count == rhs
+            elif ineq == '<':
+                return room_member_count < rhs
+            elif ineq == '>':
+                return room_member_count > rhs
+            elif ineq == '>=':
+                return room_member_count >= rhs
+            elif ineq == '<=':
+                return room_member_count <= rhs
+            else:
+                return False
         else:
             return True
 
@@ -123,6 +194,16 @@ class Pusher(object):
         if name_aliases[0] is not None:
             ctx['name'] = name_aliases[0]
 
+        their_member_events_for_room = yield self.store.get_current_state(
+            room_id=ev['room_id'],
+            event_type='m.room.member',
+            state_key=ev['user_id']
+        )
+        if len(their_member_events_for_room) > 0:
+            dn = their_member_events_for_room[0].content['displayname']
+            if dn is not None:
+                ctx['sender_display_name'] = dn
+
         defer.returnValue(ctx)
 
     @defer.inlineCallbacks
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
new file mode 100644
index 0000000000..bd162baade
--- /dev/null
+++ b/synapse/push/baserules.py
@@ -0,0 +1,49 @@
+def make_base_rules(user_name):
+    """
+    Nominally we reserve priority class 0 for these rules, although
+    in practice we just append them to the end so we don't actually need it.
+    """
+    return [
+        {
+            'conditions': [
+                {
+                    'kind': 'event_match',
+                    'key': 'content.body',
+                    'pattern': '*%s*' % (user_name,), # Matrix ID match
+                }
+            ],
+            'actions': [
+                'notify',
+                {
+                    'set_sound': 'default'
+                }
+            ]
+        },
+        {
+            'conditions': [
+                {
+                    'kind': 'contains_display_name'
+                }
+            ],
+            'actions': [
+                'notify',
+                {
+                    'set_sound': 'default'
+                }
+            ]
+        },
+        {
+            'conditions': [
+                {
+                    'kind': 'room_member_count',
+                    'is': '2'
+                }
+            ],
+            'actions': [
+                'notify',
+                {
+                    'set_sound': 'default'
+                }
+            ]
+        }
+    ]
\ No newline at end of file
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index ab128e31e5..d4c5f03b01 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -67,10 +67,7 @@ class HttpPusher(Pusher):
             'notification': {
                 'id': event['event_id'],
                 'type': event['type'],
-                'from': event['user_id'],
-                # we may have to fetch this over federation and we
-                # can't trust it anyway: is it worth it?
-                #'from_display_name': 'Steve Stevington'
+                'sender': event['user_id'],
                 'counts': {  # -- we don't mark messages as read yet so
                              # we have no way of knowing
                     # Just set the badge to 1 until we have read receipts
@@ -90,9 +87,13 @@ class HttpPusher(Pusher):
         }
         if event['type'] == 'm.room.member':
             d['notification']['membership'] = event['content']['membership']
+        if 'content' in event:
+            d['notification']['content'] = event['content']
 
         if len(ctx['aliases']):
             d['notification']['room_alias'] = ctx['aliases'][0]
+        if 'sender_display_name' in ctx:
+            d['notification']['sender_display_name'] = ctx['sender_display_name']
         if 'name' in ctx:
             d['notification']['room_name'] = ctx['name']
 
@@ -119,7 +120,7 @@ class HttpPusher(Pusher):
             'notification': {
                 'id': '',
                 'type': None,
-                'from': '',
+                'sender': '',
                 'counts': {
                     'unread': 0,
                     'missed_calls': 0
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 4182ad990f..826a36f203 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -6,7 +6,7 @@ logger = logging.getLogger(__name__)
 REQUIREMENTS = {
     "syutil==0.0.2": ["syutil"],
     "matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"],
-    "Twisted>=14.0.0": ["twisted>=14.0.0"],
+    "Twisted==14.0.2": ["twisted==14.0.2"],
     "service_identity>=1.0.0": ["service_identity>=1.0.0"],
     "pyopenssl>=0.14": ["OpenSSL>=0.14"],
     "pyyaml": ["yaml"],
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 2b1e930326..0f78fa667c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -26,11 +26,11 @@ import json
 class PushRuleRestServlet(ClientV1RestServlet):
     PATTERN = client_path_pattern("/pushrules/.*$")
     PRIORITY_CLASS_MAP = {
-        'underride': 0,
-        'sender': 1,
-        'room': 2,
-        'content': 3,
-        'override': 4,
+        'underride': 1,
+        'sender': 2,
+        'room': 3,
+        'content': 4,
+        'override': 5,
     }
     PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()}
     SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 72d5e9e476..353a4a6589 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -31,6 +31,16 @@ class PusherRestServlet(ClientV1RestServlet):
 
         content = _parse_json(request)
 
+        pusher_pool = self.hs.get_pusherpool()
+
+        if ('pushkey' in content and 'app_id' in content
+                    and 'kind' in content and
+                    content['kind'] is None):
+            yield pusher_pool.remove_pusher(
+                content['app_id'], content['pushkey']
+            )
+            defer.returnValue((200, {}))
+
         reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name',
                 'device_display_name', 'pushkey', 'lang', 'data']
         missing = []
@@ -41,7 +51,6 @@ class PusherRestServlet(ClientV1RestServlet):
             raise SynapseError(400, "Missing parameters: "+','.join(missing),
                                errcode=Codes.MISSING_PARAM)
 
-        pusher_pool = self.hs.get_pusherpool()
         try:
             yield pusher_pool.add_pusher(
                 user_name=user.to_string(),
diff --git a/synapse/state.py b/synapse/state.py
index 8144fa02b4..081bc31bb5 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError
 from synapse.events.snapshot import EventContext
 
 from collections import namedtuple
@@ -36,12 +37,16 @@ def _get_state_key_from_event(event):
 KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
+AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
+
+
 class StateHandler(object):
     """ Responsible for doing state conflict resolution.
     """
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
+        self.hs = hs
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
@@ -210,64 +215,93 @@ class StateHandler(object):
         else:
             prev_states = []
 
+        auth_events = {
+            k: e for k, e in unconflicted_state.items()
+            if k[0] in AuthEventTypes
+        }
+
         try:
-            new_state = {}
-            new_state.update(unconflicted_state)
-            for key, events in conflicted_state.items():
-                new_state[key] = self._resolve_state_events(events)
+            resolved_state = self._resolve_state_events(
+                conflicted_state, auth_events
+            )
         except:
             logger.exception("Failed to resolve state")
             raise
 
-        defer.returnValue((None, new_state, prev_states))
-
-    def _get_power_level_from_event_state(self, event, user_id):
-        if hasattr(event, "old_state_events") and event.old_state_events:
-            key = (EventTypes.PowerLevels, "", )
-            power_level_event = event.old_state_events.get(key)
-            level = None
-            if power_level_event:
-                level = power_level_event.content.get("users", {}).get(
-                    user_id
-                )
-                if not level:
-                    level = power_level_event.content.get("users_default", 0)
+        new_state = unconflicted_state
+        new_state.update(resolved_state)
 
-            return level
-        else:
-            return 0
+        defer.returnValue((None, new_state, prev_states))
 
     @log_function
-    def _resolve_state_events(self, events):
-        curr_events = events
-
-        new_powers = [
-            self._get_power_level_from_event_state(e, e.user_id)
-            for e in curr_events
-        ]
-
-        new_powers = [
-            int(p) if p else 0 for p in new_powers
-        ]
+    def _resolve_state_events(self, conflicted_state, auth_events):
+        """ This is where we actually decide which of the conflicted state to
+        use.
+
+        We resolve conflicts in the following order:
+            1. power levels
+            2. memberships
+            3. other events.
+        """
+        resolved_state = {}
+        power_key = (EventTypes.PowerLevels, "")
+        if power_key in conflicted_state.items():
+            power_levels = conflicted_state[power_key]
+            resolved_state[power_key] = self._resolve_auth_events(power_levels)
+
+        auth_events.update(resolved_state)
+
+        for key, events in conflicted_state.items():
+            if key[0] == EventTypes.Member:
+                resolved_state[key] = self._resolve_auth_events(
+                    events,
+                    auth_events
+                )
 
-        max_power = max(new_powers)
+        auth_events.update(resolved_state)
 
-        curr_events = [
-            z[0] for z in zip(curr_events, new_powers)
-            if z[1] == max_power
-        ]
+        for key, events in conflicted_state.items():
+            if key not in resolved_state:
+                resolved_state[key] = self._resolve_normal_events(
+                    events, auth_events
+                )
 
-        if not curr_events:
-            raise RuntimeError("Max didn't get a max?")
-        elif len(curr_events) == 1:
-            return curr_events[0]
-
-        # TODO: For now, just choose the one with the largest event_id.
-        return (
-            sorted(
-                curr_events,
-                key=lambda e: hashlib.sha1(
-                    e.event_id + e.user_id + e.room_id + e.type
-                ).hexdigest()
-            )[0]
-        )
+        return resolved_state
+
+    def _resolve_auth_events(self, events, auth_events):
+        reverse = [i for i in reversed(self._ordered_events(events))]
+
+        auth_events = dict(auth_events)
+
+        prev_event = reverse[0]
+        for event in reverse[1:]:
+            auth_events[(prev_event.type, prev_event.state_key)] = prev_event
+            try:
+                # FIXME: hs.get_auth() is bad style, but we need to do it to
+                # get around circular deps.
+                self.hs.get_auth().check(event, auth_events)
+                prev_event = event
+            except AuthError:
+                return prev_event
+
+        return event
+
+    def _resolve_normal_events(self, events, auth_events):
+        for event in self._ordered_events(events):
+            try:
+                # FIXME: hs.get_auth() is bad style, but we need to do it to
+                # get around circular deps.
+                self.hs.get_auth().check(event, auth_events)
+                return event
+            except AuthError:
+                pass
+
+        # Use the last event (the one with the least depth) if they all fail
+        # the auth check.
+        return event
+
+    def _ordered_events(self, events):
+        def key_func(e):
+            return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
+
+        return sorted(events, key=key_func)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 89a1e60c2b..abddb22ac7 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -377,9 +377,12 @@ class DataStore(RoomMemberStore, RoomStore,
             "redacted": del_sql,
         }
 
-        if event_type:
+        if event_type and state_key is not None:
             sql += " AND s.type = ? AND s.state_key = ? "
             args = (room_id, event_type, state_key)
+        elif event_type:
+            sql += " AND s.type = ?"
+            args = (room_id, event_type)
         else:
             args = (room_id, )
 
diff --git a/tests/test_state.py b/tests/test_state.py
index 98ad9e54cd..019e794aa2 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -16,11 +16,120 @@
 from tests import unittest
 from twisted.internet import defer
 
+from synapse.events import FrozenEvent
+from synapse.api.auth import Auth
+from synapse.api.constants import EventTypes, Membership
 from synapse.state import StateHandler
 
 from mock import Mock
 
 
+_next_event_id = 1000
+
+
+def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
+                 prev_events=[], **kwargs):
+    global _next_event_id
+
+    if not event_id:
+        _next_event_id += 1
+        event_id = str(_next_event_id)
+
+    if not name:
+        if state_key is not None:
+            name = "<%s-%s, %s>" % (type, state_key, event_id,)
+        else:
+            name = "<%s, %s>" % (type, event_id,)
+
+    d = {
+        "event_id": event_id,
+        "type": type,
+        "sender": "@user_id:example.com",
+        "room_id": "!room_id:example.com",
+        "depth": depth,
+        "prev_events": prev_events,
+    }
+
+    if state_key is not None:
+        d["state_key"] = state_key
+
+    d.update(kwargs)
+
+    event = FrozenEvent(d)
+
+    return event
+
+
+class StateGroupStore(object):
+    def __init__(self):
+        self._event_to_state_group = {}
+        self._group_to_state = {}
+
+        self._next_group = 1
+
+    def get_state_groups(self, event_ids):
+        groups = {}
+        for event_id in event_ids:
+            group = self._event_to_state_group.get(event_id)
+            if group:
+                groups[group] = self._group_to_state[group]
+
+        return defer.succeed(groups)
+
+    def store_state_groups(self, event, context):
+        if context.current_state is None:
+            return
+
+        state_events = context.current_state
+
+        if event.is_state():
+            state_events[(event.type, event.state_key)] = event
+
+        state_group = context.state_group
+        if not state_group:
+            state_group = self._next_group
+            self._next_group += 1
+
+            self._group_to_state[state_group] = state_events.values()
+
+        self._event_to_state_group[event.event_id] = state_group
+
+
+class DictObj(dict):
+    def __init__(self, **kwargs):
+        super(DictObj, self).__init__(kwargs)
+        self.__dict__ = self
+
+
+class Graph(object):
+    def __init__(self, nodes, edges):
+        events = {}
+        clobbered = set(events.keys())
+
+        for event_id, fields in nodes.items():
+            refs = edges.get(event_id)
+            if refs:
+                clobbered.difference_update(refs)
+                prev_events = [(r, {}) for r in refs]
+            else:
+                prev_events = []
+
+            events[event_id] = create_event(
+                event_id=event_id,
+                prev_events=prev_events,
+                **fields
+            )
+
+        self._leaves = clobbered
+        self._events = sorted(events.values(), key=lambda e: e.depth)
+
+    def walk(self):
+        return iter(self._events)
+
+    def get_leaves(self):
+        return (self._events[i] for i in self._leaves)
+
+
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = Mock(
@@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
                 "add_event_hashes",
             ]
         )
-        hs = Mock(spec=["get_datastore"])
+        hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"])
         hs.get_datastore.return_value = self.store
+        hs.get_state_handler.return_value = None
+        hs.get_auth.return_value = Auth(hs)
 
         self.state = StateHandler(hs)
         self.event_id = 0
 
     @defer.inlineCallbacks
+    def test_branch_no_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="",
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Message,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Message,
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Message,
+                    depth=4,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["A"],
+                "D": ["B", "C"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertEqual(2, len(context_store["D"].current_state))
+
+    @defer.inlineCallbacks
+    def test_branch_basic_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="creator",
+                    content={"membership": "@user_id:example.com"},
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id:example.com",
+                    content={"membership": Membership.JOIN},
+                    membership=Membership.JOIN,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=4,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Message,
+                    depth=5,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["A"],
+                "D": ["B", "C"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertSetEqual(
+            {"START", "A", "C"},
+            {e.event_id for e in context_store["D"].current_state.values()}
+        )
+
+    @defer.inlineCallbacks
+    def test_branch_have_banned_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="creator",
+                    content={"membership": "@user_id:example.com"},
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id:example.com",
+                    content={"membership": Membership.JOIN},
+                    membership=Membership.JOIN,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id_2:example.com",
+                    content={"membership": Membership.BAN},
+                    membership=Membership.BAN,
+                    depth=4,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=4,
+                    sender="@user_id_2:example.com",
+                ),
+                "E": DictObj(
+                    type=EventTypes.Message,
+                    depth=5,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["B"],
+                "D": ["B"],
+                "E": ["C", "D"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertSetEqual(
+            {"START", "A", "B", "C"},
+            {e.event_id for e in context_store["E"].current_state.values()}
+        )
+
+    @defer.inlineCallbacks
     def test_annotate_with_old_message(self):
-        event = self.create_event(type="test_message", name="event")
+        event = create_event(type="test_message", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         context = yield self.state.compute_event_context(
@@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
-        event = self.create_event(type="state", state_key="", name="event")
+        event = create_event(type="state", state_key="", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         context = yield self.state.compute_event_context(
@@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
-        event = self.create_event(type="test_message", name="event")
-        event.prev_events = []
+        event = create_event(type="test_message", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         group_name = "group_name_1"
@@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_state(self):
-        event = self.create_event(type="state", state_key="", name="event")
-        event.prev_events = []
+        event = create_event(type="state", state_key="", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         group_name = "group_name_1"
@@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
-        event = self.create_event(type="test_message", name="event")
-        event.prev_events = []
+        event = create_event(type="test_message", name="event")
 
         old_state_1 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test3", state_key="2"),
-            self.create_event(type="test4", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test3", state_key="2"),
+            create_event(type="test4", state_key=""),
         ]
 
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
-
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
-        }
-
-        context = yield self.state.compute_event_context(event)
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
         self.assertEqual(len(context.current_state), 5)
 
@@ -181,56 +447,76 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
-        event = self.create_event(type="test4", state_key="", name="event")
-        event.prev_events = []
+        event = create_event(type="test4", state_key="", name="event")
 
         old_state_1 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test3", state_key="2"),
-            self.create_event(type="test4", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test3", state_key="2"),
+            create_event(type="test4", state_key=""),
         ]
 
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
-
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
-        }
-
-        context = yield self.state.compute_event_context(event)
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
         self.assertEqual(len(context.current_state), 5)
 
         self.assertIsNone(context.state_group)
 
-    def create_event(self, name=None, type=None, state_key=None):
-        self.event_id += 1
-        event_id = str(self.event_id)
+    @defer.inlineCallbacks
+    def test_standard_depth_conflict(self):
+        event = create_event(type="test4", name="event")
+
+        member_event = create_event(
+            type=EventTypes.Member,
+            state_key="@user_id:example.com",
+            content={
+                "membership": Membership.JOIN,
+            }
+        )
 
-        if not name:
-            if state_key is not None:
-                name = "<%s-%s>" % (type, state_key)
-            else:
-                name = "<%s>" % (type, )
+        old_state_1 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=1),
+        ]
+
+        old_state_2 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=2),
+        ]
 
-        event = Mock(name=name, spec=[])
-        event.type = type
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
-        if state_key is not None:
-            event.state_key = state_key
-        event.event_id = event_id
+        self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+
+        # Reverse the depth to make sure we are actually using the depths
+        # during state resolution.
+
+        old_state_1 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=2),
+        ]
+
+        old_state_2 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=1),
+        ]
+
+        context = yield self._get_context(event, old_state_1, old_state_2)
+
+        self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
 
-        event.is_state = lambda: (state_key is not None)
-        event.unsigned = {}
+    def _get_context(self, event, old_state_1, old_state_2):
+        group_name_1 = "group_name_1"
+        group_name_2 = "group_name_2"
 
-        event.user_id = "@user_id:example.com"
-        event.room_id = "!room_id:example.com"
+        self.store.get_state_groups.return_value = {
+            group_name_1: old_state_1,
+            group_name_2: old_state_2,
+        }
 
-        return event
+        return self.state.compute_event_context(event)