summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/api/test_auth.py35
-rw-r--r--tests/handlers/test_auth.py80
-rw-r--r--tests/handlers/test_register.py51
-rw-r--r--tests/handlers/test_typing.py1
-rw-r--r--tests/storage/test__init__.py65
-rw-r--r--tests/util/caches/test_descriptors.py101
-rw-r--r--tests/utils.py9
7 files changed, 310 insertions, 32 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 5f158ec4b9..a82d737e71 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
         self.auth = Auth(self.hs)
 
         self.test_user = "@foo:bar"
-        self.test_token = "_test_token_"
+        self.test_token = b"_test_token_"
 
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
@@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
@@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
@@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
@@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "192.168.10.10"
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
@@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "131.111.8.42"
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
@@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
@@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
-        masquerading_user_id = "@doppelganger:matrix.org"
+        masquerading_user_id = b"@doppelganger:matrix.org"
         app_service = Mock(
             token="foobar", url="a_url", sender=self.test_user,
             ip_range_whitelist=None,
@@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
-        request.args["access_token"] = [self.test_token]
-        request.args["user_id"] = [masquerading_user_id]
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
-        self.assertEquals(requester.user.to_string(), masquerading_user_id)
+        self.assertEquals(
+            requester.user.to_string(),
+            masquerading_user_id.decode('utf8')
+        )
 
     def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
-        masquerading_user_id = "@doppelganger:matrix.org"
+        masquerading_user_id = b"@doppelganger:matrix.org"
         app_service = Mock(
             token="foobar", url="a_url", sender=self.test_user,
             ip_range_whitelist=None,
@@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
-        request.args["access_token"] = [self.test_token]
-        request.args["user_id"] = [masquerading_user_id]
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
@@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
 
         # check the token works
         request = Mock(args={})
-        request.args["access_token"] = [token]
+        request.args[b"access_token"] = [token.encode('ascii')]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         self.assertEqual(UserID.from_string(USER_ID), requester.user)
@@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
 
         # the token should *not* work now
         request = Mock(args={})
-        request.args["access_token"] = [guest_tok]
+        request.args[b"access_token"] = [guest_tok.encode('ascii')]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         with self.assertRaises(AuthError) as cm:
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 2e5e8e4dec..55eab9e9cf 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from mock import Mock
 
 import pymacaroons
 
@@ -19,6 +20,7 @@ from twisted.internet import defer
 
 import synapse
 import synapse.api.errors
+from synapse.api.errors import AuthError
 from synapse.handlers.auth import AuthHandler
 
 from tests import unittest
@@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
         self.hs.handlers = AuthHandlers(self.hs)
         self.auth_handler = self.hs.handlers.auth_handler
         self.macaroon_generator = self.hs.get_macaroon_generator()
+        # MAU tests
+        self.hs.config.max_mau_value = 50
+        self.small_number_of_users = 1
+        self.large_number_of_users = 100
 
     def test_token_is_a_macaroon(self):
         token = self.macaroon_generator.generate_access_token("some_user")
@@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
         v.satisfy_general(verify_nonce)
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
+    @defer.inlineCallbacks
     def test_short_term_login_token_gives_user_id(self):
         self.hs.clock.now = 1000
 
         token = self.macaroon_generator.generate_short_term_login_token(
             "a_user", 5000
         )
-
-        self.assertEqual(
-            "a_user",
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                token
-            )
+        user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            token
         )
+        self.assertEqual("a_user", user_id)
 
         # when we advance the clock, the token should be rejected
         self.hs.clock.now = 6000
         with self.assertRaises(synapse.api.errors.AuthError):
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 token
             )
 
+    @defer.inlineCallbacks
     def test_short_term_login_token_cannot_replace_user_id(self):
         token = self.macaroon_generator.generate_short_term_login_token(
             "a_user", 5000
         )
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
+        user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            macaroon.serialize()
+        )
         self.assertEqual(
-            "a_user",
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            )
+            "a_user", user_id
         )
 
         # add another "user_id" caveat, which might allow us to override the
@@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("user_id = b_user")
 
         with self.assertRaises(synapse.api.errors.AuthError):
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 macaroon.serialize()
             )
+
+    @defer.inlineCallbacks
+    def test_mau_limits_disabled(self):
+        self.hs.config.limit_usage_by_mau = False
+        # Ensure does not throw exception
+        yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+        yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self._get_macaroon().serialize()
+        )
+
+    @defer.inlineCallbacks
+    def test_mau_limits_exceeded(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.hs.get_datastore().count_monthly_users = Mock(
+            return_value=defer.succeed(self.large_number_of_users)
+        )
+
+        with self.assertRaises(AuthError):
+            yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+        self.hs.get_datastore().count_monthly_users = Mock(
+            return_value=defer.succeed(self.large_number_of_users)
+        )
+        with self.assertRaises(AuthError):
+            yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            )
+
+    @defer.inlineCallbacks
+    def test_mau_limits_not_exceeded(self):
+        self.hs.config.limit_usage_by_mau = True
+
+        self.hs.get_datastore().count_monthly_users = Mock(
+            return_value=defer.succeed(self.small_number_of_users)
+        )
+        # Ensure does not raise exception
+        yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+        self.hs.get_datastore().count_monthly_users = Mock(
+            return_value=defer.succeed(self.small_number_of_users)
+        )
+        yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self._get_macaroon().serialize()
+        )
+
+    def _get_macaroon(self):
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "user_a", 5000
+        )
+        return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 025fa1be81..0937d71cf6 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,6 +17,7 @@ from mock import Mock
 
 from twisted.internet import defer
 
+from synapse.api.errors import RegistrationError
 from synapse.handlers.register import RegistrationHandler
 from synapse.types import UserID, create_requester
 
@@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase):
             requester, local_part, display_name)
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
+
+    @defer.inlineCallbacks
+    def test_cannot_register_when_mau_limits_exceeded(self):
+        local_part = "someone"
+        display_name = "someone"
+        requester = create_requester("@as:test")
+        store = self.hs.get_datastore()
+        self.hs.config.limit_usage_by_mau = False
+        self.hs.config.max_mau_value = 50
+        lots_of_users = 100
+        small_number_users = 1
+
+        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+        # Ensure does not throw exception
+        yield self.handler.get_or_create_user(requester, 'a', display_name)
+
+        self.hs.config.limit_usage_by_mau = True
+
+        with self.assertRaises(RegistrationError):
+            yield self.handler.get_or_create_user(requester, 'b', display_name)
+
+        store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
+
+        self._macaroon_mock_generator("another_secret")
+
+        # Ensure does not throw exception
+        yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
+
+        self._macaroon_mock_generator("another another secret")
+        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+        with self.assertRaises(RegistrationError):
+            yield self.handler.register(localpart=local_part)
+
+        self._macaroon_mock_generator("another another secret")
+        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+        with self.assertRaises(RegistrationError):
+            yield self.handler.register_saml2(local_part)
+
+    def _macaroon_mock_generator(self, secret):
+        """
+        Reset macaroon generator in the case where the test creates multiple users
+        """
+        macaroon_generator = Mock(
+            generate_access_token=Mock(return_value=secret))
+        self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
+        self.hs.handlers = RegistrationHandlers(self.hs)
+        self.handler = self.hs.get_handlers().registration_handler
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index b08856f763..2c263af1a3 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
                 "content": content,
             }
         ],
-        "pdu_failures": [],
     }
 
 
diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py
new file mode 100644
index 0000000000..f19cb1265c
--- /dev/null
+++ b/tests/storage/test__init__.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.utils
+
+
+class InitTestCase(tests.unittest.TestCase):
+    def __init__(self, *args, **kwargs):
+        super(InitTestCase, self).__init__(*args, **kwargs)
+        self.store = None  # type: synapse.storage.DataStore
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        hs = yield tests.utils.setup_test_homeserver()
+
+        hs.config.max_mau_value = 50
+        hs.config.limit_usage_by_mau = True
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+
+    @defer.inlineCallbacks
+    def test_count_monthly_users(self):
+        count = yield self.store.count_monthly_users()
+        self.assertEqual(0, count)
+
+        yield self._insert_user_ips("@user:server1")
+        yield self._insert_user_ips("@user:server2")
+
+        count = yield self.store.count_monthly_users()
+        self.assertEqual(2, count)
+
+    @defer.inlineCallbacks
+    def _insert_user_ips(self, user):
+        """
+        Helper function to populate user_ips without using batch insertion infra
+        args:
+            user (str):  specify username i.e. @user:server.com
+        """
+        yield self.store._simple_upsert(
+            table="user_ips",
+            keyvalues={
+                "user_id": user,
+                "access_token": "access_token",
+                "ip": "ip",
+                "user_agent": "user_agent",
+                "device_id": "device_id",
+            },
+            values={
+                "last_seen": self.clock.time_msec(),
+            }
+        )
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 8176a7dabd..ca8a7c907f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
         r = yield obj.fn(2, 3)
         self.assertEqual(r, 'chips')
         obj.mock.assert_not_called()
+
+
+class CachedListDescriptorTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def test_cache(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                assert (
+                    logcontext.LoggingContext.current_context().request == "c1"
+                )
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                assert (
+                    logcontext.LoggingContext.current_context().request == "c1"
+                )
+                defer.returnValue(self.mock(args1, arg2))
+
+        with logcontext.LoggingContext() as c1:
+            c1.request = "c1"
+            obj = Cls()
+            obj.mock.return_value = {10: 'fish', 20: 'chips'}
+            d1 = obj.list_fn([10, 20], 2)
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                logcontext.LoggingContext.sentinel,
+            )
+            r = yield d1
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                c1
+            )
+            obj.mock.assert_called_once_with([10, 20], 2)
+            self.assertEqual(r, {10: 'fish', 20: 'chips'})
+            obj.mock.reset_mock()
+
+            # a call with different params should call the mock again
+            obj.mock.return_value = {30: 'peas'}
+            r = yield obj.list_fn([20, 30], 2)
+            obj.mock.assert_called_once_with([30], 2)
+            self.assertEqual(r, {20: 'chips', 30: 'peas'})
+            obj.mock.reset_mock()
+
+            # all the values should now be cached
+            r = yield obj.fn(10, 2)
+            self.assertEqual(r, 'fish')
+            r = yield obj.fn(20, 2)
+            self.assertEqual(r, 'chips')
+            r = yield obj.fn(30, 2)
+            self.assertEqual(r, 'peas')
+            r = yield obj.list_fn([10, 20, 30], 2)
+            obj.mock.assert_not_called()
+            self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        """Make sure that invalidation callbacks are called."""
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                defer.returnValue(self.mock(args1, arg2))
+
+        obj = Cls()
+        invalidate0 = mock.Mock()
+        invalidate1 = mock.Mock()
+
+        # cache miss
+        obj.mock.return_value = {10: 'fish', 20: 'chips'}
+        r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+        obj.mock.assert_called_once_with([10, 20], 2)
+        self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+        obj.mock.reset_mock()
+
+        # cache hit
+        r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+        obj.mock.assert_not_called()
+        self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+        invalidate0.assert_not_called()
+        invalidate1.assert_not_called()
+
+        # now if we invalidate the keys, both invalidations should get called
+        obj.fn.invalidate((10, 2))
+        invalidate0.assert_called_once()
+        invalidate1.assert_called_once()
diff --git a/tests/utils.py b/tests/utils.py
index c3dbff8507..9bff3ff3b9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
         self.prefix = prefix
 
     def trigger_get(self, path):
-        return self.trigger("GET", path, None)
+        return self.trigger(b"GET", path, None)
 
     @patch('twisted.web.http.Request')
     @defer.inlineCallbacks
@@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
 
         headers = {}
         if federation_auth:
-            headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
+            headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
         mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
 
         # return the right path if the event requires it
@@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
         except Exception:
             pass
 
+        if isinstance(path, bytes):
+            path = path.decode('utf8')
+
         for (method, pattern, func) in self.callbacks:
             if http_method != method:
                 continue
@@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
             if matcher:
                 try:
                     args = [
-                        urlparse.unquote(u).decode("UTF-8")
+                        urlparse.unquote(u)
                         for u in matcher.groups()
                     ]