summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_oidc.py18
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py60
-rw-r--r--tests/handlers/test_register.py52
-rw-r--r--tests/handlers/test_typing.py32
-rw-r--r--tests/handlers/test_user_directory.py6
6 files changed, 139 insertions, 31 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
         self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
         self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(spec=["args", "getCookie", "addCookie"])
+        request = Mock(
+            spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+        )
 
         code = "code"
         state = "state"
         nonce = "nonce"
         client_redirect_url = "http://client/redirect"
+        user_agent = "Browser"
+        ip_address = "10.0.0.1"
         session = self.handler._generate_oidc_session_token(
             state=state,
             nonce=nonce,
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         request.args[b"code"] = [code.encode("utf-8")]
         request.args[b"state"] = [state.encode("utf-8")]
 
+        request.requestHeaders = Mock(spec=["getRawHeaders"])
+        request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+        request.getClientIP.return_value = ip_address
+
         yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
 
         self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+        self.handler._map_userinfo_to_user.assert_called_once_with(
+            userinfo, token, user_agent, ip_address
+        )
         self.handler._fetch_userinfo.assert_not_called()
         self.handler._render_error.assert_not_called()
 
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+        self.handler._map_userinfo_to_user.assert_called_once_with(
+            userinfo, token, user_agent, ip_address
+        )
         self.handler._fetch_userinfo.assert_called_once_with(token)
         self.handler._render_error.assert_not_called()
 
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 05ea40a7de..306dcfe944 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,6 +19,7 @@ from mock import Mock, call
 from signedjson.key import generate_signing_key
 
 from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events.builder import EventBuilder
 from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
     handle_update,
 )
 from synapse.rest.client.v1 import room
-from synapse.storage.presence import UserPresenceState
 from synapse.types import UserID, get_domain_from_id
 
 from tests import unittest
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d70e1fc608..60ebc95f3e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -64,14 +64,16 @@ class ProfileTestCase(unittest.TestCase):
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        yield self.store.create_profile(self.frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
 
         self.handler = hs.get_profile_handler()
         self.hs = hs
 
     @defer.inlineCallbacks
     def test_get_my_name(self):
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        )
 
         displayname = yield defer.ensureDeferred(
             self.handler.get_displayname(self.frank)
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
+            "Frank",
         )
 
     @defer.inlineCallbacks
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
+            "Frank",
         )
 
         # Setting displayname a second time is forbidden
@@ -157,8 +171,10 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield self.store.create_profile("caroline")
-        yield self.store.set_profile_displayname("caroline", "Caroline")
+        yield defer.ensureDeferred(self.store.create_profile("caroline"))
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname("caroline", "Caroline")
+        )
 
         response = yield defer.ensureDeferred(
             self.query_handlers["profile"](
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_my_avatar(self):
-        yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.frank.localpart, "http://my.server/me.png"
+            )
         )
         avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
 
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/pic.gif",
         )
 
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/me.png",
         )
 
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.frank.localpart, "http://my.server/me.png"
+            )
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/me.png",
         )
 
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..5c92d0e8c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,18 +17,21 @@ from mock import Mock
 
 from twisted.internet import defer
 
+from synapse.api.auth import Auth
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, ResourceLimitError, SynapseError
 from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import RoomAlias, UserID, create_requester
 
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
 
 from .. import unittest
 
 
-class RegistrationHandlers(object):
+class RegistrationHandlers:
     def __init__(self, hs):
         self.registration_handler = RegistrationHandler(hs)
 
@@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
+    def test_spam_checker_deny(self):
+        """A spam checker can deny registration, which results in an error."""
+
+        class DenyAll:
+            def check_registration_for_spam(
+                self, email_threepid, username, request_info
+            ):
+                return RegistrationBehaviour.DENY
+
+        # Configure a spam checker that denies all users.
+        spam_checker = self.hs.get_spam_checker()
+        spam_checker.spam_checkers = [DenyAll()]
+
+        self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+    def test_spam_checker_shadow_ban(self):
+        """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+        class BanAll:
+            def check_registration_for_spam(
+                self, email_threepid, username, request_info
+            ):
+                return RegistrationBehaviour.SHADOW_BAN
+
+        # Configure a spam checker that denies all users.
+        spam_checker = self.hs.get_spam_checker()
+        spam_checker.spam_checkers = [BanAll()]
+
+        user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+        # Get an access token.
+        token = self.macaroon_generator.generate_access_token(user_id)
+        self.get_success(
+            self.store.add_access_token_to_user(
+                user_id=user_id, token=token, device_id=None, valid_until_ms=None
+            )
+        )
+
+        # Ensure the user was marked as shadow-banned.
+        request = Mock(args={})
+        request.args[b"access_token"] = [token.encode("ascii")]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        auth = Auth(self.hs)
+        requester = self.get_success(auth.get_user_by_req(request))
+
+        self.assertTrue(requester.shadow_banned)
+
     async def get_or_create_user(
         self, requester, localpart, displayname, password_hash=None
     ):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 64afd581bc..81c1839637 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,7 +21,7 @@ from mock import ANY, Mock, call
 from twisted.internet import defer
 
 from synapse.api.errors import AuthError
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore.get_users_in_room = get_users_in_room
 
-        self.datastore.get_user_directory_stream_pos.return_value = (
+        self.datastore.get_user_directory_stream_pos.side_effect = (
             # we deliberately return a non-None stream pos to avoid doing an initial_spam
-            defer.succeed(1)
+            lambda: make_awaitable(1)
         )
 
         self.datastore.get_current_state_deltas.return_value = (0, None)
@@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ([], 0)
         )
         self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
-        self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+        self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
             None
         )
 
@@ -167,7 +167,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.get_success(
             self.handler.started_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
+                timeout=20000,
             )
         )
 
@@ -194,7 +197,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.get_success(
             self.handler.started_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
+                timeout=20000,
             )
         )
 
@@ -269,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.get_success(
             self.handler.stopped_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
             )
         )
 
@@ -309,7 +317,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.get_success(
             self.handler.started_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
+                timeout=10000,
             )
         )
 
@@ -348,7 +359,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.get_success(
             self.handler.started_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
+                timeout=10000,
             )
         )
 
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def test_spam_checker(self):
         """
-        A user which fails to the spam checks will not appear in search results.
+        A user which fails the spam checks will not appear in search results.
         """
         u1 = self.register_user("user1", "pass")
         u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # Configure a spam checker that does not filter any users.
         spam_checker = self.hs.get_spam_checker()
 
-        class AllowAll(object):
+        class AllowAll:
             def check_username_for_spam(self, user_profile):
                 # Allow all users.
                 return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(s["results"]), 1)
 
         # Configure a spam checker that filters all users.
-        class BlockAll(object):
+        class BlockAll:
             def check_username_for_spam(self, user_profile):
                 # All users are spammy.
                 return True