diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
new file mode 100644
index 0000000000..babf4c37f1
--- /dev/null
+++ b/tests/api/test_filtering.py
@@ -0,0 +1,512 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import namedtuple
+from tests import unittest
+from twisted.internet import defer
+
+from mock import Mock, NonCallableMock
+from tests.utils import (
+ MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool,
+ MockKey
+)
+
+from synapse.server import HomeServer
+from synapse.types import UserID
+from synapse.api.filtering import Filter
+
+user_localpart = "test_user"
+MockEvent = namedtuple("MockEvent", "sender type room_id")
+
+class FilteringTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
+ self.mock_config = NonCallableMock()
+ self.mock_config.signing_key = [MockKey()]
+
+ self.mock_federation_resource = MockHttpResource()
+
+ self.mock_http_client = Mock(spec=[])
+ self.mock_http_client.put_json = DeferredMockCallable()
+
+ hs = HomeServer("test",
+ db_pool=db_pool,
+ handlers=None,
+ http_client=self.mock_http_client,
+ config=self.mock_config,
+ keyring=Mock(),
+ )
+
+ self.filtering = hs.get_filtering()
+ self.filter = Filter({})
+
+ self.datastore = hs.get_datastore()
+
+ def test_definition_types_works_with_literals(self):
+ definition = {
+ "types": ["m.room.message", "org.matrix.foo.bar"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!foo:bar"
+ )
+ self.assertTrue(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_types_works_with_wildcards(self):
+ definition = {
+ "types": ["m.*", "org.matrix.foo.bar"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!foo:bar"
+ )
+ self.assertTrue(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_types_works_with_unknowns(self):
+ definition = {
+ "types": ["m.room.message", "org.matrix.foo.bar"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="now.for.something.completely.different",
+ room_id="!foo:bar"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_types_works_with_literals(self):
+ definition = {
+ "not_types": ["m.room.message", "org.matrix.foo.bar"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!foo:bar"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_types_works_with_wildcards(self):
+ definition = {
+ "not_types": ["m.room.message", "org.matrix.*"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="org.matrix.custom.event",
+ room_id="!foo:bar"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_types_works_with_unknowns(self):
+ definition = {
+ "not_types": ["m.*", "org.*"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="com.nom.nom.nom",
+ room_id="!foo:bar"
+ )
+ self.assertTrue(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_types_takes_priority_over_types(self):
+ definition = {
+ "not_types": ["m.*", "org.*"],
+ "types": ["m.room.message", "m.room.topic"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.topic",
+ room_id="!foo:bar"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_senders_works_with_literals(self):
+ definition = {
+ "senders": ["@flibble:wibble"]
+ }
+ event = MockEvent(
+ sender="@flibble:wibble",
+ type="com.nom.nom.nom",
+ room_id="!foo:bar"
+ )
+ self.assertTrue(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_senders_works_with_unknowns(self):
+ definition = {
+ "senders": ["@flibble:wibble"]
+ }
+ event = MockEvent(
+ sender="@challenger:appears",
+ type="com.nom.nom.nom",
+ room_id="!foo:bar"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_senders_works_with_literals(self):
+ definition = {
+ "not_senders": ["@flibble:wibble"]
+ }
+ event = MockEvent(
+ sender="@flibble:wibble",
+ type="com.nom.nom.nom",
+ room_id="!foo:bar"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_senders_works_with_unknowns(self):
+ definition = {
+ "not_senders": ["@flibble:wibble"]
+ }
+ event = MockEvent(
+ sender="@challenger:appears",
+ type="com.nom.nom.nom",
+ room_id="!foo:bar"
+ )
+ self.assertTrue(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_senders_takes_priority_over_senders(self):
+ definition = {
+ "not_senders": ["@misspiggy:muppets"],
+ "senders": ["@kermit:muppets", "@misspiggy:muppets"]
+ }
+ event = MockEvent(
+ sender="@misspiggy:muppets",
+ type="m.room.topic",
+ room_id="!foo:bar"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_rooms_works_with_literals(self):
+ definition = {
+ "rooms": ["!secretbase:unknown"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown"
+ )
+ self.assertTrue(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_rooms_works_with_unknowns(self):
+ definition = {
+ "rooms": ["!secretbase:unknown"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!anothersecretbase:unknown"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_rooms_works_with_literals(self):
+ definition = {
+ "not_rooms": ["!anothersecretbase:unknown"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!anothersecretbase:unknown"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_rooms_works_with_unknowns(self):
+ definition = {
+ "not_rooms": ["!secretbase:unknown"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!anothersecretbase:unknown"
+ )
+ self.assertTrue(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_not_rooms_takes_priority_over_rooms(self):
+ definition = {
+ "not_rooms": ["!secretbase:unknown"],
+ "rooms": ["!secretbase:unknown"]
+ }
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown"
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_combined_event(self):
+ definition = {
+ "not_senders": ["@misspiggy:muppets"],
+ "senders": ["@kermit:muppets"],
+ "rooms": ["!stage:unknown"],
+ "not_rooms": ["!piggyshouse:muppets"],
+ "types": ["m.room.message", "muppets.kermit.*"],
+ "not_types": ["muppets.misspiggy.*"]
+ }
+ event = MockEvent(
+ sender="@kermit:muppets", # yup
+ type="m.room.message", # yup
+ room_id="!stage:unknown" # yup
+ )
+ self.assertTrue(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_combined_event_bad_sender(self):
+ definition = {
+ "not_senders": ["@misspiggy:muppets"],
+ "senders": ["@kermit:muppets"],
+ "rooms": ["!stage:unknown"],
+ "not_rooms": ["!piggyshouse:muppets"],
+ "types": ["m.room.message", "muppets.kermit.*"],
+ "not_types": ["muppets.misspiggy.*"]
+ }
+ event = MockEvent(
+ sender="@misspiggy:muppets", # nope
+ type="m.room.message", # yup
+ room_id="!stage:unknown" # yup
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_combined_event_bad_room(self):
+ definition = {
+ "not_senders": ["@misspiggy:muppets"],
+ "senders": ["@kermit:muppets"],
+ "rooms": ["!stage:unknown"],
+ "not_rooms": ["!piggyshouse:muppets"],
+ "types": ["m.room.message", "muppets.kermit.*"],
+ "not_types": ["muppets.misspiggy.*"]
+ }
+ event = MockEvent(
+ sender="@kermit:muppets", # yup
+ type="m.room.message", # yup
+ room_id="!piggyshouse:muppets" # nope
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ def test_definition_combined_event_bad_type(self):
+ definition = {
+ "not_senders": ["@misspiggy:muppets"],
+ "senders": ["@kermit:muppets"],
+ "rooms": ["!stage:unknown"],
+ "not_rooms": ["!piggyshouse:muppets"],
+ "types": ["m.room.message", "muppets.kermit.*"],
+ "not_types": ["muppets.misspiggy.*"]
+ }
+ event = MockEvent(
+ sender="@kermit:muppets", # yup
+ type="muppets.misspiggy.kisses", # nope
+ room_id="!stage:unknown" # yup
+ )
+ self.assertFalse(
+ self.filter._passes_definition(definition, event)
+ )
+
+ @defer.inlineCallbacks
+ def test_filter_public_user_data_match(self):
+ user_filter_json = {
+ "public_user_data": {
+ "types": ["m.*"]
+ }
+ }
+ user = UserID.from_string("@" + user_localpart + ":test")
+ filter_id = yield self.datastore.add_user_filter(
+ user_localpart=user_localpart,
+ user_filter=user_filter_json,
+ )
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.profile",
+ room_id="!foo:bar"
+ )
+ events = [event]
+
+ user_filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
+ )
+
+ results = user_filter.filter_public_user_data(events=events)
+ self.assertEquals(events, results)
+
+ @defer.inlineCallbacks
+ def test_filter_public_user_data_no_match(self):
+ user_filter_json = {
+ "public_user_data": {
+ "types": ["m.*"]
+ }
+ }
+ user = UserID.from_string("@" + user_localpart + ":test")
+ filter_id = yield self.datastore.add_user_filter(
+ user_localpart=user_localpart,
+ user_filter=user_filter_json,
+ )
+ event = MockEvent(
+ sender="@foo:bar",
+ type="custom.avatar.3d.crazy",
+ room_id="!foo:bar"
+ )
+ events = [event]
+
+ user_filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
+ )
+
+ results = user_filter.filter_public_user_data(events=events)
+ self.assertEquals([], results)
+
+ @defer.inlineCallbacks
+ def test_filter_room_state_match(self):
+ user_filter_json = {
+ "room": {
+ "state": {
+ "types": ["m.*"]
+ }
+ }
+ }
+ user = UserID.from_string("@" + user_localpart + ":test")
+ filter_id = yield self.datastore.add_user_filter(
+ user_localpart=user_localpart,
+ user_filter=user_filter_json,
+ )
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.topic",
+ room_id="!foo:bar"
+ )
+ events = [event]
+
+ user_filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
+ )
+
+ results = user_filter.filter_room_state(events=events)
+ self.assertEquals(events, results)
+
+ @defer.inlineCallbacks
+ def test_filter_room_state_no_match(self):
+ user_filter_json = {
+ "room": {
+ "state": {
+ "types": ["m.*"]
+ }
+ }
+ }
+ user = UserID.from_string("@" + user_localpart + ":test")
+ filter_id = yield self.datastore.add_user_filter(
+ user_localpart=user_localpart,
+ user_filter=user_filter_json,
+ )
+ event = MockEvent(
+ sender="@foo:bar",
+ type="org.matrix.custom.event",
+ room_id="!foo:bar"
+ )
+ events = [event]
+
+ user_filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
+ )
+
+ results = user_filter.filter_room_state(events)
+ self.assertEquals([], results)
+
+ @defer.inlineCallbacks
+ def test_add_filter(self):
+ user_filter_json = {
+ "room": {
+ "state": {
+ "types": ["m.*"]
+ }
+ }
+ }
+
+ filter_id = yield self.filtering.add_user_filter(
+ user_localpart=user_localpart,
+ user_filter=user_filter_json,
+ )
+
+ self.assertEquals(filter_id, 0)
+ self.assertEquals(user_filter_json,
+ (yield self.datastore.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=0,
+ ))
+ )
+
+ @defer.inlineCallbacks
+ def test_get_filter(self):
+ user_filter_json = {
+ "room": {
+ "state": {
+ "types": ["m.*"]
+ }
+ }
+ }
+
+ filter_id = yield self.datastore.add_user_filter(
+ user_localpart=user_localpart,
+ user_filter=user_filter_json,
+ )
+
+ filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
+ )
+
+ self.assertEquals(filter.filter_json, user_filter_json)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index ed21defd13..44dbce6bea 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
"get_room",
"get_destination_retry_timings",
"set_destination_retry_timings",
+ "have_events",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
@@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase):
self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True)
+ self.datastore.have_events.return_value = defer.succeed({})
def annotate(ev, old_state=None):
context = Mock()
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 65d5cc4916..f849120a3e 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -75,6 +75,7 @@ class PresenceStateTestCase(unittest.TestCase):
"user": UserID.from_string(myid),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@@ -165,6 +166,7 @@ class PresenceListTestCase(unittest.TestCase):
"user": UserID.from_string(myid),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.handlers.room_member_handler = Mock(
@@ -282,7 +284,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None):
- return UserID.from_string(myid)
+ return (UserID.from_string(myid), "")
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 39cd68d829..6a2085276a 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -58,7 +58,7 @@ class ProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None):
- return UserID.from_string(myid)
+ return (UserID.from_string(myid), "")
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 76ed550b75..81ead10e76 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -70,6 +70,7 @@ class RoomPermissionsTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@@ -466,6 +467,7 @@ class RoomsMemberListTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@@ -555,6 +557,7 @@ class RoomsCreateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@@ -657,6 +660,7 @@ class RoomTopicTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@@ -773,6 +777,7 @@ class RoomMemberStateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@@ -909,6 +914,7 @@ class RoomMessagesTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@@ -1013,6 +1019,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index c89b37d004..c5d5b06da3 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -73,6 +73,7 @@ class RoomTypingTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index f59745e13c..fa70575c57 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -39,9 +39,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
hs = HomeServer("test",
db_pool=None,
- datastore=Mock(spec=[
- "insert_client_ip",
- ]),
+ datastore=self.make_datastore_mock(),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
@@ -53,8 +51,14 @@ class V2AlphaRestTestCase(unittest.TestCase):
"user": UserID.from_string(self.USER_ID),
"admin": False,
"device_id": None,
+ "token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)
+
+ def make_datastore_mock(self):
+ return Mock(spec=[
+ "insert_client_ip",
+ ])
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
new file mode 100644
index 0000000000..80ddabf818
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -0,0 +1,95 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from mock import Mock
+
+from . import V2AlphaRestTestCase
+
+from synapse.rest.client.v2_alpha import filter
+
+from synapse.api.errors import StoreError
+
+
+class FilterTestCase(V2AlphaRestTestCase):
+ USER_ID = "@apple:test"
+ TO_REGISTER = [filter]
+
+ def make_datastore_mock(self):
+ datastore = super(FilterTestCase, self).make_datastore_mock()
+
+ self._user_filters = {}
+
+ def add_user_filter(user_localpart, definition):
+ filters = self._user_filters.setdefault(user_localpart, [])
+ filter_id = len(filters)
+ filters.append(definition)
+ return defer.succeed(filter_id)
+ datastore.add_user_filter = add_user_filter
+
+ def get_user_filter(user_localpart, filter_id):
+ if user_localpart not in self._user_filters:
+ raise StoreError(404, "No user")
+ filters = self._user_filters[user_localpart]
+ if filter_id >= len(filters):
+ raise StoreError(404, "No filter")
+ return defer.succeed(filters[filter_id])
+ datastore.get_user_filter = get_user_filter
+
+ return datastore
+
+ @defer.inlineCallbacks
+ def test_add_filter(self):
+ (code, response) = yield self.mock_resource.trigger("POST",
+ "/user/%s/filter" % (self.USER_ID),
+ '{"type": ["m.*"]}'
+ )
+ self.assertEquals(200, code)
+ self.assertEquals({"filter_id": "0"}, response)
+
+ self.assertIn("apple", self._user_filters)
+ self.assertEquals(len(self._user_filters["apple"]), 1)
+ self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
+
+ @defer.inlineCallbacks
+ def test_get_filter(self):
+ self._user_filters["apple"] = [
+ {"type": ["m.*"]}
+ ]
+
+ (code, response) = yield self.mock_resource.trigger("GET",
+ "/user/%s/filter/0" % (self.USER_ID), None
+ )
+ self.assertEquals(200, code)
+ self.assertEquals({"type": ["m.*"]}, response)
+
+ @defer.inlineCallbacks
+ def test_get_filter_no_id(self):
+ self._user_filters["apple"] = [
+ {"type": ["m.*"]}
+ ]
+
+ (code, response) = yield self.mock_resource.trigger("GET",
+ "/user/%s/filter/2" % (self.USER_ID), None
+ )
+ self.assertEquals(404, code)
+
+ @defer.inlineCallbacks
+ def test_get_filter_no_user(self):
+ (code, response) = yield self.mock_resource.trigger("GET",
+ "/user/%s/filter/0" % (self.USER_ID), None
+ )
+ self.assertEquals(404, code)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 84bfde7568..6f8bea2f61 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -53,7 +53,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
self.assertEquals(
- {"admin": 0, "device_id": None, "name": self.user_id},
+ {"admin": 0,
+ "device_id": None,
+ "name": self.user_id,
+ "token_id": 1},
(yield self.store.get_user_by_token(self.tokens[0]))
)
@@ -63,7 +66,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
self.assertEquals(
- {"admin": 0, "device_id": None, "name": self.user_id},
+ {"admin": 0,
+ "device_id": None,
+ "name": self.user_id,
+ "token_id": 2},
(yield self.store.get_user_by_token(self.tokens[1]))
)
diff --git a/tests/test_state.py b/tests/test_state.py
index 98ad9e54cd..019e794aa2 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -16,11 +16,120 @@
from tests import unittest
from twisted.internet import defer
+from synapse.events import FrozenEvent
+from synapse.api.auth import Auth
+from synapse.api.constants import EventTypes, Membership
from synapse.state import StateHandler
from mock import Mock
+_next_event_id = 1000
+
+
+def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
+ prev_events=[], **kwargs):
+ global _next_event_id
+
+ if not event_id:
+ _next_event_id += 1
+ event_id = str(_next_event_id)
+
+ if not name:
+ if state_key is not None:
+ name = "<%s-%s, %s>" % (type, state_key, event_id,)
+ else:
+ name = "<%s, %s>" % (type, event_id,)
+
+ d = {
+ "event_id": event_id,
+ "type": type,
+ "sender": "@user_id:example.com",
+ "room_id": "!room_id:example.com",
+ "depth": depth,
+ "prev_events": prev_events,
+ }
+
+ if state_key is not None:
+ d["state_key"] = state_key
+
+ d.update(kwargs)
+
+ event = FrozenEvent(d)
+
+ return event
+
+
+class StateGroupStore(object):
+ def __init__(self):
+ self._event_to_state_group = {}
+ self._group_to_state = {}
+
+ self._next_group = 1
+
+ def get_state_groups(self, event_ids):
+ groups = {}
+ for event_id in event_ids:
+ group = self._event_to_state_group.get(event_id)
+ if group:
+ groups[group] = self._group_to_state[group]
+
+ return defer.succeed(groups)
+
+ def store_state_groups(self, event, context):
+ if context.current_state is None:
+ return
+
+ state_events = context.current_state
+
+ if event.is_state():
+ state_events[(event.type, event.state_key)] = event
+
+ state_group = context.state_group
+ if not state_group:
+ state_group = self._next_group
+ self._next_group += 1
+
+ self._group_to_state[state_group] = state_events.values()
+
+ self._event_to_state_group[event.event_id] = state_group
+
+
+class DictObj(dict):
+ def __init__(self, **kwargs):
+ super(DictObj, self).__init__(kwargs)
+ self.__dict__ = self
+
+
+class Graph(object):
+ def __init__(self, nodes, edges):
+ events = {}
+ clobbered = set(events.keys())
+
+ for event_id, fields in nodes.items():
+ refs = edges.get(event_id)
+ if refs:
+ clobbered.difference_update(refs)
+ prev_events = [(r, {}) for r in refs]
+ else:
+ prev_events = []
+
+ events[event_id] = create_event(
+ event_id=event_id,
+ prev_events=prev_events,
+ **fields
+ )
+
+ self._leaves = clobbered
+ self._events = sorted(events.values(), key=lambda e: e.depth)
+
+ def walk(self):
+ return iter(self._events)
+
+ def get_leaves(self):
+ return (self._events[i] for i in self._leaves)
+
+
class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = Mock(
@@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
"add_event_hashes",
]
)
- hs = Mock(spec=["get_datastore"])
+ hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"])
hs.get_datastore.return_value = self.store
+ hs.get_state_handler.return_value = None
+ hs.get_auth.return_value = Auth(hs)
self.state = StateHandler(hs)
self.event_id = 0
@defer.inlineCallbacks
+ def test_branch_no_conflict(self):
+ graph = Graph(
+ nodes={
+ "START": DictObj(
+ type=EventTypes.Create,
+ state_key="",
+ depth=1,
+ ),
+ "A": DictObj(
+ type=EventTypes.Message,
+ depth=2,
+ ),
+ "B": DictObj(
+ type=EventTypes.Message,
+ depth=3,
+ ),
+ "C": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=3,
+ ),
+ "D": DictObj(
+ type=EventTypes.Message,
+ depth=4,
+ ),
+ },
+ edges={
+ "A": ["START"],
+ "B": ["A"],
+ "C": ["A"],
+ "D": ["B", "C"]
+ }
+ )
+
+ store = StateGroupStore()
+ self.store.get_state_groups.side_effect = store.get_state_groups
+
+ context_store = {}
+
+ for event in graph.walk():
+ context = yield self.state.compute_event_context(event)
+ store.store_state_groups(event, context)
+ context_store[event.event_id] = context
+
+ self.assertEqual(2, len(context_store["D"].current_state))
+
+ @defer.inlineCallbacks
+ def test_branch_basic_conflict(self):
+ graph = Graph(
+ nodes={
+ "START": DictObj(
+ type=EventTypes.Create,
+ state_key="creator",
+ content={"membership": "@user_id:example.com"},
+ depth=1,
+ ),
+ "A": DictObj(
+ type=EventTypes.Member,
+ state_key="@user_id:example.com",
+ content={"membership": Membership.JOIN},
+ membership=Membership.JOIN,
+ depth=2,
+ ),
+ "B": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=3,
+ ),
+ "C": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=4,
+ ),
+ "D": DictObj(
+ type=EventTypes.Message,
+ depth=5,
+ ),
+ },
+ edges={
+ "A": ["START"],
+ "B": ["A"],
+ "C": ["A"],
+ "D": ["B", "C"]
+ }
+ )
+
+ store = StateGroupStore()
+ self.store.get_state_groups.side_effect = store.get_state_groups
+
+ context_store = {}
+
+ for event in graph.walk():
+ context = yield self.state.compute_event_context(event)
+ store.store_state_groups(event, context)
+ context_store[event.event_id] = context
+
+ self.assertSetEqual(
+ {"START", "A", "C"},
+ {e.event_id for e in context_store["D"].current_state.values()}
+ )
+
+ @defer.inlineCallbacks
+ def test_branch_have_banned_conflict(self):
+ graph = Graph(
+ nodes={
+ "START": DictObj(
+ type=EventTypes.Create,
+ state_key="creator",
+ content={"membership": "@user_id:example.com"},
+ depth=1,
+ ),
+ "A": DictObj(
+ type=EventTypes.Member,
+ state_key="@user_id:example.com",
+ content={"membership": Membership.JOIN},
+ membership=Membership.JOIN,
+ depth=2,
+ ),
+ "B": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=3,
+ ),
+ "C": DictObj(
+ type=EventTypes.Member,
+ state_key="@user_id_2:example.com",
+ content={"membership": Membership.BAN},
+ membership=Membership.BAN,
+ depth=4,
+ ),
+ "D": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=4,
+ sender="@user_id_2:example.com",
+ ),
+ "E": DictObj(
+ type=EventTypes.Message,
+ depth=5,
+ ),
+ },
+ edges={
+ "A": ["START"],
+ "B": ["A"],
+ "C": ["B"],
+ "D": ["B"],
+ "E": ["C", "D"]
+ }
+ )
+
+ store = StateGroupStore()
+ self.store.get_state_groups.side_effect = store.get_state_groups
+
+ context_store = {}
+
+ for event in graph.walk():
+ context = yield self.state.compute_event_context(event)
+ store.store_state_groups(event, context)
+ context_store[event.event_id] = context
+
+ self.assertSetEqual(
+ {"START", "A", "B", "C"},
+ {e.event_id for e in context_store["E"].current_state.values()}
+ )
+
+ @defer.inlineCallbacks
def test_annotate_with_old_message(self):
- event = self.create_event(type="test_message", name="event")
+ event = create_event(type="test_message", name="event")
old_state = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
context = yield self.state.compute_event_context(
@@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
- event = self.create_event(type="state", state_key="", name="event")
+ event = create_event(type="state", state_key="", name="event")
old_state = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
context = yield self.state.compute_event_context(
@@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
- event = self.create_event(type="test_message", name="event")
- event.prev_events = []
+ event = create_event(type="test_message", name="event")
old_state = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
group_name = "group_name_1"
@@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_trivial_annotate_state(self):
- event = self.create_event(type="state", state_key="", name="event")
- event.prev_events = []
+ event = create_event(type="state", state_key="", name="event")
old_state = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
group_name = "group_name_1"
@@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
- event = self.create_event(type="test_message", name="event")
- event.prev_events = []
+ event = create_event(type="test_message", name="event")
old_state_1 = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
old_state_2 = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test3", state_key="2"),
- self.create_event(type="test4", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test3", state_key="2"),
+ create_event(type="test4", state_key=""),
]
- group_name_1 = "group_name_1"
- group_name_2 = "group_name_2"
-
- self.store.get_state_groups.return_value = {
- group_name_1: old_state_1,
- group_name_2: old_state_2,
- }
-
- context = yield self.state.compute_event_context(event)
+ context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5)
@@ -181,56 +447,76 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
- event = self.create_event(type="test4", state_key="", name="event")
- event.prev_events = []
+ event = create_event(type="test4", state_key="", name="event")
old_state_1 = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
old_state_2 = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test3", state_key="2"),
- self.create_event(type="test4", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test3", state_key="2"),
+ create_event(type="test4", state_key=""),
]
- group_name_1 = "group_name_1"
- group_name_2 = "group_name_2"
-
- self.store.get_state_groups.return_value = {
- group_name_1: old_state_1,
- group_name_2: old_state_2,
- }
-
- context = yield self.state.compute_event_context(event)
+ context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5)
self.assertIsNone(context.state_group)
- def create_event(self, name=None, type=None, state_key=None):
- self.event_id += 1
- event_id = str(self.event_id)
+ @defer.inlineCallbacks
+ def test_standard_depth_conflict(self):
+ event = create_event(type="test4", name="event")
+
+ member_event = create_event(
+ type=EventTypes.Member,
+ state_key="@user_id:example.com",
+ content={
+ "membership": Membership.JOIN,
+ }
+ )
- if not name:
- if state_key is not None:
- name = "<%s-%s>" % (type, state_key)
- else:
- name = "<%s>" % (type, )
+ old_state_1 = [
+ member_event,
+ create_event(type="test1", state_key="1", depth=1),
+ ]
+
+ old_state_2 = [
+ member_event,
+ create_event(type="test1", state_key="1", depth=2),
+ ]
- event = Mock(name=name, spec=[])
- event.type = type
+ context = yield self._get_context(event, old_state_1, old_state_2)
- if state_key is not None:
- event.state_key = state_key
- event.event_id = event_id
+ self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+
+ # Reverse the depth to make sure we are actually using the depths
+ # during state resolution.
+
+ old_state_1 = [
+ member_event,
+ create_event(type="test1", state_key="1", depth=2),
+ ]
+
+ old_state_2 = [
+ member_event,
+ create_event(type="test1", state_key="1", depth=1),
+ ]
+
+ context = yield self._get_context(event, old_state_1, old_state_2)
+
+ self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
- event.is_state = lambda: (state_key is not None)
- event.unsigned = {}
+ def _get_context(self, event, old_state_1, old_state_2):
+ group_name_1 = "group_name_1"
+ group_name_2 = "group_name_2"
- event.user_id = "@user_id:example.com"
- event.room_id = "!room_id:example.com"
+ self.store.get_state_groups.return_value = {
+ group_name_1: old_state_1,
+ group_name_2: old_state_2,
+ }
- return event
+ return self.state.compute_event_context(event)
|