diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_distributor.py | 2 | ||||
-rw-r--r-- | tests/test_event_auth.py | 151 | ||||
-rw-r--r-- | tests/test_state.py | 16 | ||||
-rw-r--r-- | tests/util/caches/test_descriptors.py | 2 | ||||
-rw-r--r-- | tests/util/test_stream_change_cache.py | 198 |
5 files changed, 363 insertions, 6 deletions
diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 010aeaee7e..c066381698 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -19,7 +19,6 @@ from twisted.internet import defer from mock import Mock, patch from synapse.util.distributor import Distributor -from synapse.util.async import run_on_reactor class DistributorTestCase(unittest.TestCase): @@ -95,7 +94,6 @@ class DistributorTestCase(unittest.TestCase): @defer.inlineCallbacks def observer(): - yield run_on_reactor() raise MyException("Oopsie") self.dist.observe("whail", observer) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py new file mode 100644 index 0000000000..d08e19c53a --- /dev/null +++ b/tests/test_event_auth.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse import event_auth +from synapse.api.errors import AuthError +from synapse.events import FrozenEvent +import unittest + + +class EventAuthTestCase(unittest.TestCase): + def test_random_users_cannot_send_state_before_first_pl(self): + """ + Check that, before the first PL lands, the creator is the only user + that can send a state event. + """ + creator = "@creator:example.com" + joiner = "@joiner:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.member", joiner): _join_event(joiner), + } + + # creator should be able to send state + event_auth.check( + _random_state_event(creator), auth_events, + do_sig_check=False, + ) + + # joiner should not be able to send state + self.assertRaises( + AuthError, + event_auth.check, + _random_state_event(joiner), + auth_events, + do_sig_check=False, + ), + + def test_state_default_level(self): + """ + Check that users above the state_default level can send state and + those below cannot + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + king = "@joiner2:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.power_levels", ""): _power_levels_event(creator, { + "state_default": "30", + "users": { + pleb: "29", + king: "30", + }, + }), + ("m.room.member", pleb): _join_event(pleb), + ("m.room.member", king): _join_event(king), + } + + # pleb should not be able to send state + self.assertRaises( + AuthError, + event_auth.check, + _random_state_event(pleb), + auth_events, + do_sig_check=False, + ), + + # king should be able to send state + event_auth.check( + _random_state_event(king), auth_events, + do_sig_check=False, + ) + + +# helpers for making events + +TEST_ROOM_ID = "!test:room" + + +def _create_event(user_id): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.create", + "sender": user_id, + "content": { + "creator": user_id, + }, + }) + + +def _join_event(user_id): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "content": { + "membership": "join", + }, + }) + + +def _power_levels_event(sender, content): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.power_levels", + "sender": sender, + "state_key": "", + "content": content, + }) + + +def _random_state_event(sender): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "test.state", + "sender": sender, + "state_key": "", + "content": { + "membership": "join", + }, + }) + + +event_count = 0 + + +def _get_event_id(): + global event_count + c = event_count + event_count += 1 + return "!%i:example.com" % (c, ) diff --git a/tests/test_state.py b/tests/test_state.py index a5c5e55951..71c412faf4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -606,6 +606,14 @@ class StateTestCase(unittest.TestCase): } ) + power_levels = create_event( + type=EventTypes.PowerLevels, state_key="", + content={"users": { + "@foo:bar": "100", + "@user_id:example.com": "100", + }} + ) + creation = create_event( type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"} @@ -613,12 +621,14 @@ class StateTestCase(unittest.TestCase): old_state_1 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] @@ -633,7 +643,7 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - old_state_2[2].event_id, context.current_state_ids[("test1", "1")] + old_state_2[3].event_id, context.current_state_ids[("test1", "1")] ) # Reverse the depth to make sure we are actually using the depths @@ -641,12 +651,14 @@ class StateTestCase(unittest.TestCase): old_state_1 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] @@ -659,7 +671,7 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - old_state_1[2].event_id, context.current_state_ids[("test1", "1")] + old_state_1[3].event_id, context.current_state_ids[("test1", "1")] ) def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 2516fe40f4..24754591df 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -18,7 +18,6 @@ import logging import mock from synapse.api.errors import SynapseError -from synapse.util import async from synapse.util import logcontext from twisted.internet import defer from synapse.util.caches import descriptors @@ -195,7 +194,6 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): - yield async.run_on_reactor() raise SynapseError(400, "blah") return inner_fn() diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py new file mode 100644 index 0000000000..67ece166c7 --- /dev/null +++ b/tests/util/test_stream_change_cache.py @@ -0,0 +1,198 @@ +from tests import unittest +from mock import patch + +from synapse.util.caches.stream_change_cache import StreamChangeCache + + +class StreamChangeCacheTests(unittest.TestCase): + """ + Tests for StreamChangeCache. + """ + + def test_prefilled_cache(self): + """ + Providing a prefilled cache to StreamChangeCache will result in a cache + with the prefilled-cache entered in. + """ + cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2}) + self.assertTrue(cache.has_entity_changed("user@foo.com", 1)) + + def test_has_entity_changed(self): + """ + StreamChangeCache.entity_has_changed will mark entities as changed, and + has_entity_changed will observe the changed entities. + """ + cache = StreamChangeCache("#test", 3) + + cache.entity_has_changed("user@foo.com", 6) + cache.entity_has_changed("bar@baz.net", 7) + + # If it's been changed after that stream position, return True + self.assertTrue(cache.has_entity_changed("user@foo.com", 4)) + self.assertTrue(cache.has_entity_changed("bar@baz.net", 4)) + + # If it's been changed at that stream position, return False + self.assertFalse(cache.has_entity_changed("user@foo.com", 6)) + + # If there's no changes after that stream position, return False + self.assertFalse(cache.has_entity_changed("user@foo.com", 7)) + + # If the entity does not exist, return False. + self.assertFalse(cache.has_entity_changed("not@here.website", 7)) + + # If we request before the stream cache's earliest known position, + # return True, whether it's a known entity or not. + self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) + self.assertTrue(cache.has_entity_changed("not@here.website", 0)) + + @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) + def test_has_entity_changed_pops_off_start(self): + """ + StreamChangeCache.entity_has_changed will respect the max size and + purge the oldest items upon reaching that max size. + """ + cache = StreamChangeCache("#test", 1, max_size=2) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # The cache is at the max size, 2 + self.assertEqual(len(cache._cache), 2) + + # The oldest item has been popped off + self.assertTrue("user@foo.com" not in cache._entity_to_key) + + # If we update an existing entity, it keeps the two existing entities + cache.entity_has_changed("bar@baz.net", 5) + self.assertEqual( + set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key) + ) + + def test_get_all_entities_changed(self): + """ + StreamChangeCache.get_all_entities_changed will return all changed + entities since the given position. If the position is before the start + of the known stream, it returns None instead. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + self.assertEqual( + cache.get_all_entities_changed(1), + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], + ) + self.assertEqual( + cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"] + ) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) + self.assertEqual(cache.get_all_entities_changed(0), None) + + def test_has_any_entity_changed(self): + """ + StreamChangeCache.has_any_entity_changed will return True if any + entities have been changed since the provided stream position, and + False if they have not. If the cache has entries and the provided + stream position is before it, it will return True, otherwise False if + the cache has no entries. + """ + cache = StreamChangeCache("#test", 1) + + # With no entities, it returns False for the past, present, and future. + self.assertFalse(cache.has_any_entity_changed(0)) + self.assertFalse(cache.has_any_entity_changed(1)) + self.assertFalse(cache.has_any_entity_changed(2)) + + # We add an entity + cache.entity_has_changed("user@foo.com", 2) + + # With an entity, it returns True for the past, the stream start + # position, and False for the stream position the entity was changed + # on and ones after it. + self.assertTrue(cache.has_any_entity_changed(0)) + self.assertTrue(cache.has_any_entity_changed(1)) + self.assertFalse(cache.has_any_entity_changed(2)) + self.assertFalse(cache.has_any_entity_changed(3)) + + def test_get_entities_changed(self): + """ + StreamChangeCache.get_entities_changed will return the entities in the + given list that have changed since the provided stream ID. If the + stream position is earlier than the earliest known position, it will + return all of the entities queried for. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # Query all the entries, but mid-way through the stream. We should only + # get the ones after that point. + self.assertEqual( + cache.get_entities_changed( + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 + ), + set(["bar@baz.net", "user@elsewhere.org"]), + ) + + # Query all the entries mid-way through the stream, but include one + # that doesn't exist in it. We should get back the one that doesn't + # exist, too. + self.assertEqual( + cache.get_entities_changed( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ], + stream_pos=2, + ), + set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]), + ) + + # Query all the entries, but before the first known point. We will get + # all the entries we queried for, including ones that don't exist. + self.assertEqual( + cache.get_entities_changed( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ], + stream_pos=0, + ), + set( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ] + ), + ) + + def test_max_pos(self): + """ + StreamChangeCache.get_max_pos_of_last_change will return the most + recent point where the entity could have changed. If the entity is not + known, the stream start is provided instead. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # Known entities will return the point where they were changed. + self.assertEqual(cache.get_max_pos_of_last_change("user@foo.com"), 2) + self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3) + self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4) + + # Unknown entities will return the stream start position. + self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1) |