summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_filtering.py512
-rw-r--r--tests/handlers/test_federation.py2
-rw-r--r--tests/rest/client/v1/test_presence.py4
-rw-r--r--tests/rest/client/v1/test_profile.py2
-rw-r--r--tests/rest/client/v1/test_rooms.py7
-rw-r--r--tests/rest/client/v1/test_typing.py1
-rw-r--r--tests/rest/client/v2_alpha/__init__.py10
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py95
-rw-r--r--tests/storage/test_registration.py10
-rw-r--r--tests/test_state.py428
10 files changed, 993 insertions, 78 deletions
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
new file mode 100644
index 0000000000..babf4c37f1
--- /dev/null
+++ b/tests/api/test_filtering.py
@@ -0,0 +1,512 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import namedtuple
+from tests import unittest
+from twisted.internet import defer
+
+from mock import Mock, NonCallableMock
+from tests.utils import (
+    MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool,
+    MockKey
+)
+
+from synapse.server import HomeServer
+from synapse.types import UserID
+from synapse.api.filtering import Filter
+
+user_localpart = "test_user"
+MockEvent = namedtuple("MockEvent", "sender type room_id")
+
+class FilteringTestCase(unittest.TestCase):
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        db_pool = SQLiteMemoryDbPool()
+        yield db_pool.prepare()
+
+        self.mock_config = NonCallableMock()
+        self.mock_config.signing_key = [MockKey()]
+
+        self.mock_federation_resource = MockHttpResource()
+
+        self.mock_http_client = Mock(spec=[])
+        self.mock_http_client.put_json = DeferredMockCallable()
+
+        hs = HomeServer("test",
+            db_pool=db_pool,
+            handlers=None,
+            http_client=self.mock_http_client,
+            config=self.mock_config,
+            keyring=Mock(),
+        )
+
+        self.filtering = hs.get_filtering()
+        self.filter = Filter({})
+
+        self.datastore = hs.get_datastore()
+
+    def test_definition_types_works_with_literals(self):
+        definition = {
+            "types": ["m.room.message", "org.matrix.foo.bar"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!foo:bar"
+        )
+        self.assertTrue(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_types_works_with_wildcards(self):
+        definition = {
+            "types": ["m.*", "org.matrix.foo.bar"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!foo:bar"
+        )
+        self.assertTrue(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_types_works_with_unknowns(self):
+        definition = {
+            "types": ["m.room.message", "org.matrix.foo.bar"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="now.for.something.completely.different",
+            room_id="!foo:bar"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_types_works_with_literals(self):
+        definition = {
+            "not_types": ["m.room.message", "org.matrix.foo.bar"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!foo:bar"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_types_works_with_wildcards(self):
+        definition = {
+            "not_types": ["m.room.message", "org.matrix.*"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="org.matrix.custom.event",
+            room_id="!foo:bar"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_types_works_with_unknowns(self):
+        definition = {
+            "not_types": ["m.*", "org.*"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="com.nom.nom.nom",
+            room_id="!foo:bar"
+        )
+        self.assertTrue(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_types_takes_priority_over_types(self):
+        definition = {
+            "not_types": ["m.*", "org.*"],
+            "types": ["m.room.message", "m.room.topic"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.topic",
+            room_id="!foo:bar"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_senders_works_with_literals(self):
+        definition = {
+            "senders": ["@flibble:wibble"]
+        }
+        event = MockEvent(
+            sender="@flibble:wibble",
+            type="com.nom.nom.nom",
+            room_id="!foo:bar"
+        )
+        self.assertTrue(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_senders_works_with_unknowns(self):
+        definition = {
+            "senders": ["@flibble:wibble"]
+        }
+        event = MockEvent(
+            sender="@challenger:appears",
+            type="com.nom.nom.nom",
+            room_id="!foo:bar"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_senders_works_with_literals(self):
+        definition = {
+            "not_senders": ["@flibble:wibble"]
+        }
+        event = MockEvent(
+            sender="@flibble:wibble",
+            type="com.nom.nom.nom",
+            room_id="!foo:bar"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_senders_works_with_unknowns(self):
+        definition = {
+            "not_senders": ["@flibble:wibble"]
+        }
+        event = MockEvent(
+            sender="@challenger:appears",
+            type="com.nom.nom.nom",
+            room_id="!foo:bar"
+        )
+        self.assertTrue(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_senders_takes_priority_over_senders(self):
+        definition = {
+            "not_senders": ["@misspiggy:muppets"],
+            "senders": ["@kermit:muppets", "@misspiggy:muppets"]
+        }
+        event = MockEvent(
+            sender="@misspiggy:muppets",
+            type="m.room.topic",
+            room_id="!foo:bar"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_rooms_works_with_literals(self):
+        definition = {
+            "rooms": ["!secretbase:unknown"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!secretbase:unknown"
+        )
+        self.assertTrue(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_rooms_works_with_unknowns(self):
+        definition = {
+            "rooms": ["!secretbase:unknown"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!anothersecretbase:unknown"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_rooms_works_with_literals(self):
+        definition = {
+            "not_rooms": ["!anothersecretbase:unknown"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!anothersecretbase:unknown"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_rooms_works_with_unknowns(self):
+        definition = {
+            "not_rooms": ["!secretbase:unknown"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!anothersecretbase:unknown"
+        )
+        self.assertTrue(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_not_rooms_takes_priority_over_rooms(self):
+        definition = {
+            "not_rooms": ["!secretbase:unknown"],
+            "rooms": ["!secretbase:unknown"]
+        }
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!secretbase:unknown"
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_combined_event(self):
+        definition = {
+            "not_senders": ["@misspiggy:muppets"],
+            "senders": ["@kermit:muppets"],
+            "rooms": ["!stage:unknown"],
+            "not_rooms": ["!piggyshouse:muppets"],
+            "types": ["m.room.message", "muppets.kermit.*"],
+            "not_types": ["muppets.misspiggy.*"]
+        }
+        event = MockEvent(
+            sender="@kermit:muppets",  # yup
+            type="m.room.message",  # yup
+            room_id="!stage:unknown"  # yup
+        )
+        self.assertTrue(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_combined_event_bad_sender(self):
+        definition = {
+            "not_senders": ["@misspiggy:muppets"],
+            "senders": ["@kermit:muppets"],
+            "rooms": ["!stage:unknown"],
+            "not_rooms": ["!piggyshouse:muppets"],
+            "types": ["m.room.message", "muppets.kermit.*"],
+            "not_types": ["muppets.misspiggy.*"]
+        }
+        event = MockEvent(
+            sender="@misspiggy:muppets",  # nope
+            type="m.room.message",  # yup
+            room_id="!stage:unknown"  # yup
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_combined_event_bad_room(self):
+        definition = {
+            "not_senders": ["@misspiggy:muppets"],
+            "senders": ["@kermit:muppets"],
+            "rooms": ["!stage:unknown"],
+            "not_rooms": ["!piggyshouse:muppets"],
+            "types": ["m.room.message", "muppets.kermit.*"],
+            "not_types": ["muppets.misspiggy.*"]
+        }
+        event = MockEvent(
+            sender="@kermit:muppets",  # yup
+            type="m.room.message",  # yup
+            room_id="!piggyshouse:muppets"  # nope
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    def test_definition_combined_event_bad_type(self):
+        definition = {
+            "not_senders": ["@misspiggy:muppets"],
+            "senders": ["@kermit:muppets"],
+            "rooms": ["!stage:unknown"],
+            "not_rooms": ["!piggyshouse:muppets"],
+            "types": ["m.room.message", "muppets.kermit.*"],
+            "not_types": ["muppets.misspiggy.*"]
+        }
+        event = MockEvent(
+            sender="@kermit:muppets",  # yup
+            type="muppets.misspiggy.kisses",  # nope
+            room_id="!stage:unknown"  # yup
+        )
+        self.assertFalse(
+            self.filter._passes_definition(definition, event)
+        )
+
+    @defer.inlineCallbacks
+    def test_filter_public_user_data_match(self):
+        user_filter_json = {
+            "public_user_data": {
+                "types": ["m.*"]
+            }
+        }
+        user = UserID.from_string("@" + user_localpart + ":test")
+        filter_id = yield self.datastore.add_user_filter(
+            user_localpart=user_localpart,
+            user_filter=user_filter_json,
+        )
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.profile",
+            room_id="!foo:bar"
+        )
+        events = [event]
+
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
+        )
+
+        results = user_filter.filter_public_user_data(events=events)
+        self.assertEquals(events, results)
+
+    @defer.inlineCallbacks
+    def test_filter_public_user_data_no_match(self):
+        user_filter_json = {
+            "public_user_data": {
+                "types": ["m.*"]
+            }
+        }
+        user = UserID.from_string("@" + user_localpart + ":test")
+        filter_id = yield self.datastore.add_user_filter(
+            user_localpart=user_localpart,
+            user_filter=user_filter_json,
+        )
+        event = MockEvent(
+            sender="@foo:bar",
+            type="custom.avatar.3d.crazy",
+            room_id="!foo:bar"
+        )
+        events = [event]
+
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
+        )
+
+        results = user_filter.filter_public_user_data(events=events)
+        self.assertEquals([], results)
+
+    @defer.inlineCallbacks
+    def test_filter_room_state_match(self):
+        user_filter_json = {
+            "room": {
+                "state": {
+                    "types": ["m.*"]
+                }
+            }
+        }
+        user = UserID.from_string("@" + user_localpart + ":test")
+        filter_id = yield self.datastore.add_user_filter(
+            user_localpart=user_localpart,
+            user_filter=user_filter_json,
+        )
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.topic",
+            room_id="!foo:bar"
+        )
+        events = [event]
+
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
+        )
+
+        results = user_filter.filter_room_state(events=events)
+        self.assertEquals(events, results)
+
+    @defer.inlineCallbacks
+    def test_filter_room_state_no_match(self):
+        user_filter_json = {
+            "room": {
+                "state": {
+                    "types": ["m.*"]
+                }
+            }
+        }
+        user = UserID.from_string("@" + user_localpart + ":test")
+        filter_id = yield self.datastore.add_user_filter(
+            user_localpart=user_localpart,
+            user_filter=user_filter_json,
+        )
+        event = MockEvent(
+            sender="@foo:bar",
+            type="org.matrix.custom.event",
+            room_id="!foo:bar"
+        )
+        events = [event]
+
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
+        )
+
+        results = user_filter.filter_room_state(events)
+        self.assertEquals([], results)
+
+    @defer.inlineCallbacks
+    def test_add_filter(self):
+        user_filter_json = {
+            "room": {
+                "state": {
+                    "types": ["m.*"]
+                }
+            }
+        }
+
+        filter_id = yield self.filtering.add_user_filter(
+            user_localpart=user_localpart,
+            user_filter=user_filter_json,
+        )
+
+        self.assertEquals(filter_id, 0)
+        self.assertEquals(user_filter_json,
+            (yield self.datastore.get_user_filter(
+                user_localpart=user_localpart,
+                filter_id=0,
+            ))
+        )
+
+    @defer.inlineCallbacks
+    def test_get_filter(self):
+        user_filter_json = {
+            "room": {
+                "state": {
+                    "types": ["m.*"]
+                }
+            }
+        }
+
+        filter_id = yield self.datastore.add_user_filter(
+            user_localpart=user_localpart,
+            user_filter=user_filter_json,
+        )
+
+        filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
+        )
+
+        self.assertEquals(filter.filter_json, user_filter_json)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index ed21defd13..44dbce6bea 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
                 "get_room",
                 "get_destination_retry_timings",
                 "set_destination_retry_timings",
+                "have_events",
             ]),
             resource_for_federation=NonCallableMock(),
             http_client=NonCallableMock(spec_set=[]),
@@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase):
         self.datastore.persist_event.return_value = defer.succeed(None)
         self.datastore.get_room.return_value = defer.succeed(True)
         self.auth.check_host_in_room.return_value = defer.succeed(True)
+        self.datastore.have_events.return_value = defer.succeed({})
 
         def annotate(ev, old_state=None):
             context = Mock()
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 65d5cc4916..f849120a3e 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -75,6 +75,7 @@ class PresenceStateTestCase(unittest.TestCase):
                 "user": UserID.from_string(myid),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
 
         hs.get_auth().get_user_by_token = _get_user_by_token
@@ -165,6 +166,7 @@ class PresenceListTestCase(unittest.TestCase):
                 "user": UserID.from_string(myid),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
 
         hs.handlers.room_member_handler = Mock(
@@ -282,7 +284,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
         hs.get_clock().time_msec.return_value = 1000000
 
         def _get_user_by_req(req=None):
-            return UserID.from_string(myid)
+            return (UserID.from_string(myid), "")
 
         hs.get_auth().get_user_by_req = _get_user_by_req
 
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 39cd68d829..6a2085276a 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -58,7 +58,7 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         def _get_user_by_req(request=None):
-            return UserID.from_string(myid)
+            return (UserID.from_string(myid), "")
 
         hs.get_auth().get_user_by_req = _get_user_by_req
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 76ed550b75..81ead10e76 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -70,6 +70,7 @@ class RoomPermissionsTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
@@ -466,6 +467,7 @@ class RoomsMemberListTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
@@ -555,6 +557,7 @@ class RoomsCreateTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
@@ -657,6 +660,7 @@ class RoomTopicTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
 
         hs.get_auth().get_user_by_token = _get_user_by_token
@@ -773,6 +777,7 @@ class RoomMemberStateTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
@@ -909,6 +914,7 @@ class RoomMessagesTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
@@ -1013,6 +1019,7 @@ class RoomInitialSyncTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index c89b37d004..c5d5b06da3 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -73,6 +73,7 @@ class RoomTypingTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
 
         hs.get_auth().get_user_by_token = _get_user_by_token
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index f59745e13c..fa70575c57 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -39,9 +39,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
 
         hs = HomeServer("test",
             db_pool=None,
-            datastore=Mock(spec=[
-                "insert_client_ip",
-            ]),
+            datastore=self.make_datastore_mock(),
             http_client=None,
             resource_for_client=self.mock_resource,
             resource_for_federation=self.mock_resource,
@@ -53,8 +51,14 @@ class V2AlphaRestTestCase(unittest.TestCase):
                 "user": UserID.from_string(self.USER_ID),
                 "admin": False,
                 "device_id": None,
+                "token_id": 1,
             }
         hs.get_auth().get_user_by_token = _get_user_by_token
 
         for r in self.TO_REGISTER:
             r.register_servlets(hs, self.mock_resource)
+
+    def make_datastore_mock(self):
+        return Mock(spec=[
+            "insert_client_ip",
+        ])
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
new file mode 100644
index 0000000000..80ddabf818
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -0,0 +1,95 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from mock import Mock
+
+from . import V2AlphaRestTestCase
+
+from synapse.rest.client.v2_alpha import filter
+
+from synapse.api.errors import StoreError
+
+
+class FilterTestCase(V2AlphaRestTestCase):
+    USER_ID = "@apple:test"
+    TO_REGISTER = [filter]
+
+    def make_datastore_mock(self):
+        datastore = super(FilterTestCase, self).make_datastore_mock()
+
+        self._user_filters = {}
+
+        def add_user_filter(user_localpart, definition):
+            filters = self._user_filters.setdefault(user_localpart, [])
+            filter_id = len(filters)
+            filters.append(definition)
+            return defer.succeed(filter_id)
+        datastore.add_user_filter = add_user_filter
+
+        def get_user_filter(user_localpart, filter_id):
+            if user_localpart not in self._user_filters:
+                raise StoreError(404, "No user")
+            filters = self._user_filters[user_localpart]
+            if filter_id >= len(filters):
+                raise StoreError(404, "No filter")
+            return defer.succeed(filters[filter_id])
+        datastore.get_user_filter = get_user_filter
+
+        return datastore
+
+    @defer.inlineCallbacks
+    def test_add_filter(self):
+        (code, response) = yield self.mock_resource.trigger("POST",
+            "/user/%s/filter" % (self.USER_ID),
+            '{"type": ["m.*"]}'
+        )
+        self.assertEquals(200, code)
+        self.assertEquals({"filter_id": "0"}, response)
+
+        self.assertIn("apple", self._user_filters)
+        self.assertEquals(len(self._user_filters["apple"]), 1)
+        self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
+
+    @defer.inlineCallbacks
+    def test_get_filter(self):
+        self._user_filters["apple"] = [
+            {"type": ["m.*"]}
+        ]
+
+        (code, response) = yield self.mock_resource.trigger("GET",
+            "/user/%s/filter/0" % (self.USER_ID), None
+        )
+        self.assertEquals(200, code)
+        self.assertEquals({"type": ["m.*"]}, response)
+
+    @defer.inlineCallbacks
+    def test_get_filter_no_id(self):
+        self._user_filters["apple"] = [
+            {"type": ["m.*"]}
+        ]
+
+        (code, response) = yield self.mock_resource.trigger("GET",
+            "/user/%s/filter/2" % (self.USER_ID), None
+        )
+        self.assertEquals(404, code)
+
+    @defer.inlineCallbacks
+    def test_get_filter_no_user(self):
+        (code, response) = yield self.mock_resource.trigger("GET",
+            "/user/%s/filter/0" % (self.USER_ID), None
+        )
+        self.assertEquals(404, code)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 84bfde7568..6f8bea2f61 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -53,7 +53,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            {"admin": 0, "device_id": None, "name": self.user_id},
+            {"admin": 0,
+             "device_id": None,
+             "name": self.user_id,
+             "token_id": 1},
             (yield self.store.get_user_by_token(self.tokens[0]))
         )
 
@@ -63,7 +66,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
         yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
 
         self.assertEquals(
-            {"admin": 0, "device_id": None, "name": self.user_id},
+            {"admin": 0,
+             "device_id": None,
+             "name": self.user_id,
+             "token_id": 2},
             (yield self.store.get_user_by_token(self.tokens[1]))
         )
 
diff --git a/tests/test_state.py b/tests/test_state.py
index 98ad9e54cd..019e794aa2 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -16,11 +16,120 @@
 from tests import unittest
 from twisted.internet import defer
 
+from synapse.events import FrozenEvent
+from synapse.api.auth import Auth
+from synapse.api.constants import EventTypes, Membership
 from synapse.state import StateHandler
 
 from mock import Mock
 
 
+_next_event_id = 1000
+
+
+def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
+                 prev_events=[], **kwargs):
+    global _next_event_id
+
+    if not event_id:
+        _next_event_id += 1
+        event_id = str(_next_event_id)
+
+    if not name:
+        if state_key is not None:
+            name = "<%s-%s, %s>" % (type, state_key, event_id,)
+        else:
+            name = "<%s, %s>" % (type, event_id,)
+
+    d = {
+        "event_id": event_id,
+        "type": type,
+        "sender": "@user_id:example.com",
+        "room_id": "!room_id:example.com",
+        "depth": depth,
+        "prev_events": prev_events,
+    }
+
+    if state_key is not None:
+        d["state_key"] = state_key
+
+    d.update(kwargs)
+
+    event = FrozenEvent(d)
+
+    return event
+
+
+class StateGroupStore(object):
+    def __init__(self):
+        self._event_to_state_group = {}
+        self._group_to_state = {}
+
+        self._next_group = 1
+
+    def get_state_groups(self, event_ids):
+        groups = {}
+        for event_id in event_ids:
+            group = self._event_to_state_group.get(event_id)
+            if group:
+                groups[group] = self._group_to_state[group]
+
+        return defer.succeed(groups)
+
+    def store_state_groups(self, event, context):
+        if context.current_state is None:
+            return
+
+        state_events = context.current_state
+
+        if event.is_state():
+            state_events[(event.type, event.state_key)] = event
+
+        state_group = context.state_group
+        if not state_group:
+            state_group = self._next_group
+            self._next_group += 1
+
+            self._group_to_state[state_group] = state_events.values()
+
+        self._event_to_state_group[event.event_id] = state_group
+
+
+class DictObj(dict):
+    def __init__(self, **kwargs):
+        super(DictObj, self).__init__(kwargs)
+        self.__dict__ = self
+
+
+class Graph(object):
+    def __init__(self, nodes, edges):
+        events = {}
+        clobbered = set(events.keys())
+
+        for event_id, fields in nodes.items():
+            refs = edges.get(event_id)
+            if refs:
+                clobbered.difference_update(refs)
+                prev_events = [(r, {}) for r in refs]
+            else:
+                prev_events = []
+
+            events[event_id] = create_event(
+                event_id=event_id,
+                prev_events=prev_events,
+                **fields
+            )
+
+        self._leaves = clobbered
+        self._events = sorted(events.values(), key=lambda e: e.depth)
+
+    def walk(self):
+        return iter(self._events)
+
+    def get_leaves(self):
+        return (self._events[i] for i in self._leaves)
+
+
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = Mock(
@@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
                 "add_event_hashes",
             ]
         )
-        hs = Mock(spec=["get_datastore"])
+        hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"])
         hs.get_datastore.return_value = self.store
+        hs.get_state_handler.return_value = None
+        hs.get_auth.return_value = Auth(hs)
 
         self.state = StateHandler(hs)
         self.event_id = 0
 
     @defer.inlineCallbacks
+    def test_branch_no_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="",
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Message,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Message,
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Message,
+                    depth=4,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["A"],
+                "D": ["B", "C"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertEqual(2, len(context_store["D"].current_state))
+
+    @defer.inlineCallbacks
+    def test_branch_basic_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="creator",
+                    content={"membership": "@user_id:example.com"},
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id:example.com",
+                    content={"membership": Membership.JOIN},
+                    membership=Membership.JOIN,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=4,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Message,
+                    depth=5,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["A"],
+                "D": ["B", "C"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertSetEqual(
+            {"START", "A", "C"},
+            {e.event_id for e in context_store["D"].current_state.values()}
+        )
+
+    @defer.inlineCallbacks
+    def test_branch_have_banned_conflict(self):
+        graph = Graph(
+            nodes={
+                "START": DictObj(
+                    type=EventTypes.Create,
+                    state_key="creator",
+                    content={"membership": "@user_id:example.com"},
+                    depth=1,
+                ),
+                "A": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id:example.com",
+                    content={"membership": Membership.JOIN},
+                    membership=Membership.JOIN,
+                    depth=2,
+                ),
+                "B": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=3,
+                ),
+                "C": DictObj(
+                    type=EventTypes.Member,
+                    state_key="@user_id_2:example.com",
+                    content={"membership": Membership.BAN},
+                    membership=Membership.BAN,
+                    depth=4,
+                ),
+                "D": DictObj(
+                    type=EventTypes.Name,
+                    state_key="",
+                    depth=4,
+                    sender="@user_id_2:example.com",
+                ),
+                "E": DictObj(
+                    type=EventTypes.Message,
+                    depth=5,
+                ),
+            },
+            edges={
+                "A": ["START"],
+                "B": ["A"],
+                "C": ["B"],
+                "D": ["B"],
+                "E": ["C", "D"]
+            }
+        )
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertSetEqual(
+            {"START", "A", "B", "C"},
+            {e.event_id for e in context_store["E"].current_state.values()}
+        )
+
+    @defer.inlineCallbacks
     def test_annotate_with_old_message(self):
-        event = self.create_event(type="test_message", name="event")
+        event = create_event(type="test_message", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         context = yield self.state.compute_event_context(
@@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
-        event = self.create_event(type="state", state_key="", name="event")
+        event = create_event(type="state", state_key="", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         context = yield self.state.compute_event_context(
@@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
-        event = self.create_event(type="test_message", name="event")
-        event.prev_events = []
+        event = create_event(type="test_message", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         group_name = "group_name_1"
@@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_state(self):
-        event = self.create_event(type="state", state_key="", name="event")
-        event.prev_events = []
+        event = create_event(type="state", state_key="", name="event")
 
         old_state = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         group_name = "group_name_1"
@@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
-        event = self.create_event(type="test_message", name="event")
-        event.prev_events = []
+        event = create_event(type="test_message", name="event")
 
         old_state_1 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test3", state_key="2"),
-            self.create_event(type="test4", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test3", state_key="2"),
+            create_event(type="test4", state_key=""),
         ]
 
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
-
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
-        }
-
-        context = yield self.state.compute_event_context(event)
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
         self.assertEqual(len(context.current_state), 5)
 
@@ -181,56 +447,76 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
-        event = self.create_event(type="test4", state_key="", name="event")
-        event.prev_events = []
+        event = create_event(type="test4", state_key="", name="event")
 
         old_state_1 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test1", state_key="2"),
-            self.create_event(type="test2", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test1", state_key="2"),
+            create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
-            self.create_event(type="test1", state_key="1"),
-            self.create_event(type="test3", state_key="2"),
-            self.create_event(type="test4", state_key=""),
+            create_event(type="test1", state_key="1"),
+            create_event(type="test3", state_key="2"),
+            create_event(type="test4", state_key=""),
         ]
 
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
-
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
-        }
-
-        context = yield self.state.compute_event_context(event)
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
         self.assertEqual(len(context.current_state), 5)
 
         self.assertIsNone(context.state_group)
 
-    def create_event(self, name=None, type=None, state_key=None):
-        self.event_id += 1
-        event_id = str(self.event_id)
+    @defer.inlineCallbacks
+    def test_standard_depth_conflict(self):
+        event = create_event(type="test4", name="event")
+
+        member_event = create_event(
+            type=EventTypes.Member,
+            state_key="@user_id:example.com",
+            content={
+                "membership": Membership.JOIN,
+            }
+        )
 
-        if not name:
-            if state_key is not None:
-                name = "<%s-%s>" % (type, state_key)
-            else:
-                name = "<%s>" % (type, )
+        old_state_1 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=1),
+        ]
+
+        old_state_2 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=2),
+        ]
 
-        event = Mock(name=name, spec=[])
-        event.type = type
+        context = yield self._get_context(event, old_state_1, old_state_2)
 
-        if state_key is not None:
-            event.state_key = state_key
-        event.event_id = event_id
+        self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+
+        # Reverse the depth to make sure we are actually using the depths
+        # during state resolution.
+
+        old_state_1 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=2),
+        ]
+
+        old_state_2 = [
+            member_event,
+            create_event(type="test1", state_key="1", depth=1),
+        ]
+
+        context = yield self._get_context(event, old_state_1, old_state_2)
+
+        self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
 
-        event.is_state = lambda: (state_key is not None)
-        event.unsigned = {}
+    def _get_context(self, event, old_state_1, old_state_2):
+        group_name_1 = "group_name_1"
+        group_name_2 = "group_name_2"
 
-        event.user_id = "@user_id:example.com"
-        event.room_id = "!room_id:example.com"
+        self.store.get_state_groups.return_value = {
+            group_name_1: old_state_1,
+            group_name_2: old_state_2,
+        }
 
-        return event
+        return self.state.compute_event_context(event)