summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9626.feature1
-rw-r--r--docs/spam_checker.md8
-rw-r--r--synapse/events/spamcheck.py29
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--tests/handlers/test_register.py31
5 files changed, 67 insertions, 6 deletions
diff --git a/changelog.d/9626.feature b/changelog.d/9626.feature
new file mode 100644
index 0000000000..eacba6201b
--- /dev/null
+++ b/changelog.d/9626.feature
@@ -0,0 +1 @@
+Tell spam checker modules about the SSO IdP a user registered through if one was used.
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index 2020eb9006..52947f605e 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -69,7 +69,13 @@ class ExampleSpamChecker:
     async def check_username_for_spam(self, user_profile):
         return False  # allow all usernames
 
-    async def check_registration_for_spam(self, email_threepid, username, request_info):
+    async def check_registration_for_spam(
+        self,
+        email_threepid,
+        username,
+        request_info,
+        auth_provider_id,
+    ):
         return RegistrationBehaviour.ALLOW  # allow all registrations
 
     async def check_media_file_for_spam(self, file_wrapper, file_info):
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 8cfc0bb3cb..a9185987a2 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import inspect
+import logging
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
 from synapse.rest.media.v1._base import FileInfo
@@ -27,6 +28,8 @@ if TYPE_CHECKING:
     import synapse.events
     import synapse.server
 
+logger = logging.getLogger(__name__)
+
 
 class SpamChecker:
     def __init__(self, hs: "synapse.server.HomeServer"):
@@ -190,6 +193,7 @@ class SpamChecker:
         email_threepid: Optional[dict],
         username: Optional[str],
         request_info: Collection[Tuple[str, str]],
+        auth_provider_id: Optional[str] = None,
     ) -> RegistrationBehaviour:
         """Checks if we should allow the given registration request.
 
@@ -198,6 +202,9 @@ class SpamChecker:
             username: The request user name, if any
             request_info: List of tuples of user agent and IP that
                 were used during the registration process.
+            auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+                "cas". If any. Note this does not include users registered
+                via a password provider.
 
         Returns:
             Enum for how the request should be handled
@@ -208,9 +215,25 @@ class SpamChecker:
             # spam checker
             checker = getattr(spam_checker, "check_registration_for_spam", None)
             if checker:
-                behaviour = await maybe_awaitable(
-                    checker(email_threepid, username, request_info)
-                )
+                # Provide auth_provider_id if the function supports it
+                checker_args = inspect.signature(checker)
+                if len(checker_args.parameters) == 4:
+                    d = checker(
+                        email_threepid,
+                        username,
+                        request_info,
+                        auth_provider_id,
+                    )
+                elif len(checker_args.parameters) == 3:
+                    d = checker(email_threepid, username, request_info)
+                else:
+                    logger.error(
+                        "Invalid signature for %s.check_registration_for_spam. Denying registration",
+                        spam_checker.__module__,
+                    )
+                    return RegistrationBehaviour.DENY
+
+                behaviour = await maybe_awaitable(d)
                 assert isinstance(behaviour, RegistrationBehaviour)
                 if behaviour != RegistrationBehaviour.ALLOW:
                     return behaviour
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cd001e87c7..1abc8875cb 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -198,8 +198,7 @@ class RegistrationHandler(BaseHandler):
               admin api, otherwise False.
             user_agent_ips: Tuples of IP addresses and user-agents used
                 during the registration process.
-            auth_provider_id: The SSO IdP the user used, if any (just used for the
-                prometheus metrics).
+            auth_provider_id: The SSO IdP the user used, if any.
         Returns:
             The registered user_id.
         Raises:
@@ -211,6 +210,7 @@ class RegistrationHandler(BaseHandler):
             threepid,
             localpart,
             user_agent_ips or [],
+            auth_provider_id=auth_provider_id,
         )
 
         if result == RegistrationBehaviour.DENY:
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..94b6903594 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         self.assertTrue(requester.shadow_banned)
 
+    def test_spam_checker_receives_sso_type(self):
+        """Test rejecting registration based on SSO type"""
+
+        class BanBadIdPUser:
+            def check_registration_for_spam(
+                self, email_threepid, username, request_info, auth_provider_id=None
+            ):
+                # Reject any user coming from CAS and whose username contains profanity
+                if auth_provider_id == "cas" and "flimflob" in username:
+                    return RegistrationBehaviour.DENY
+                return RegistrationBehaviour.ALLOW
+
+        # Configure a spam checker that denies a certain user on a specific IdP
+        spam_checker = self.hs.get_spam_checker()
+        spam_checker.spam_checkers = [BanBadIdPUser()]
+
+        f = self.get_failure(
+            self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
+            SynapseError,
+        )
+        exception = f.value
+
+        # We return 429 from the spam checker for denied registrations
+        self.assertIsInstance(exception, SynapseError)
+        self.assertEqual(exception.code, 429)
+
+        # Check the same username can register using SAML
+        self.get_success(
+            self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
+        )
+
     async def get_or_create_user(
         self, requester, localpart, displayname, password_hash=None
     ):