diff --git a/Cargo.lock b/Cargo.lock
index 51ff26ec1b..52f911277e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -340,9 +340,9 @@ dependencies = [
[[package]]
name = "serde_json"
-version = "1.0.97"
+version = "1.0.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bdf3bf93142acad5821c99197022e170842cdbc1c30482b98750c688c640842a"
+checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3"
dependencies = [
"itoa",
"ryu",
diff --git a/changelog.d/15817.bugfix b/changelog.d/15817.bugfix
new file mode 100644
index 0000000000..2b025730ad
--- /dev/null
+++ b/changelog.d/15817.bugfix
@@ -0,0 +1 @@
+Fix sqlite `user_filters` upgrade introduced in v1.86.0.
diff --git a/changelog.d/15838.feature b/changelog.d/15838.feature
new file mode 100644
index 0000000000..04c77bd723
--- /dev/null
+++ b/changelog.d/15838.feature
@@ -0,0 +1 @@
+Add spam checker module API for logins.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 1a0c6ec954..ffdfe6082e 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
callback that does not return `False` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.
+
+### `check_login_for_spam`
+
+_First introduced in Synapse v1.87.0_
+
+```python
+async def check_login_for_spam(
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
+```
+
+Called when a user logs in.
+
+The arguments passed to this callback are:
+
+* `user_id`: The user ID the user is logging in with
+* `device_id`: The device ID the user is re-logging into.
+* `initial_display_name`: The device display name, if any.
+* `request_info`: A collection of tuples, which first item is a user agent, and which
+ second item is an IP address. These user agents and IP addresses are the ones that were
+ used during the login process.
+* `auth_provider_id`: The identifier of the SSO authentication provider, if any.
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
+The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
+be used. If this happens, Synapse will not call any of the subsequent implementations of
+this callback.
+
+*Note:* This will not be called when a user registers.
+
+
## Example
The example below is a module that implements the spam checker callback
diff --git a/poetry.lock b/poetry.lock
index d7b7a5aa2c..ee19c246f3 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2245,28 +2245,28 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]]
name = "ruff"
-version = "0.0.272"
+version = "0.0.275"
description = "An extremely fast Python linter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
- {file = "ruff-0.0.272-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:ae9b57546e118660175d45d264b87e9b4c19405c75b587b6e4d21e6a17bf4fdf"},
- {file = "ruff-0.0.272-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:1609b864a8d7ee75a8c07578bdea0a7db75a144404e75ef3162e0042bfdc100d"},
- {file = "ruff-0.0.272-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee76b4f05fcfff37bd6ac209d1370520d509ea70b5a637bdf0a04d0c99e13dff"},
- {file = "ruff-0.0.272-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48eccf225615e106341a641f826b15224b8a4240b84269ead62f0afd6d7e2d95"},
- {file = "ruff-0.0.272-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:677284430ac539bb23421a2b431b4ebc588097ef3ef918d0e0a8d8ed31fea216"},
- {file = "ruff-0.0.272-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:9c4bfb75456a8e1efe14c52fcefb89cfb8f2a0d31ed8d804b82c6cf2dc29c42c"},
- {file = "ruff-0.0.272-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86bc788245361a8148ff98667da938a01e1606b28a45e50ac977b09d3ad2c538"},
- {file = "ruff-0.0.272-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:27b2ea68d2aa69fff1b20b67636b1e3e22a6a39e476c880da1282c3e4bf6ee5a"},
- {file = "ruff-0.0.272-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd2bbe337a3f84958f796c77820d55ac2db1e6753f39d1d1baed44e07f13f96d"},
- {file = "ruff-0.0.272-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d5a208f8ef0e51d4746930589f54f9f92f84bb69a7d15b1de34ce80a7681bc00"},
- {file = "ruff-0.0.272-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:905ff8f3d6206ad56fcd70674453527b9011c8b0dc73ead27618426feff6908e"},
- {file = "ruff-0.0.272-py3-none-musllinux_1_2_i686.whl", hash = "sha256:19643d448f76b1eb8a764719072e9c885968971bfba872e14e7257e08bc2f2b7"},
- {file = "ruff-0.0.272-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:691d72a00a99707a4e0b2846690961157aef7b17b6b884f6b4420a9f25cd39b5"},
- {file = "ruff-0.0.272-py3-none-win32.whl", hash = "sha256:dc406e5d756d932da95f3af082814d2467943631a587339ee65e5a4f4fbe83eb"},
- {file = "ruff-0.0.272-py3-none-win_amd64.whl", hash = "sha256:a37ec80e238ead2969b746d7d1b6b0d31aa799498e9ba4281ab505b93e1f4b28"},
- {file = "ruff-0.0.272-py3-none-win_arm64.whl", hash = "sha256:06b8ee4eb8711ab119db51028dd9f5384b44728c23586424fd6e241a5b9c4a3b"},
- {file = "ruff-0.0.272.tar.gz", hash = "sha256:273a01dc8c3c4fd4c2af7ea7a67c8d39bb09bce466e640dd170034da75d14cab"},
+ {file = "ruff-0.0.275-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:5e6554a072e7ce81eb6f0bec1cebd3dcb0e358652c0f4900d7d630d61691e914"},
+ {file = "ruff-0.0.275-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:1cc599022fe5ffb143a965b8d659eb64161ab8ab4433d208777eab018a1aab67"},
+ {file = "ruff-0.0.275-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5206fc1cd8c1c1deadd2e6360c0dbcd690f1c845da588ca9d32e4a764a402c60"},
+ {file = "ruff-0.0.275-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0c4e6468da26f77b90cae35319d310999f471a8c352998e9b39937a23750149e"},
+ {file = "ruff-0.0.275-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0dbdea02942131dbc15dd45f431d152224f15e1dd1859fcd0c0487b658f60f1a"},
+ {file = "ruff-0.0.275-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:22efd9f41af27ef8fb9779462c46c35c89134d33e326c889971e10b2eaf50c63"},
+ {file = "ruff-0.0.275-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c09662112cfa22d7467a19252a546291fd0eae4f423e52b75a7a2000a1894db"},
+ {file = "ruff-0.0.275-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80043726662144876a381efaab88841c88e8df8baa69559f96b22d4fa216bef1"},
+ {file = "ruff-0.0.275-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5859ee543b01b7eb67835dfd505faa8bb7cc1550f0295c92c1401b45b42be399"},
+ {file = "ruff-0.0.275-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c8ace4d40a57b5ea3c16555f25a6b16bc5d8b2779ae1912ce2633543d4e9b1da"},
+ {file = "ruff-0.0.275-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8347fc16aa185aae275906c4ac5b770e00c896b6a0acd5ba521f158801911998"},
+ {file = "ruff-0.0.275-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ec43658c64bfda44fd84bbea9da8c7a3b34f65448192d1c4dd63e9f4e7abfdd4"},
+ {file = "ruff-0.0.275-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:508b13f7ca37274cceaba4fb3ea5da6ca192356323d92acf39462337c33ad14e"},
+ {file = "ruff-0.0.275-py3-none-win32.whl", hash = "sha256:6afb1c4422f24f361e877937e2a44b3f8176774a476f5e33845ebfe887dd5ec2"},
+ {file = "ruff-0.0.275-py3-none-win_amd64.whl", hash = "sha256:d9b264d78621bf7b698b6755d4913ab52c19bd28bee1a16001f954d64c1a1220"},
+ {file = "ruff-0.0.275-py3-none-win_arm64.whl", hash = "sha256:a19ce3bea71023eee5f0f089dde4a4272d088d5ac0b675867e074983238ccc65"},
+ {file = "ruff-0.0.275.tar.gz", hash = "sha256:a63a0b645da699ae5c758fce19188e901b3033ec54d862d93fcd042addf7f38d"},
]
[[package]]
@@ -2711,21 +2711,21 @@ files = [
[[package]]
name = "towncrier"
-version = "22.12.0"
+version = "23.6.0"
description = "Building newsfiles for your project."
optional = false
python-versions = ">=3.7"
files = [
- {file = "towncrier-22.12.0-py3-none-any.whl", hash = "sha256:9767a899a4d6856950f3598acd9e8f08da2663c49fdcda5ea0f9e6ba2afc8eea"},
- {file = "towncrier-22.12.0.tar.gz", hash = "sha256:9c49d7e75f646a9aea02ae904c0bc1639c8fd14a01292d2b123b8d307564034d"},
+ {file = "towncrier-23.6.0-py3-none-any.whl", hash = "sha256:da552f29192b3c2b04d630133f194c98e9f14f0558669d427708e203fea4d0a5"},
+ {file = "towncrier-23.6.0.tar.gz", hash = "sha256:fc29bd5ab4727c8dacfbe636f7fb5dc53b99805b62da1c96b214836159ff70c1"},
]
[package.dependencies]
click = "*"
click-default-group = "*"
+importlib-resources = {version = ">=5", markers = "python_version < \"3.10\""}
incremental = "*"
jinja2 = "*"
-setuptools = "*"
tomli = {version = "*", markers = "python_version < \"3.11\""}
[package.extras]
@@ -2931,13 +2931,13 @@ files = [
[[package]]
name = "types-opentracing"
-version = "2.4.10.4"
+version = "2.4.10.5"
description = "Typing stubs for opentracing"
optional = false
python-versions = "*"
files = [
- {file = "types-opentracing-2.4.10.4.tar.gz", hash = "sha256:347040c9da4ada7d3c795659912c95d98c5651e242e8eaa0344815fee5bb97e2"},
- {file = "types_opentracing-2.4.10.4-py3-none-any.whl", hash = "sha256:73c9b958eea3df6c4906ebf3865608a562dd9981c1bbc75a373a583c613bed56"},
+ {file = "types-opentracing-2.4.10.5.tar.gz", hash = "sha256:852d13ab1324832835d50c00cfd58b9267f0e79ec3189e5664c2a90c26880fd4"},
+ {file = "types_opentracing-2.4.10.5-py3-none-any.whl", hash = "sha256:8f12ab4dce3e298a8e6655da9a6d52171e7a275357eae4cec22a1663d94023a7"},
]
[[package]]
@@ -3003,13 +3003,13 @@ types-urllib3 = "*"
[[package]]
name = "types-setuptools"
-version = "67.8.0.0"
+version = "68.0.0.0"
description = "Typing stubs for setuptools"
optional = false
python-versions = "*"
files = [
- {file = "types-setuptools-67.8.0.0.tar.gz", hash = "sha256:95c9ed61871d6c0e258433373a4e1753c0a7c3627a46f4d4058c7b5a08ab844f"},
- {file = "types_setuptools-67.8.0.0-py3-none-any.whl", hash = "sha256:6df73340d96b238a4188b7b7668814b37e8018168aef1eef94a3b1872e3f60ff"},
+ {file = "types-setuptools-68.0.0.0.tar.gz", hash = "sha256:fc958b4123b155ffc069a66d3af5fe6c1f9d0600c35c0c8444b2ab4147112641"},
+ {file = "types_setuptools-68.0.0.0-py3-none-any.whl", hash = "sha256:cc00e09ba8f535362cbe1ea8b8407d15d14b59c57f4190cceaf61a9e57616446"},
]
[[package]]
@@ -3294,4 +3294,4 @@ user-search = ["pyicu"]
[metadata]
lock-version = "2.0"
python-versions = "^3.7.1"
-content-hash = "090924370b17fd265407b5a3f9cbc00997308f575b455399b39a48e3ca1a5a8e"
+content-hash = "7f31754a1009d7b6c9a1bd7221a0b243ffd510f362c28f0da417aaac16757a87"
diff --git a/pyproject.toml b/pyproject.toml
index 90812de019..a44ecd65e4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -311,7 +311,7 @@ all = [
# We pin black so that our tests don't start failing on new releases.
isort = ">=5.10.1"
black = ">=22.3.0"
-ruff = "0.0.272"
+ruff = "0.0.275"
# Typechecking
lxml-stubs = ">=0.4.0"
diff --git a/synapse/http/site.py b/synapse/http/site.py
index c530966ef3..5b5a7c1e59 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -521,6 +521,11 @@ class SynapseRequest(Request):
else:
return self.getClientAddress().host
+ def request_info(self) -> "RequestInfo":
+ h = self.getHeader(b"User-Agent")
+ user_agent = h.decode("ascii", "replace") if h else None
+ return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())
+
class XForwardedForRequest(SynapseRequest):
"""Request object which honours proxy headers
@@ -661,3 +666,9 @@ class SynapseSite(Site):
def log(self, request: SynapseRequest) -> None:
pass
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class RequestInfo:
+ user_agent: Optional[str]
+ ip: str
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 84b2aef620..95f7800111 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
+ CHECK_LOGIN_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
@@ -302,6 +303,7 @@ class ModuleApi:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+ check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Registers callbacks for spam checking capabilities.
@@ -319,6 +321,7 @@ class ModuleApi:
check_username_for_spam=check_username_for_spam,
check_registration_for_spam=check_registration_for_spam,
check_media_file_for_spam=check_media_file_for_spam,
+ check_login_for_spam=check_login_for_spam,
)
def register_account_validity_callbacks(
diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py
index 4456d1b81e..7cee442145 100644
--- a/synapse/module_api/callbacks/spamchecker_callbacks.py
+++ b/synapse/module_api/callbacks/spamchecker_callbacks.py
@@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
]
],
]
+CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
+ [
+ str,
+ Optional[str],
+ Optional[str],
+ Collection[Tuple[Optional[str], str]],
+ Optional[str],
+ ],
+ Awaitable[
+ Union[
+ Literal["NOT_SPAM"],
+ Codes,
+ # Highly experimental, not officially part of the spamchecker API, may
+ # disappear without warning depending on the results of ongoing
+ # experiments.
+ # Use this to return additional information as part of an error.
+ Tuple[Codes, JsonDict],
+ ]
+ ],
+]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
@@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
+ self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []
def register_callbacks(
self,
@@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+ check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
@@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
+ if check_login_for_spam is not None:
+ self._check_login_for_spam_callbacks.append(check_login_for_spam)
+
@trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
@@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
+ """Checks if we should allow the given registration request.
+
+ Args:
+ user_id: The request user ID
+ request_info: List of tuples of user agent and IP that
+ were used during the registration process.
+ auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+ "cas". If any. Note this does not include users registered
+ via a password provider.
+
+ Returns:
+ Enum for how the request should be handled
+ """
+
+ for callback in self._check_login_for_spam_callbacks:
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ res = await delay_cancellation(
+ callback(
+ user_id,
+ device_id,
+ initial_display_name,
+ request_info,
+ auth_provider_id,
+ )
+ )
+ # Normalize return values to `Codes` or `"NOT_SPAM"`.
+ if res is self.NOT_SPAM:
+ continue
+ elif isinstance(res, synapse.api.errors.Codes):
+ return res, {}
+ elif (
+ isinstance(res, tuple)
+ and len(res) == 2
+ and isinstance(res[0], synapse.api.errors.Codes)
+ and isinstance(res[1], dict)
+ ):
+ return res
+ else:
+ logger.warning(
+ "Module returned invalid value, rejecting login as spam"
+ )
+ return synapse.api.errors.Codes.FORBIDDEN, {}
+
+ return self.NOT_SPAM
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 6493b00bb8..d724c68920 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -50,7 +50,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.http.site import SynapseRequest
+from synapse.http.site import RequestInfo, SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
@@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
+ self._spam_checker = hs.get_module_api_callbacks().spam_checker
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
@@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet):
self._refresh_tokens_enabled and client_requested_refresh_token
)
+ request_info = request.request_info()
+
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
requester = await self.auth.get_user_by_req(request)
@@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet):
login_submission,
appservice,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
elif (
self.jwt_enabled
@@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
@@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
@@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet):
# The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender,
+ request_info=request_info,
)
async def _do_other_login(
- self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ self,
+ login_submission: JsonDict,
+ should_issue_refresh_token: bool = False,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
"""Handle non-token/saml/jwt logins
@@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet):
login_submission,
callback,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
return result
@@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
should_check_deactivated: bool = True,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet):
This exists purely for appservice's configured sender_localpart
which doesn't have an associated user in the database.
+ request_info: The user agent/IP address of the user.
Returns:
Dictionary of account information after successful login.
@@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet):
)
initial_display_name = login_submission.get("initial_device_display_name")
+ spam_check = await self._spam_checker.check_login_for_spam(
+ user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ request_info=[(request_info.user_agent, request_info.ip)],
+ auth_provider_id=auth_provider_id,
+ )
+ if spam_check != self._spam_checker.NOT_SPAM:
+ logger.info("Blocking login due to spam checker")
+ raise SynapseError(
+ 403,
+ msg="Login was blocked by the server",
+ errcode=spam_check[0],
+ additional_fields=spam_check[1],
+ )
+
(
device_id,
access_token,
@@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet):
return result
async def _do_token_login(
- self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ self,
+ login_submission: JsonDict,
+ should_issue_refresh_token: bool = False,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
"""
Handle token login.
@@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet):
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=res.auth_provider_session_id,
+ request_info=request_info,
)
async def _do_jwt_login(
- self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ self,
+ login_submission: JsonDict,
+ should_issue_refresh_token: bool = False,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
"""
Handle the custom JWT login.
@@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet):
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
diff --git a/synapse/storage/schema/main/delta/78/02_validate_and_update_user_filters.py b/synapse/storage/schema/main/delta/78/02_validate_and_update_user_filters.py
index 8ef63335e7..e148ed26f2 100644
--- a/synapse/storage/schema/main/delta/78/02_validate_and_update_user_filters.py
+++ b/synapse/storage/schema/main/delta/78/02_validate_and_update_user_filters.py
@@ -61,9 +61,7 @@ def run_upgrade(
full_user_id text NOT NULL,
user_id text NOT NULL,
filter_id bigint NOT NULL,
- filter_json bytea NOT NULL,
- UNIQUE (full_user_id),
- UNIQUE (user_id)
+ filter_json bytea NOT NULL
)
"""
cur.execute(create_sql)
diff --git a/synapse/storage/schema/main/delta/78/03_remove_unused_indexes_user_filters.py b/synapse/storage/schema/main/delta/78/03_remove_unused_indexes_user_filters.py
new file mode 100644
index 0000000000..f5ba1c3fd4
--- /dev/null
+++ b/synapse/storage/schema/main/delta/78/03_remove_unused_indexes_user_filters.py
@@ -0,0 +1,65 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine
+
+
+def run_update(
+ cur: LoggingTransaction,
+ database_engine: BaseDatabaseEngine,
+ config: HomeServerConfig,
+) -> None:
+ """
+ Fix to drop unused indexes caused by incorrectly adding UNIQUE constraint to
+ columns `user_id` and `full_user_id` of table `user_filters` in previous migration.
+ """
+
+ if isinstance(database_engine, Sqlite3Engine):
+ cur.execute("DROP TABLE IF EXISTS temp_user_filters")
+ create_sql = """
+ CREATE TABLE temp_user_filters (
+ full_user_id text NOT NULL,
+ user_id text NOT NULL,
+ filter_id bigint NOT NULL,
+ filter_json bytea NOT NULL
+ )
+ """
+ cur.execute(create_sql)
+
+ copy_sql = """
+ INSERT INTO temp_user_filters (
+ user_id,
+ filter_id,
+ filter_json,
+ full_user_id)
+ SELECT user_id, filter_id, filter_json, full_user_id FROM user_filters
+ """
+ cur.execute(copy_sql)
+
+ drop_sql = """
+ DROP TABLE user_filters
+ """
+ cur.execute(drop_sql)
+
+ rename_sql = """
+ ALTER TABLE temp_user_filters RENAME to user_filters
+ """
+ cur.execute(rename_sql)
+
+ index_sql = """
+ CREATE UNIQUE INDEX IF NOT EXISTS user_filters_unique ON
+ user_filters (user_id, filter_id)
+ """
+ cur.execute(index_sql)
diff --git a/synapse/storage/schema/main/delta/78/04_add_full_user_id_index_user_filters.py b/synapse/storage/schema/main/delta/78/04_add_full_user_id_index_user_filters.py
new file mode 100644
index 0000000000..97fecc2bd9
--- /dev/null
+++ b/synapse/storage/schema/main/delta/78/04_add_full_user_id_index_user_filters.py
@@ -0,0 +1,25 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine
+
+
+def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
+ if isinstance(database_engine, Sqlite3Engine):
+ idx_sql = """
+ CREATE UNIQUE INDEX IF NOT EXISTS user_filters_full_user_id_unique ON
+ user_filters (full_user_id, filter_id)
+ """
+ cur.execute(idx_sql)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f3c3bc69a9..ffbc13bb8d 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,11 +13,12 @@
# limitations under the License.
import time
import urllib.parse
-from typing import Any, Dict, List, Optional
+from typing import Any, Collection, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib.parse import urlencode
import pymacaroons
+from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@@ -26,11 +27,12 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
+from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
-from synapse.types import create_requester
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
]
+class TestSpamChecker:
+ def __init__(self, config: None, api: ModuleApi):
+ api.register_spam_checker_callbacks(
+ check_login_for_spam=self.check_login_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config: JsonDict) -> None:
+ return None
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[
+ Literal["NOT_SPAM"],
+ Tuple["synapse.module_api.errors.Codes", JsonDict],
+ ]:
+ return "NOT_SPAM"
+
+
+class DenyAllSpamChecker:
+ def __init__(self, config: None, api: ModuleApi):
+ api.register_spam_checker_callbacks(
+ check_login_for_spam=self.check_login_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config: JsonDict) -> None:
+ return None
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[
+ Literal["NOT_SPAM"],
+ Tuple["synapse.module_api.errors.Codes", JsonDict],
+ ]:
+ # Return an odd set of values to ensure that they get correctly passed
+ # to the client.
+ return Codes.LIMIT_EXCEEDED, {"extra": "value"}
+
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__
+ + "."
+ + TestSpamChecker.__qualname__
+ }
+ ]
+ }
+ )
+ def test_spam_checker_allow(self) -> None:
+ """Check that that adding a spam checker doesn't break login."""
+ self.register_user("kermit", "monkey")
+
+ body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": DenyAllSpamChecker.__module__
+ + "."
+ + DenyAllSpamChecker.__qualname__
+ }
+ ]
+ }
+ )
+ def test_spam_checker_deny(self) -> None:
+ """Check that login"""
+
+ self.register_user("kermit", "monkey")
+
+ body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertDictContainsSubset(
+ {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
+ )
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
|