diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8034.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 1ffc9525d1..a7cddac974 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,9 +15,10 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional, Tuple
-from synapse.spam_checker_api import SpamCheckerApi
+from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
+from synapse.types import Collection
MYPY = False
if MYPY:
@@ -160,3 +161,33 @@ class SpamChecker(object):
return True
return False
+
+ def check_registration_for_spam(
+ self,
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ ) -> RegistrationBehaviour:
+ """Checks if we should allow the given registration request.
+
+ Args:
+ email_threepid: The email threepid used for registering, if any
+ username: The request user name, if any
+ request_info: List of tuples of user agent and IP that
+ were used during the registration process.
+
+ Returns:
+ Enum for how the request should be handled
+ """
+
+ for spam_checker in self.spam_checkers:
+ # For backwards compatibility, only run if the method exists on the
+ # spam checker
+ checker = getattr(spam_checker, "check_registration_for_spam", None)
+ if checker:
+ behaviour = checker(email_threepid, username, request_info)
+ assert isinstance(behaviour, RegistrationBehaviour)
+ if behaviour != RegistrationBehaviour.ALLOW:
+ return behaviour
+
+ return RegistrationBehaviour.ALLOW
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 68d6870e40..654f58ddae 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -364,6 +364,14 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+
+ await self.store.add_user_agent_ip_to_ui_auth_session(
+ session.session_id, user_agent, clientip
+ )
+
if not authdict:
raise InteractiveAuthIncompleteError(
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 786e608fa2..a4cc4b9a5a 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -35,6 +35,7 @@ class CasHandler:
"""
def __init__(self, hs):
+ self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -210,8 +211,16 @@ class CasHandler:
else:
if not registered_user_id:
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(
+ b"User-Agent", default=[b""]
+ )[0].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=user_display_name
+ localpart=localpart,
+ default_display_name=user_display_name,
+ user_agent_ips=(user_agent, ip_address),
)
await self._auth_handler.complete_sso_login(
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index dd3703cbd2..c5bd2fea68 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -93,6 +93,7 @@ class OidcHandler:
"""
def __init__(self, hs: "HomeServer"):
+ self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
@@ -689,9 +690,17 @@ class OidcHandler:
self._render_error(request, "invalid_token", str(e))
return
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
# Call the mapper to register/login the user
try:
- user_id = await self._map_userinfo_to_user(userinfo, token)
+ user_id = await self._map_userinfo_to_user(
+ userinfo, token, user_agent, ip_address
+ )
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
@@ -828,7 +837,9 @@ class OidcHandler:
now = self._clock.time_msec()
return now < expiry
- async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+ async def _map_userinfo_to_user(
+ self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
+ ) -> str:
"""Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim
@@ -843,6 +854,8 @@ class OidcHandler:
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
Raises:
MappingException: if there was an error while mapping some properties
@@ -899,7 +912,9 @@ class OidcHandler:
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=attributes["display_name"],
+ localpart=localpart,
+ default_display_name=attributes["display_name"],
+ user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ccd96e4626..cde2dbca92 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -26,6 +26,7 @@ from synapse.replication.http.register import (
ReplicationPostRegisterActionsServlet,
ReplicationRegisterServlet,
)
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
@@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self.spam_checker = hs.get_spam_checker()
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler):
address=None,
bind_emails=[],
by_admin=False,
- shadow_banned=False,
+ user_agent_ips=None,
):
"""Registers a new client on the server.
@@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
- shadow_banned (bool): Shadow-ban the created user.
+ user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+ during the registration process.
Returns:
str: user_id
Raises:
@@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
+ result = self.spam_checker.check_registration_for_spam(
+ threepid, localpart, user_agent_ips or [],
+ )
+
+ if result == RegistrationBehaviour.DENY:
+ logger.info(
+ "Blocked registration of %r", localpart,
+ )
+ # We return a 429 to make it not obvious that they've been
+ # denied.
+ raise SynapseError(429, "Rate limited")
+
+ shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
+ if shadow_banned:
+ logger.info(
+ "Shadow banning registration of %r", localpart,
+ )
+
# do not check_auth_blocking if the call is coming through the Admin API
if not by_admin:
await self.auth.check_auth_blocking(threepid=threepid)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index c1fcb98454..b426199aa6 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -54,6 +54,7 @@ class Saml2SessionData:
class SamlHandler:
def __init__(self, hs: "synapse.server.HomeServer"):
+ self.hs = hs
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@@ -133,8 +134,14 @@ class SamlHandler:
# the dict.
self.expire_sessions()
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
user_id, current_session = await self._map_saml_response_to_user(
- resp_bytes, relay_state
+ resp_bytes, relay_state, user_agent, ip_address
)
# Complete the interactive auth session or the login.
@@ -147,7 +154,11 @@ class SamlHandler:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
- self, resp_bytes: str, client_redirect_url: str
+ self,
+ resp_bytes: str,
+ client_redirect_url: str,
+ user_agent: str,
+ ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]:
"""
Given a sample response, retrieve the cached session and user for it.
@@ -155,6 +166,8 @@ class SamlHandler:
Args:
resp_bytes: The SAML response.
client_redirect_url: The redirect URL passed in by the client.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
Returns:
Tuple of the user ID and SAML session associated with this response.
@@ -291,6 +304,7 @@ class SamlHandler:
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
+ user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 7290fd0756..be0e680ac5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE,
)
+ entries = await self.store.get_user_agents_ips_to_ui_auth_session(
+ session_id
+ )
+
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
address=client_addr,
+ user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 7f63f1bfa0..9be92e2565 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from enum import Enum
from twisted.internet import defer
@@ -25,6 +26,16 @@ if MYPY:
logger = logging.getLogger(__name__)
+class RegistrationBehaviour(Enum):
+ """
+ Enum to define whether a registration request should allowed, denied, or shadow-banned.
+ """
+
+ ALLOW = "allow"
+ SHADOW_BAN = "shadow_ban"
+ DENY = "deny"
+
+
class SpamCheckerApi(object):
"""A proxy object that gets passed to spam checkers so they can get
access to rooms and other relevant information.
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+ session_id TEXT NOT NULL,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ UNIQUE (session_id, ip, user_agent),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 6281a41a3d..9eef8e57c5 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import attr
@@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
+ async def add_user_agent_ip_to_ui_auth_session(
+ self, session_id: str, user_agent: str, ip: str,
+ ):
+ """Add the given user agent / IP to the tracking table
+ """
+ await self.db_pool.simple_upsert(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+ values={},
+ desc="add_user_agent_ip_to_ui_auth_session",
+ )
+
+ async def get_user_agents_ips_to_ui_auth_session(
+ self, session_id: str,
+ ) -> List[Tuple[str, str]]:
+ """Get the given user agents / IPs used during the ui auth process
+
+ Returns:
+ List of user_agent/ip pairs
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ )
+ return [(row["user_agent"], row["ip"]) for row in rows]
+
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
@@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore):
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
+ # Delete the corresponding IP/user agents.
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_ips",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..5c92d0e8c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,18 +17,21 @@ from mock import Mock
from twisted.internet import defer
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
from tests.test_utils import make_awaitable
from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
|