diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/api/test_auth.py | 35 | ||||
-rw-r--r-- | tests/util/caches/test_descriptors.py | 101 | ||||
-rw-r--r-- | tests/utils.py | 9 |
3 files changed, 126 insertions, 19 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 5f158ec4b9..a82d737e71 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase): self.auth = Auth(self.hs) self.test_user = "@foo:bar" - self.test_token = "_test_token_" + self.test_token = b"_test_token_" # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) @@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): - masquerading_user_id = "@doppelganger:matrix.org" + masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None, @@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] - request.args["user_id"] = [masquerading_user_id] + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) - self.assertEquals(requester.user.to_string(), masquerading_user_id) + self.assertEquals( + requester.user.to_string(), + masquerading_user_id.decode('utf8') + ) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): - masquerading_user_id = "@doppelganger:matrix.org" + masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None, @@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] - request.args["user_id"] = [masquerading_user_id] + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase): # check the token works request = Mock(args={}) - request.args["access_token"] = [token] + request.args[b"access_token"] = [token.encode('ascii')] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request, allow_guest=True) self.assertEqual(UserID.from_string(USER_ID), requester.user) @@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase): # the token should *not* work now request = Mock(args={}) - request.args["access_token"] = [guest_tok] + request.args[b"access_token"] = [guest_tok.encode('ascii')] request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(AuthError) as cm: diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 8176a7dabd..ca8a7c907f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase): r = yield obj.fn(2, 3) self.assertEqual(r, 'chips') obj.mock.assert_not_called() + + +class CachedListDescriptorTestCase(unittest.TestCase): + @defer.inlineCallbacks + def test_cache(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + pass + + @descriptors.cachedList("fn", "args1", inlineCallbacks=True) + def list_fn(self, args1, arg2): + assert ( + logcontext.LoggingContext.current_context().request == "c1" + ) + # we want this to behave like an asynchronous function + yield run_on_reactor() + assert ( + logcontext.LoggingContext.current_context().request == "c1" + ) + defer.returnValue(self.mock(args1, arg2)) + + with logcontext.LoggingContext() as c1: + c1.request = "c1" + obj = Cls() + obj.mock.return_value = {10: 'fish', 20: 'chips'} + d1 = obj.list_fn([10, 20], 2) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + r = yield d1 + self.assertEqual( + logcontext.LoggingContext.current_context(), + c1 + ) + obj.mock.assert_called_once_with([10, 20], 2) + self.assertEqual(r, {10: 'fish', 20: 'chips'}) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = {30: 'peas'} + r = yield obj.list_fn([20, 30], 2) + obj.mock.assert_called_once_with([30], 2) + self.assertEqual(r, {20: 'chips', 30: 'peas'}) + obj.mock.reset_mock() + + # all the values should now be cached + r = yield obj.fn(10, 2) + self.assertEqual(r, 'fish') + r = yield obj.fn(20, 2) + self.assertEqual(r, 'chips') + r = yield obj.fn(30, 2) + self.assertEqual(r, 'peas') + r = yield obj.list_fn([10, 20, 30], 2) + obj.mock.assert_not_called() + self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'}) + + @defer.inlineCallbacks + def test_invalidate(self): + """Make sure that invalidation callbacks are called.""" + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + pass + + @descriptors.cachedList("fn", "args1", inlineCallbacks=True) + def list_fn(self, args1, arg2): + # we want this to behave like an asynchronous function + yield run_on_reactor() + defer.returnValue(self.mock(args1, arg2)) + + obj = Cls() + invalidate0 = mock.Mock() + invalidate1 = mock.Mock() + + # cache miss + obj.mock.return_value = {10: 'fish', 20: 'chips'} + r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) + obj.mock.assert_called_once_with([10, 20], 2) + self.assertEqual(r1, {10: 'fish', 20: 'chips'}) + obj.mock.reset_mock() + + # cache hit + r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) + obj.mock.assert_not_called() + self.assertEqual(r2, {10: 'fish', 20: 'chips'}) + + invalidate0.assert_not_called() + invalidate1.assert_not_called() + + # now if we invalidate the keys, both invalidations should get called + obj.fn.invalidate((10, 2)) + invalidate0.assert_called_once() + invalidate1.assert_called_once() diff --git a/tests/utils.py b/tests/utils.py index c3dbff8507..9bff3ff3b9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -193,7 +193,7 @@ class MockHttpResource(HttpServer): self.prefix = prefix def trigger_get(self, path): - return self.trigger("GET", path, None) + return self.trigger(b"GET", path, None) @patch('twisted.web.http.Request') @defer.inlineCallbacks @@ -227,7 +227,7 @@ class MockHttpResource(HttpServer): headers = {} if federation_auth: - headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] + headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it @@ -241,6 +241,9 @@ class MockHttpResource(HttpServer): except Exception: pass + if isinstance(path, bytes): + path = path.decode('utf8') + for (method, pattern, func) in self.callbacks: if http_method != method: continue @@ -249,7 +252,7 @@ class MockHttpResource(HttpServer): if matcher: try: args = [ - urlparse.unquote(u).decode("UTF-8") + urlparse.unquote(u) for u in matcher.groups() ] |