diff options
author | Erik Johnston <erik@matrix.org> | 2018-08-06 13:33:54 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2018-08-06 13:33:54 +0100 |
commit | 49a316395831a0676160833c39daa0bb63ceea09 (patch) | |
tree | f1c4db4121d7890a7ba390c01b0e81e329f09b22 /tests | |
parent | Merge branch 'release-v0.33.1' of github.com:matrix-org/synapse into matrix-o... (diff) | |
parent | Return M_NOT_FOUND when a profile could not be found. (#3596) (diff) | |
download | synapse-49a316395831a0676160833c39daa0bb63ceea09.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
Diffstat (limited to 'tests')
-rw-r--r-- | tests/api/test_auth.py | 35 | ||||
-rw-r--r-- | tests/handlers/test_auth.py | 80 | ||||
-rw-r--r-- | tests/handlers/test_register.py | 51 | ||||
-rw-r--r-- | tests/handlers/test_typing.py | 1 | ||||
-rw-r--r-- | tests/replication/slave/storage/_base.py | 37 | ||||
-rw-r--r-- | tests/storage/test__init__.py | 65 | ||||
-rw-r--r-- | tests/storage/test_state.py | 319 | ||||
-rw-r--r-- | tests/util/caches/test_descriptors.py | 101 | ||||
-rw-r--r-- | tests/utils.py | 9 |
9 files changed, 659 insertions, 39 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 5f158ec4b9..a82d737e71 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase): self.auth = Auth(self.hs) self.test_user = "@foo:bar" - self.test_token = "_test_token_" + self.test_token = b"_test_token_" # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) @@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): - masquerading_user_id = "@doppelganger:matrix.org" + masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None, @@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] - request.args["user_id"] = [masquerading_user_id] + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) - self.assertEquals(requester.user.to_string(), masquerading_user_id) + self.assertEquals( + requester.user.to_string(), + masquerading_user_id.decode('utf8') + ) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): - masquerading_user_id = "@doppelganger:matrix.org" + masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None, @@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] - request.args["user_id"] = [masquerading_user_id] + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase): # check the token works request = Mock(args={}) - request.args["access_token"] = [token] + request.args[b"access_token"] = [token.encode('ascii')] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request, allow_guest=True) self.assertEqual(UserID.from_string(USER_ID), requester.user) @@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase): # the token should *not* work now request = Mock(args={}) - request.args["access_token"] = [guest_tok] + request.args[b"access_token"] = [guest_tok.encode('ascii')] request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(AuthError) as cm: diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 2e5e8e4dec..55eab9e9cf 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock import pymacaroons @@ -19,6 +20,7 @@ from twisted.internet import defer import synapse import synapse.api.errors +from synapse.api.errors import AuthError from synapse.handlers.auth import AuthHandler from tests import unittest @@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() + # MAU tests + self.hs.config.max_mau_value = 50 + self.small_number_of_users = 1 + self.large_number_of_users = 100 def test_token_is_a_macaroon(self): token = self.macaroon_generator.generate_access_token("some_user") @@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase): v.satisfy_general(verify_nonce) v.verify(macaroon, self.hs.config.macaroon_secret_key) + @defer.inlineCallbacks def test_short_term_login_token_gives_user_id(self): self.hs.clock.now = 1000 token = self.macaroon_generator.generate_short_term_login_token( "a_user", 5000 ) - - self.assertEqual( - "a_user", - self.auth_handler.validate_short_term_login_token_and_get_user_id( - token - ) + user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + token ) + self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected self.hs.clock.now = 6000 with self.assertRaises(synapse.api.errors.AuthError): - self.auth_handler.validate_short_term_login_token_and_get_user_id( + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( token ) + @defer.inlineCallbacks def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( "a_user", 5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) + user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) self.assertEqual( - "a_user", - self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() - ) + "a_user", user_id ) # add another "user_id" caveat, which might allow us to override the @@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = b_user") with self.assertRaises(synapse.api.errors.AuthError): - self.auth_handler.validate_short_term_login_token_and_get_user_id( + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( macaroon.serialize() ) + + @defer.inlineCallbacks + def test_mau_limits_disabled(self): + self.hs.config.limit_usage_by_mau = False + # Ensure does not throw exception + yield self.auth_handler.get_access_token_for_user_id('user_a') + + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + + @defer.inlineCallbacks + def test_mau_limits_exceeded(self): + self.hs.config.limit_usage_by_mau = True + self.hs.get_datastore().count_monthly_users = Mock( + return_value=defer.succeed(self.large_number_of_users) + ) + + with self.assertRaises(AuthError): + yield self.auth_handler.get_access_token_for_user_id('user_a') + + self.hs.get_datastore().count_monthly_users = Mock( + return_value=defer.succeed(self.large_number_of_users) + ) + with self.assertRaises(AuthError): + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + + @defer.inlineCallbacks + def test_mau_limits_not_exceeded(self): + self.hs.config.limit_usage_by_mau = True + + self.hs.get_datastore().count_monthly_users = Mock( + return_value=defer.succeed(self.small_number_of_users) + ) + # Ensure does not raise exception + yield self.auth_handler.get_access_token_for_user_id('user_a') + + self.hs.get_datastore().count_monthly_users = Mock( + return_value=defer.succeed(self.small_number_of_users) + ) + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + + def _get_macaroon(self): + token = self.macaroon_generator.generate_short_term_login_token( + "user_a", 5000 + ) + return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 025fa1be81..0937d71cf6 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,6 +17,7 @@ from mock import Mock from twisted.internet import defer +from synapse.api.errors import RegistrationError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase): requester, local_part, display_name) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') + + @defer.inlineCallbacks + def test_cannot_register_when_mau_limits_exceeded(self): + local_part = "someone" + display_name = "someone" + requester = create_requester("@as:test") + store = self.hs.get_datastore() + self.hs.config.limit_usage_by_mau = False + self.hs.config.max_mau_value = 50 + lots_of_users = 100 + small_number_users = 1 + + store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + + # Ensure does not throw exception + yield self.handler.get_or_create_user(requester, 'a', display_name) + + self.hs.config.limit_usage_by_mau = True + + with self.assertRaises(RegistrationError): + yield self.handler.get_or_create_user(requester, 'b', display_name) + + store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users)) + + self._macaroon_mock_generator("another_secret") + + # Ensure does not throw exception + yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") + + self._macaroon_mock_generator("another another secret") + store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + + with self.assertRaises(RegistrationError): + yield self.handler.register(localpart=local_part) + + self._macaroon_mock_generator("another another secret") + store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + + with self.assertRaises(RegistrationError): + yield self.handler.register_saml2(local_part) + + def _macaroon_mock_generator(self, secret): + """ + Reset macaroon generator in the case where the test creates multiple users + """ + macaroon_generator = Mock( + generate_access_token=Mock(return_value=secret)) + self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator) + self.hs.handlers = RegistrationHandlers(self.hs) + self.handler = self.hs.get_handlers().registration_handler diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index b08856f763..2c263af1a3 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"): "content": content, } ], - "pdu_failures": [], } diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 8708c8a196..a103e7be80 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -11,23 +11,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import tempfile from mock import Mock, NonCallableMock from twisted.internet import defer, reactor +from twisted.internet.defer import Deferred from synapse.replication.tcp.client import ( ReplicationClientFactory, ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable from tests import unittest from tests.utils import setup_test_homeserver +class TestReplicationClientHandler(ReplicationClientHandler): + """Overrides on_rdata so that we can wait for it to happen""" + def __init__(self, store): + super(TestReplicationClientHandler, self).__init__(store) + self._rdata_awaiters = [] + + def await_replication(self): + d = Deferred() + self._rdata_awaiters.append(d) + return make_deferred_yieldable(d) + + def on_rdata(self, stream_name, token, rows): + awaiters = self._rdata_awaiters + self._rdata_awaiters = [] + super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows) + with PreserveLoggingContext(): + for a in awaiters: + a.callback(None) + + class BaseSlavedStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): @@ -52,7 +73,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.addCleanup(listener.stopListening) self.streamer = server_factory.streamer - self.replication_handler = ReplicationClientHandler(self.slaved_store) + self.replication_handler = TestReplicationClientHandler(self.slaved_store) client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) @@ -60,12 +81,14 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.addCleanup(client_factory.stopTrying) self.addCleanup(client_connector.disconnect) - @defer.inlineCallbacks def replicate(self): - yield self.streamer.on_notifier_poke() - d = self.replication_handler.await_sync("replication_test") - self.streamer.send_sync_to_all_connections("replication_test") - yield d + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + # xxx: should we be more specific in what we wait for? + d = self.replication_handler.await_replication() + self.streamer.on_notifier_poke() + return d @defer.inlineCallbacks def check(self, method, args, expected_result=None): diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py new file mode 100644 index 0000000000..f19cb1265c --- /dev/null +++ b/tests/storage/test__init__.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import tests.utils + + +class InitTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(InitTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + + hs.config.max_mau_value = 50 + hs.config.limit_usage_by_mau = True + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def test_count_monthly_users(self): + count = yield self.store.count_monthly_users() + self.assertEqual(0, count) + + yield self._insert_user_ips("@user:server1") + yield self._insert_user_ips("@user:server2") + + count = yield self.store.count_monthly_users() + self.assertEqual(2, count) + + @defer.inlineCallbacks + def _insert_user_ips(self, user): + """ + Helper function to populate user_ips without using batch insertion infra + args: + user (str): specify username i.e. @user:server.com + """ + yield self.store._simple_upsert( + table="user_ips", + keyvalues={ + "user_id": user, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": "device_id", + }, + values={ + "last_seen": self.clock.time_msec(), + } + ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py new file mode 100644 index 0000000000..7a76d67b8c --- /dev/null +++ b/tests/storage/test_state.py @@ -0,0 +1,319 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.types import RoomID, UserID + +import tests.unittest +import tests.utils + +logger = logging.getLogger(__name__) + + +class StateStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(StateStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + + self.store = hs.get_datastore() + self.event_builder_factory = hs.get_event_builder_factory() + self.event_creation_handler = hs.get_event_creation_handler() + + self.u_alice = UserID.from_string("@alice:test") + self.u_bob = UserID.from_string("@bob:test") + + self.room = RoomID.from_string("!abc123:test") + + yield self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True + ) + + @defer.inlineCallbacks + def inject_state_event(self, room, sender, typ, state_key, content): + builder = self.event_builder_factory.new({ + "type": typ, + "sender": sender.to_string(), + "state_key": state_key, + "room_id": room.to_string(), + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + defer.returnValue(event) + + def assertStateMapEqual(self, s1, s2): + for t in s1: + # just compare event IDs for simplicity + self.assertEqual(s1[t].event_id, s2[t].event_id) + self.assertEqual(len(s1), len(s2)) + + @defer.inlineCallbacks + def test_get_state_for_event(self): + + # this defaults to a linear DAG as each new injection defaults to whatever + # forward extremities are currently in the DB for this room. + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {}, + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', { + "name": "test room" + }, + ) + e3 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), { + "membership": Membership.JOIN + }, + ) + e4 = yield self.inject_state_event( + self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), { + "membership": Membership.JOIN + }, + ) + e5 = yield self.inject_state_event( + self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), { + "membership": Membership.LEAVE + }, + ) + + # check we get the full state as of the final event + state = yield self.store.get_state_for_event( + e5.event_id, None, filtered_types=None + ) + + self.assertIsNotNone(e4) + + self.assertStateMapEqual({ + (e1.type, e1.state_key): e1, + (e2.type, e2.state_key): e2, + (e3.type, e3.state_key): e3, + # e4 is overwritten by e5 + (e5.type, e5.state_key): e5, + }, state) + + # check we can filter to the m.room.name event (with a '' state key) + state = yield self.store.get_state_for_event( + e5.event_id, [(EventTypes.Name, '')], filtered_types=None + ) + + self.assertStateMapEqual({ + (e2.type, e2.state_key): e2, + }, state) + + # check we can filter to the m.room.name event (with a wildcard None state key) + state = yield self.store.get_state_for_event( + e5.event_id, [(EventTypes.Name, None)], filtered_types=None + ) + + self.assertStateMapEqual({ + (e2.type, e2.state_key): e2, + }, state) + + # check we can grab the m.room.member events (with a wildcard None state key) + state = yield self.store.get_state_for_event( + e5.event_id, [(EventTypes.Member, None)], filtered_types=None + ) + + self.assertStateMapEqual({ + (e3.type, e3.state_key): e3, + (e5.type, e5.state_key): e5, + }, state) + + # check we can use filter_types to grab a specific room member + # without filtering out the other event types + state = yield self.store.get_state_for_event( + e5.event_id, [(EventTypes.Member, self.u_alice.to_string())], + filtered_types=[EventTypes.Member], + ) + + self.assertStateMapEqual({ + (e1.type, e1.state_key): e1, + (e2.type, e2.state_key): e2, + (e3.type, e3.state_key): e3, + }, state) + + # check that types=[], filtered_types=[EventTypes.Member] + # doesn't return all members + state = yield self.store.get_state_for_event( + e5.event_id, [], filtered_types=[EventTypes.Member], + ) + + self.assertStateMapEqual({ + (e1.type, e1.state_key): e1, + (e2.type, e2.state_key): e2, + }, state) + + ####################################################### + # _get_some_state_from_cache tests against a full cache + ####################################################### + + room_id = self.room.to_string() + group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) + group = group_ids.keys()[0] + + # test _get_some_state_from_cache correctly filters out members with types=[] + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + group, [], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({ + (e1.type, e1.state_key): e1.event_id, + (e2.type, e2.state_key): e2.event_id, + }, state_dict) + + # test _get_some_state_from_cache correctly filters in members with wildcard types + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({ + (e1.type, e1.state_key): e1.event_id, + (e2.type, e2.state_key): e2.event_id, + (e3.type, e3.state_key): e3.event_id, + # e4 is overwritten by e5 + (e5.type, e5.state_key): e5.event_id, + }, state_dict) + + # test _get_some_state_from_cache correctly filters in members with specific types + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({ + (e1.type, e1.state_key): e1.event_id, + (e2.type, e2.state_key): e2.event_id, + (e5.type, e5.state_key): e5.event_id, + }, state_dict) + + # test _get_some_state_from_cache correctly filters in members with specific types + # and no filtered_types + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + group, [(EventTypes.Member, e5.state_key)], filtered_types=None + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({ + (e5.type, e5.state_key): e5.event_id, + }, state_dict) + + ####################################################### + # deliberately remove e2 (room name) from the _state_group_cache + + (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group) + + self.assertEqual(is_all, True) + self.assertEqual(known_absent, set()) + self.assertDictEqual(state_dict_ids, { + (e1.type, e1.state_key): e1.event_id, + (e2.type, e2.state_key): e2.event_id, + (e3.type, e3.state_key): e3.event_id, + # e4 is overwritten by e5 + (e5.type, e5.state_key): e5.event_id, + }) + + state_dict_ids.pop((e2.type, e2.state_key)) + self.store._state_group_cache.invalidate(group) + self.store._state_group_cache.update( + sequence=self.store._state_group_cache.sequence, + key=group, + value=state_dict_ids, + # list fetched keys so it knows it's partial + fetched_keys=( + (e1.type, e1.state_key), + (e3.type, e3.state_key), + (e5.type, e5.state_key), + ) + ) + + (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group) + + self.assertEqual(is_all, False) + self.assertEqual(known_absent, set([ + (e1.type, e1.state_key), + (e3.type, e3.state_key), + (e5.type, e5.state_key), + ])) + self.assertDictEqual(state_dict_ids, { + (e1.type, e1.state_key): e1.event_id, + (e3.type, e3.state_key): e3.event_id, + (e5.type, e5.state_key): e5.event_id, + }) + + ############################################ + # test that things work with a partial cache + + # test _get_some_state_from_cache correctly filters out members with types=[] + room_id = self.room.to_string() + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + group, [], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, False) + self.assertDictEqual({ + (e1.type, e1.state_key): e1.event_id, + }, state_dict) + + # test _get_some_state_from_cache correctly filters in members wildcard types + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, False) + self.assertDictEqual({ + (e1.type, e1.state_key): e1.event_id, + (e3.type, e3.state_key): e3.event_id, + # e4 is overwritten by e5 + (e5.type, e5.state_key): e5.event_id, + }, state_dict) + + # test _get_some_state_from_cache correctly filters in members with specific types + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, False) + self.assertDictEqual({ + (e1.type, e1.state_key): e1.event_id, + (e5.type, e5.state_key): e5.event_id, + }, state_dict) + + # test _get_some_state_from_cache correctly filters in members with specific types + # and no filtered_types + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + group, [(EventTypes.Member, e5.state_key)], filtered_types=None + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({ + (e5.type, e5.state_key): e5.event_id, + }, state_dict) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 8176a7dabd..ca8a7c907f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase): r = yield obj.fn(2, 3) self.assertEqual(r, 'chips') obj.mock.assert_not_called() + + +class CachedListDescriptorTestCase(unittest.TestCase): + @defer.inlineCallbacks + def test_cache(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + pass + + @descriptors.cachedList("fn", "args1", inlineCallbacks=True) + def list_fn(self, args1, arg2): + assert ( + logcontext.LoggingContext.current_context().request == "c1" + ) + # we want this to behave like an asynchronous function + yield run_on_reactor() + assert ( + logcontext.LoggingContext.current_context().request == "c1" + ) + defer.returnValue(self.mock(args1, arg2)) + + with logcontext.LoggingContext() as c1: + c1.request = "c1" + obj = Cls() + obj.mock.return_value = {10: 'fish', 20: 'chips'} + d1 = obj.list_fn([10, 20], 2) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + r = yield d1 + self.assertEqual( + logcontext.LoggingContext.current_context(), + c1 + ) + obj.mock.assert_called_once_with([10, 20], 2) + self.assertEqual(r, {10: 'fish', 20: 'chips'}) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = {30: 'peas'} + r = yield obj.list_fn([20, 30], 2) + obj.mock.assert_called_once_with([30], 2) + self.assertEqual(r, {20: 'chips', 30: 'peas'}) + obj.mock.reset_mock() + + # all the values should now be cached + r = yield obj.fn(10, 2) + self.assertEqual(r, 'fish') + r = yield obj.fn(20, 2) + self.assertEqual(r, 'chips') + r = yield obj.fn(30, 2) + self.assertEqual(r, 'peas') + r = yield obj.list_fn([10, 20, 30], 2) + obj.mock.assert_not_called() + self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'}) + + @defer.inlineCallbacks + def test_invalidate(self): + """Make sure that invalidation callbacks are called.""" + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + pass + + @descriptors.cachedList("fn", "args1", inlineCallbacks=True) + def list_fn(self, args1, arg2): + # we want this to behave like an asynchronous function + yield run_on_reactor() + defer.returnValue(self.mock(args1, arg2)) + + obj = Cls() + invalidate0 = mock.Mock() + invalidate1 = mock.Mock() + + # cache miss + obj.mock.return_value = {10: 'fish', 20: 'chips'} + r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) + obj.mock.assert_called_once_with([10, 20], 2) + self.assertEqual(r1, {10: 'fish', 20: 'chips'}) + obj.mock.reset_mock() + + # cache hit + r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) + obj.mock.assert_not_called() + self.assertEqual(r2, {10: 'fish', 20: 'chips'}) + + invalidate0.assert_not_called() + invalidate1.assert_not_called() + + # now if we invalidate the keys, both invalidations should get called + obj.fn.invalidate((10, 2)) + invalidate0.assert_called_once() + invalidate1.assert_called_once() diff --git a/tests/utils.py b/tests/utils.py index c3dbff8507..9bff3ff3b9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -193,7 +193,7 @@ class MockHttpResource(HttpServer): self.prefix = prefix def trigger_get(self, path): - return self.trigger("GET", path, None) + return self.trigger(b"GET", path, None) @patch('twisted.web.http.Request') @defer.inlineCallbacks @@ -227,7 +227,7 @@ class MockHttpResource(HttpServer): headers = {} if federation_auth: - headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] + headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it @@ -241,6 +241,9 @@ class MockHttpResource(HttpServer): except Exception: pass + if isinstance(path, bytes): + path = path.decode('utf8') + for (method, pattern, func) in self.callbacks: if http_method != method: continue @@ -249,7 +252,7 @@ class MockHttpResource(HttpServer): if matcher: try: args = [ - urlparse.unquote(u).decode("UTF-8") + urlparse.unquote(u) for u in matcher.groups() ] |