summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8034.feature1
-rw-r--r--synapse/events/spamcheck.py35
-rw-r--r--synapse/handlers/auth.py8
-rw-r--r--synapse/handlers/cas_handler.py11
-rw-r--r--synapse/handlers/oidc_handler.py21
-rw-r--r--synapse/handlers/register.py26
-rw-r--r--synapse/handlers/saml_handler.py18
-rw-r--r--synapse/rest/client/v2_alpha/register.py5
-rw-r--r--synapse/spam_checker_api/__init__.py11
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql25
-rw-r--r--synapse/storage/databases/main/ui_auth.py39
-rw-r--r--tests/handlers/test_oidc.py18
-rw-r--r--tests/handlers/test_register.py52
-rw-r--r--tests/handlers/test_user_directory.py6
14 files changed, 258 insertions, 18 deletions
diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8034.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 1ffc9525d1..a7cddac974 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,9 +15,10 @@
 # limitations under the License.
 
 import inspect
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional, Tuple
 
-from synapse.spam_checker_api import SpamCheckerApi
+from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
+from synapse.types import Collection
 
 MYPY = False
 if MYPY:
@@ -160,3 +161,33 @@ class SpamChecker(object):
                     return True
 
         return False
+
+    def check_registration_for_spam(
+        self,
+        email_threepid: Optional[dict],
+        username: Optional[str],
+        request_info: Collection[Tuple[str, str]],
+    ) -> RegistrationBehaviour:
+        """Checks if we should allow the given registration request.
+
+        Args:
+            email_threepid: The email threepid used for registering, if any
+            username: The request user name, if any
+            request_info: List of tuples of user agent and IP that
+                were used during the registration process.
+
+        Returns:
+            Enum for how the request should be handled
+        """
+
+        for spam_checker in self.spam_checkers:
+            # For backwards compatibility, only run if the method exists on the
+            # spam checker
+            checker = getattr(spam_checker, "check_registration_for_spam", None)
+            if checker:
+                behaviour = checker(email_threepid, username, request_info)
+                assert isinstance(behaviour, RegistrationBehaviour)
+                if behaviour != RegistrationBehaviour.ALLOW:
+                    return behaviour
+
+        return RegistrationBehaviour.ALLOW
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 68d6870e40..654f58ddae 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -364,6 +364,14 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+
+        await self.store.add_user_agent_ip_to_ui_auth_session(
+            session.session_id, user_agent, clientip
+        )
+
         if not authdict:
             raise InteractiveAuthIncompleteError(
                 session.session_id, self._auth_dict_for_flows(flows, session.session_id)
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 786e608fa2..a4cc4b9a5a 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -35,6 +35,7 @@ class CasHandler:
     """
 
     def __init__(self, hs):
+        self.hs = hs
         self._hostname = hs.hostname
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
@@ -210,8 +211,16 @@ class CasHandler:
 
         else:
             if not registered_user_id:
+                # Pull out the user-agent and IP from the request.
+                user_agent = request.requestHeaders.getRawHeaders(
+                    b"User-Agent", default=[b""]
+                )[0].decode("ascii", "surrogateescape")
+                ip_address = self.hs.get_ip_from_request(request)
+
                 registered_user_id = await self._registration_handler.register_user(
-                    localpart=localpart, default_display_name=user_display_name
+                    localpart=localpart,
+                    default_display_name=user_display_name,
+                    user_agent_ips=(user_agent, ip_address),
                 )
 
             await self._auth_handler.complete_sso_login(
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index dd3703cbd2..c5bd2fea68 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -93,6 +93,7 @@ class OidcHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
+        self.hs = hs
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
         self._client_auth = ClientAuth(
@@ -689,9 +690,17 @@ class OidcHandler:
                 self._render_error(request, "invalid_token", str(e))
                 return
 
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_userinfo_to_user(userinfo, token)
+            user_id = await self._map_userinfo_to_user(
+                userinfo, token, user_agent, ip_address
+            )
         except MappingException as e:
             logger.exception("Could not map user")
             self._render_error(request, "mapping_error", str(e))
@@ -828,7 +837,9 @@ class OidcHandler:
         now = self._clock.time_msec()
         return now < expiry
 
-    async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+    async def _map_userinfo_to_user(
+        self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
+    ) -> str:
         """Maps a UserInfo object to a mxid.
 
         UserInfo should have a claim that uniquely identifies users. This claim
@@ -843,6 +854,8 @@ class OidcHandler:
         Args:
             userinfo: an object representing the user
             token: a dict with the tokens obtained from the provider
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
 
         Raises:
             MappingException: if there was an error while mapping some properties
@@ -899,7 +912,9 @@ class OidcHandler:
         # It's the first time this user is logging in and the mapped mxid was
         # not taken, register the user
         registered_user_id = await self._registration_handler.register_user(
-            localpart=localpart, default_display_name=attributes["display_name"],
+            localpart=localpart,
+            default_display_name=attributes["display_name"],
+            user_agent_ips=(user_agent, ip_address),
         )
 
         await self._datastore.record_user_external_id(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ccd96e4626..cde2dbca92 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -26,6 +26,7 @@ from synapse.replication.http.register import (
     ReplicationPostRegisterActionsServlet,
     ReplicationRegisterServlet,
 )
+from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
 
@@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler):
         self.macaroon_gen = hs.get_macaroon_generator()
         self._server_notices_mxid = hs.config.server_notices_mxid
 
+        self.spam_checker = hs.get_spam_checker()
+
         if hs.config.worker_app:
             self._register_client = ReplicationRegisterServlet.make_client(hs)
             self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler):
         address=None,
         bind_emails=[],
         by_admin=False,
-        shadow_banned=False,
+        user_agent_ips=None,
     ):
         """Registers a new client on the server.
 
@@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler):
             bind_emails (List[str]): list of emails to bind to this account.
             by_admin (bool): True if this registration is being made via the
               admin api, otherwise False.
-            shadow_banned (bool): Shadow-ban the created user.
+            user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+                during the registration process.
         Returns:
             str: user_id
         Raises:
@@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler):
         """
         self.check_registration_ratelimit(address)
 
+        result = self.spam_checker.check_registration_for_spam(
+            threepid, localpart, user_agent_ips or [],
+        )
+
+        if result == RegistrationBehaviour.DENY:
+            logger.info(
+                "Blocked registration of %r", localpart,
+            )
+            # We return a 429 to make it not obvious that they've been
+            # denied.
+            raise SynapseError(429, "Rate limited")
+
+        shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
+        if shadow_banned:
+            logger.info(
+                "Shadow banning registration of %r", localpart,
+            )
+
         # do not check_auth_blocking if the call is coming through the Admin API
         if not by_admin:
             await self.auth.check_auth_blocking(threepid=threepid)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index c1fcb98454..b426199aa6 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -54,6 +54,7 @@ class Saml2SessionData:
 
 class SamlHandler:
     def __init__(self, hs: "synapse.server.HomeServer"):
+        self.hs = hs
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
@@ -133,8 +134,14 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
         user_id, current_session = await self._map_saml_response_to_user(
-            resp_bytes, relay_state
+            resp_bytes, relay_state, user_agent, ip_address
         )
 
         # Complete the interactive auth session or the login.
@@ -147,7 +154,11 @@ class SamlHandler:
             await self._auth_handler.complete_sso_login(user_id, request, relay_state)
 
     async def _map_saml_response_to_user(
-        self, resp_bytes: str, client_redirect_url: str
+        self,
+        resp_bytes: str,
+        client_redirect_url: str,
+        user_agent: str,
+        ip_address: str,
     ) -> Tuple[str, Optional[Saml2SessionData]]:
         """
         Given a sample response, retrieve the cached session and user for it.
@@ -155,6 +166,8 @@ class SamlHandler:
         Args:
             resp_bytes: The SAML response.
             client_redirect_url: The redirect URL passed in by the client.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
 
         Returns:
              Tuple of the user ID and SAML session associated with this response.
@@ -291,6 +304,7 @@ class SamlHandler:
                 localpart=localpart,
                 default_display_name=displayname,
                 bind_emails=emails,
+                user_agent_ips=(user_agent, ip_address),
             )
 
             await self._datastore.record_user_external_id(
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 7290fd0756..be0e680ac5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet):
                                 Codes.THREEPID_IN_USE,
                             )
 
+            entries = await self.store.get_user_agents_ips_to_ui_auth_session(
+                session_id
+            )
+
             registered_user_id = await self.registration_handler.register_user(
                 localpart=desired_username,
                 password_hash=password_hash,
                 guest_access_token=guest_access_token,
                 threepid=threepid,
                 address=client_addr,
+                user_agent_ips=entries,
             )
             # Necessary due to auth checks prior to the threepid being
             # written to the db
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 7f63f1bfa0..9be92e2565 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from enum import Enum
 
 from twisted.internet import defer
 
@@ -25,6 +26,16 @@ if MYPY:
 logger = logging.getLogger(__name__)
 
 
+class RegistrationBehaviour(Enum):
+    """
+    Enum to define whether a registration request should allowed, denied, or shadow-banned.
+    """
+
+    ALLOW = "allow"
+    SHADOW_BAN = "shadow_ban"
+    DENY = "deny"
+
+
 class SpamCheckerApi(object):
     """A proxy object that gets passed to spam checkers so they can get
     access to rooms and other relevant information.
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+    session_id TEXT NOT NULL,
+    ip TEXT NOT NULL,
+    user_agent TEXT NOT NULL,
+    UNIQUE (session_id, ip, user_agent),
+    FOREIGN KEY (session_id)
+        REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 6281a41a3d..9eef8e57c5 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import attr
 
@@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
 
         return serverdict.get(key, default)
 
+    async def add_user_agent_ip_to_ui_auth_session(
+        self, session_id: str, user_agent: str, ip: str,
+    ):
+        """Add the given user agent / IP to the tracking table
+        """
+        await self.db_pool.simple_upsert(
+            table="ui_auth_sessions_ips",
+            keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+            values={},
+            desc="add_user_agent_ip_to_ui_auth_session",
+        )
+
+    async def get_user_agents_ips_to_ui_auth_session(
+        self, session_id: str,
+    ) -> List[Tuple[str, str]]:
+        """Get the given user agents / IPs used during the ui auth process
+
+        Returns:
+            List of user_agent/ip pairs
+        """
+        rows = await self.db_pool.simple_select_list(
+            table="ui_auth_sessions_ips",
+            keyvalues={"session_id": session_id},
+            retcols=("user_agent", "ip"),
+            desc="get_user_agents_ips_to_ui_auth_session",
+        )
+        return [(row["user_agent"], row["ip"]) for row in rows]
+
 
 class UIAuthStore(UIAuthWorkerStore):
     def delete_old_ui_auth_sessions(self, expiration_time: int):
@@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore):
         txn.execute(sql, [expiration_time])
         session_ids = [r[0] for r in txn.fetchall()]
 
+        # Delete the corresponding IP/user agents.
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="ui_auth_sessions_ips",
+            column="session_id",
+            iterable=session_ids,
+            keyvalues={},
+        )
+
         # Delete the corresponding completed credentials.
         self.db_pool.simple_delete_many_txn(
             txn,
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
         self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
         self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(spec=["args", "getCookie", "addCookie"])
+        request = Mock(
+            spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+        )
 
         code = "code"
         state = "state"
         nonce = "nonce"
         client_redirect_url = "http://client/redirect"
+        user_agent = "Browser"
+        ip_address = "10.0.0.1"
         session = self.handler._generate_oidc_session_token(
             state=state,
             nonce=nonce,
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         request.args[b"code"] = [code.encode("utf-8")]
         request.args[b"state"] = [state.encode("utf-8")]
 
+        request.requestHeaders = Mock(spec=["getRawHeaders"])
+        request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+        request.getClientIP.return_value = ip_address
+
         yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
 
         self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+        self.handler._map_userinfo_to_user.assert_called_once_with(
+            userinfo, token, user_agent, ip_address
+        )
         self.handler._fetch_userinfo.assert_not_called()
         self.handler._render_error.assert_not_called()
 
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+        self.handler._map_userinfo_to_user.assert_called_once_with(
+            userinfo, token, user_agent, ip_address
+        )
         self.handler._fetch_userinfo.assert_called_once_with(token)
         self.handler._render_error.assert_not_called()
 
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..5c92d0e8c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,18 +17,21 @@ from mock import Mock
 
 from twisted.internet import defer
 
+from synapse.api.auth import Auth
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, ResourceLimitError, SynapseError
 from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import RoomAlias, UserID, create_requester
 
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
 
 from .. import unittest
 
 
-class RegistrationHandlers(object):
+class RegistrationHandlers:
     def __init__(self, hs):
         self.registration_handler = RegistrationHandler(hs)
 
@@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
+    def test_spam_checker_deny(self):
+        """A spam checker can deny registration, which results in an error."""
+
+        class DenyAll:
+            def check_registration_for_spam(
+                self, email_threepid, username, request_info
+            ):
+                return RegistrationBehaviour.DENY
+
+        # Configure a spam checker that denies all users.
+        spam_checker = self.hs.get_spam_checker()
+        spam_checker.spam_checkers = [DenyAll()]
+
+        self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+    def test_spam_checker_shadow_ban(self):
+        """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+        class BanAll:
+            def check_registration_for_spam(
+                self, email_threepid, username, request_info
+            ):
+                return RegistrationBehaviour.SHADOW_BAN
+
+        # Configure a spam checker that denies all users.
+        spam_checker = self.hs.get_spam_checker()
+        spam_checker.spam_checkers = [BanAll()]
+
+        user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+        # Get an access token.
+        token = self.macaroon_generator.generate_access_token(user_id)
+        self.get_success(
+            self.store.add_access_token_to_user(
+                user_id=user_id, token=token, device_id=None, valid_until_ms=None
+            )
+        )
+
+        # Ensure the user was marked as shadow-banned.
+        request = Mock(args={})
+        request.args[b"access_token"] = [token.encode("ascii")]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        auth = Auth(self.hs)
+        requester = self.get_success(auth.get_user_by_req(request))
+
+        self.assertTrue(requester.shadow_banned)
+
     async def get_or_create_user(
         self, requester, localpart, displayname, password_hash=None
     ):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def test_spam_checker(self):
         """
-        A user which fails to the spam checks will not appear in search results.
+        A user which fails the spam checks will not appear in search results.
         """
         u1 = self.register_user("user1", "pass")
         u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # Configure a spam checker that does not filter any users.
         spam_checker = self.hs.get_spam_checker()
 
-        class AllowAll(object):
+        class AllowAll:
             def check_username_for_spam(self, user_profile):
                 # Allow all users.
                 return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(s["results"]), 1)
 
         # Configure a spam checker that filters all users.
-        class BlockAll(object):
+        class BlockAll:
             def check_username_for_spam(self, user_profile):
                 # All users are spammy.
                 return True