diff --git a/changelog.d/9626.feature b/changelog.d/9626.feature
new file mode 100644
index 0000000000..eacba6201b
--- /dev/null
+++ b/changelog.d/9626.feature
@@ -0,0 +1 @@
+Tell spam checker modules about the SSO IdP a user registered through if one was used.
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index 2020eb9006..52947f605e 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -69,7 +69,13 @@ class ExampleSpamChecker:
async def check_username_for_spam(self, user_profile):
return False # allow all usernames
- async def check_registration_for_spam(self, email_threepid, username, request_info):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
return RegistrationBehaviour.ALLOW # allow all registrations
async def check_media_file_for_spam(self, file_wrapper, file_info):
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 8cfc0bb3cb..a9185987a2 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,6 +15,7 @@
# limitations under the License.
import inspect
+import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.rest.media.v1._base import FileInfo
@@ -27,6 +28,8 @@ if TYPE_CHECKING:
import synapse.events
import synapse.server
+logger = logging.getLogger(__name__)
+
class SpamChecker:
def __init__(self, hs: "synapse.server.HomeServer"):
@@ -190,6 +193,7 @@ class SpamChecker:
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str] = None,
) -> RegistrationBehaviour:
"""Checks if we should allow the given registration request.
@@ -198,6 +202,9 @@ class SpamChecker:
username: The request user name, if any
request_info: List of tuples of user agent and IP that
were used during the registration process.
+ auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+ "cas". If any. Note this does not include users registered
+ via a password provider.
Returns:
Enum for how the request should be handled
@@ -208,9 +215,25 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
- behaviour = await maybe_awaitable(
- checker(email_threepid, username, request_info)
- )
+ # Provide auth_provider_id if the function supports it
+ checker_args = inspect.signature(checker)
+ if len(checker_args.parameters) == 4:
+ d = checker(
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ )
+ elif len(checker_args.parameters) == 3:
+ d = checker(email_threepid, username, request_info)
+ else:
+ logger.error(
+ "Invalid signature for %s.check_registration_for_spam. Denying registration",
+ spam_checker.__module__,
+ )
+ return RegistrationBehaviour.DENY
+
+ behaviour = await maybe_awaitable(d)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cd001e87c7..1abc8875cb 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -198,8 +198,7 @@ class RegistrationHandler(BaseHandler):
admin api, otherwise False.
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
- auth_provider_id: The SSO IdP the user used, if any (just used for the
- prometheus metrics).
+ auth_provider_id: The SSO IdP the user used, if any.
Returns:
The registered user_id.
Raises:
@@ -211,6 +210,7 @@ class RegistrationHandler(BaseHandler):
threepid,
localpart,
user_agent_ips or [],
+ auth_provider_id=auth_provider_id,
)
if result == RegistrationBehaviour.DENY:
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..94b6903594 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ def test_spam_checker_receives_sso_type(self):
+ """Test rejecting registration based on SSO type"""
+
+ class BanBadIdPUser:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+ # Configure a spam checker that denies a certain user on a specific IdP
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanBadIdPUser()]
+
+ f = self.get_failure(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
+ SynapseError,
+ )
+ exception = f.value
+
+ # We return 429 from the spam checker for denied registrations
+ self.assertIsInstance(exception, SynapseError)
+ self.assertEqual(exception.code, 429)
+
+ # Check the same username can register using SAML
+ self.get_success(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
+ )
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
|