diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/api/test_auth.py | 145 | ||||
-rw-r--r-- | tests/rest/client/v1/test_presence.py | 4 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 28 | ||||
-rw-r--r-- | tests/rest/client/v1/test_typing.py | 2 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 3 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/__init__.py | 2 | ||||
-rw-r--r-- | tests/storage/event_injector.py | 81 | ||||
-rw-r--r-- | tests/storage/test_events.py | 116 | ||||
-rw-r--r-- | tests/storage/test_room.py | 2 | ||||
-rw-r--r-- | tests/storage/test_stream.py | 68 | ||||
-rw-r--r-- | tests/test_state.py | 37 |
11 files changed, 403 insertions, 85 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 22fc804331..c96273480d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -19,17 +19,21 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.errors import AuthError +from synapse.types import UserID +from tests.utils import setup_test_homeserver + +import pymacaroons class AuthTestCase(unittest.TestCase): + @defer.inlineCallbacks def setUp(self): self.state_handler = Mock() self.store = Mock() - self.hs = Mock() + self.hs = yield setup_test_homeserver(handlers=None) self.hs.get_datastore = Mock(return_value=self.store) - self.hs.get_state_handler = Mock(return_value=self.state_handler) self.auth = Auth(self.hs) self.test_user = "@foo:bar" @@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = Mock(return_value=[""]) d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) + + @defer.inlineCallbacks + def test_get_user_from_macaroon(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize()) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_user_db_mismatch(self): + self.store.get_user_by_access_token = Mock( + return_value={"name": "@percy:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("User mismatch", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_missing_caveat(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("No user caveat", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_wrong_key(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key + "wrong") + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_unknown_caveat(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("cunning > fox") + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_expired(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("time < 1") # ms + + self.hs.clock.now = 5000 # seconds + + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # TODO(daniel): Turn on the check that we validate expiration, when we + # validate expiration (and remove the above line, which will start + # throwing). + # with self.assertRaises(AuthError) as cm: + # yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # self.assertEqual(401, cm.exception.code) + # self.assertIn("Invalid macaroon", cm.exception.msg) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 91547bdd06..2ee3da0b34 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -76,7 +76,7 @@ class PresenceStateTestCase(unittest.TestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ @@ -169,7 +169,7 @@ class PresenceListTestCase(unittest.TestCase): ] ) - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token presence.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 34ab47d02e..a2123be81b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -239,7 +239,7 @@ class RoomPermissionsTestCase(RestTestCase): "PUT", topic_path, topic_content) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(200, code, msg=str(response)) # get topic in PUBLIC room, not joined, expect 403 (code, response) = yield self.mock_resource.trigger_get( @@ -301,11 +301,11 @@ class RoomPermissionsTestCase(RestTestCase): room=room, expect_code=200) # get membership of self, get membership of other, private room + left - # expect all 403s + # expect all 200s yield self.leave(room=room, user=self.user_id) yield self._test_get_membership( members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + room=room, expect_code=200) @defer.inlineCallbacks def test_membership_public_room_perms(self): @@ -326,11 +326,11 @@ class RoomPermissionsTestCase(RestTestCase): room=room, expect_code=200) # get membership of self, get membership of other, public room + left - # expect 403. + # expect 200. yield self.leave(room=room, user=self.user_id) yield self._test_get_membership( members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + room=room, expect_code=200) @defer.inlineCallbacks def test_invited_permissions(self): @@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -492,9 +492,9 @@ class RoomsMemberListTestCase(RestTestCase): self.assertEquals(200, code, msg=str(response)) yield self.leave(room=room_id, user=self.user_id) - # can no longer see list, you've left. + # can see old list once left (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(200, code, msg=str(response)) class RoomsCreateTestCase(RestTestCase): @@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 1c4519406d..6395ce79db 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c472d53043..85096a0326 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase): self.mock_resource = None self.auth_user_id = None - def mock_get_user_by_access_token(self, token=None): - return self.auth_user_id - @defer.inlineCallbacks def create_room_as(self, room_creator, is_public=True, tok=None): temp_id = self.auth_user_id diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index ef972a53aa..f45570a1c0 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase): "user": UserID.from_string(self.USER_ID), "token_id": 1, } - hs.get_auth().get_user_by_access_token = _get_user_by_access_token + hs.get_auth()._get_user_by_access_token = _get_user_by_access_token for r in self.TO_REGISTER: r.register_servlets(hs, self.mock_resource) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py new file mode 100644 index 0000000000..42bd8928bd --- /dev/null +++ b/tests/storage/event_injector.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.types import UserID, RoomID + +from tests.utils import setup_test_homeserver + +from mock import Mock + + +class EventInjector: + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.message_handler = hs.get_handlers().message_handler + self.event_builder_factory = hs.get_event_builder_factory() + + @defer.inlineCallbacks + def create_room(self, room): + builder = self.event_builder_factory.new({ + "type": EventTypes.Create, + "room_id": room.to_string(), + "content": {}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + @defer.inlineCallbacks + def inject_room_member(self, room, user, membership): + builder = self.event_builder_factory.new({ + "type": EventTypes.Member, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"membership": membership}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_message(self, room, user, body): + builder = self.event_builder_factory.new({ + "type": EventTypes.Message, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"body": body, "msgtype": u"message"}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py new file mode 100644 index 0000000000..313013009e --- /dev/null +++ b/tests/storage/test_events.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import uuid +from mock.mock import Mock +from synapse.types import RoomID, UserID + +from tests import unittest +from twisted.internet import defer +from tests.storage.event_injector import EventInjector + +from tests.utils import setup_test_homeserver + + +class EventsStoreTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + resource_for_federation=Mock(), + http_client=None, + ) + self.store = self.hs.get_datastore() + self.db_pool = self.hs.get_db_pool() + self.message_handler = self.hs.get_handlers().message_handler + self.event_injector = EventInjector(self.hs) + + @defer.inlineCallbacks + def test_count_daily_messages(self): + self.db_pool.runQuery("DELETE FROM stats_reporting") + + self.hs.clock.now = 100 + + # Never reported before, and nothing which could be reported + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting") + self.assertEqual([(0,)], count) + + # Create something to report + room = RoomID.from_string("!abc123:test") + user = UserID.from_string("@raccoonlover:test") + yield self.event_injector.create_room(room) + + self.base_event = yield self._get_last_stream_token() + + yield self.event_injector.inject_message(room, user, "Raccoons are really cute") + + # Never reported before, something could be reported, but isn't because + # it isn't old enough. + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(1, self.hs.clock.now) + + # Already reported yesterday, two new events from today. + yield self.event_injector.inject_message(room, user, "Yeah they are!") + yield self.event_injector.inject_message(room, user, "Incredibly!") + self.hs.clock.now += 60 * 60 * 24 + count = yield self.store.count_daily_messages() + self.assertEqual(2, count) # 2 since yesterday + self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + + # Last reported too recently. + yield self.event_injector.inject_message(room, user, "Who could disagree?") + self.hs.clock.now += 60 * 60 * 22 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(4, self.hs.clock.now) + + # Last reported too long ago + yield self.event_injector.inject_message(room, user, "No one.") + self.hs.clock.now += 60 * 60 * 26 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(5, self.hs.clock.now) + + # And now let's actually report something + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + # A little over 24 hours is fine :) + self.hs.clock.now += (60 * 60 * 24) + 50 + count = yield self.store.count_daily_messages() + self.assertEqual(3, count) + self._assert_stats_reporting(8, self.hs.clock.now) + + @defer.inlineCallbacks + def _get_last_stream_token(self): + rows = yield self.db_pool.runQuery( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + if not rows: + defer.returnValue(0) + else: + defer.returnValue(rows[0][0]) + + @defer.inlineCallbacks + def _assert_stats_reporting(self, messages, time): + rows = yield self.db_pool.runQuery( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + self.assertEqual([(self.base_event + messages, time,)], rows) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ab7625a3ca..caffce64e3 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() - self.event_factory = hs.get_event_factory(); + self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 0c9b89d765..a658a789aa 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, RoomID +from tests.storage.event_injector import EventInjector from tests.utils import setup_test_homeserver @@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase): self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() + self.event_injector = EventInjector(hs) self.handlers = hs.get_handlers() self.message_handler = self.handlers.message_handler @@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") self.room2 = RoomID.from_string("!xyx987:test") - self.depth = 1 - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - defer.returnValue(event) - - @defer.inlineCallbacks - def inject_message(self, room, user, body): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - @defer.inlineCallbacks def test_event_stream_get_other(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_get_own(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_join_leave(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Then bob leaves again. - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.LEAVE ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_prev_content(self): - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) - event1 = yield self.inject_room_member( + event1 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) start = yield self.store.get_room_events_max_id() - event2 = yield self.inject_room_member( + event2 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN, ) diff --git a/tests/test_state.py b/tests/test_state.py index 5845358754..55f37c521f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase): nodes={ "START": DictObj( type=EventTypes.Create, - state_key="creator", - content={"membership": "@user_id:example.com"}, + state_key="", + content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( @@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase): nodes={ "START": DictObj( type=EventTypes.Create, - state_key="creator", - content={"membership": "@user_id:example.com"}, + state_key="", + content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( @@ -432,13 +432,19 @@ class StateTestCase(unittest.TestCase): def test_resolve_message_conflict(self): event = create_event(type="test_message", name="event") + creation = create_event( + type=EventTypes.Create, state_key="" + ) + old_state_1 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), @@ -446,7 +452,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 5) + self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @@ -454,13 +460,19 @@ class StateTestCase(unittest.TestCase): def test_resolve_state_conflict(self): event = create_event(type="test4", state_key="", name="event") + creation = create_event( + type=EventTypes.Create, state_key="" + ) + old_state_1 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), @@ -468,7 +480,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 5) + self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @@ -484,36 +496,45 @@ class StateTestCase(unittest.TestCase): } ) + creation = create_event( + type=EventTypes.Create, state_key="", + content={"creator": "@foo:bar"} + ) + old_state_1 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=2), ] context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_2[1], context.current_state[("test1", "1")]) + self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. old_state_1 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=1), ] context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_1[1], context.current_state[("test1", "1")]) + self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" |