diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py
index 5c91031746..b1a9db0210 100644
--- a/tests/config/test_oauth_delegation.py
+++ b/tests/config/test_oauth_delegation.py
@@ -22,15 +22,7 @@ from synapse.types import JsonDict
from tests.server import get_clock, setup_test_homeserver
from tests.unittest import TestCase, skip_unless
-from tests.utils import default_config
-
-try:
- import authlib # noqa: F401
-
- HAS_AUTHLIB = True
-except ImportError:
- HAS_AUTHLIB = False
-
+from tests.utils import HAS_AUTHLIB, default_config
# These are a few constants that are used as config parameters in the tests.
SERVER_NAME = "test"
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index a72ecfdc97..a2b8e3d562 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -13,7 +13,8 @@
# limitations under the License.
from http import HTTPStatus
-from typing import Any, Dict, Union
+from io import BytesIO
+from typing import Any, Dict, Optional, Union
from unittest.mock import ANY, AsyncMock, Mock
from urllib.parse import parse_qs
@@ -25,6 +26,8 @@ from signedjson.key import (
from signedjson.sign import sign_json
from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
from synapse.api.errors import (
AuthError,
@@ -33,23 +36,17 @@ from synapse.api.errors import (
OAuthInsufficientScopeError,
SynapseError,
)
+from synapse.http.site import SynapseRequest
from synapse.rest import admin
from synapse.rest.client import account, devices, keys, login, logout, register
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
+from tests.server import FakeChannel
from tests.test_utils import FakeResponse, get_awaitable_result
-from tests.unittest import HomeserverTestCase, skip_unless
-from tests.utils import mock_getRawHeaders
-
-try:
- import authlib # noqa: F401
-
- HAS_AUTHLIB = True
-except ImportError:
- HAS_AUTHLIB = False
-
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+from tests.utils import HAS_AUTHLIB, checked_cast, mock_getRawHeaders
# These are a few constants that are used as config parameters in the tests.
SERVER_NAME = "test"
@@ -75,6 +72,7 @@ MATRIX_DEVICE_SCOPE = MATRIX_DEVICE_SCOPE_PREFIX + DEVICE
SUBJECT = "abc-def-ghi"
USERNAME = "test-user"
USER_ID = "@" + USERNAME + ":" + SERVER_NAME
+OIDC_ADMIN_USERID = f"@__oidc_admin:{SERVER_NAME}"
async def get_json(url: str) -> JsonDict:
@@ -134,7 +132,10 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
- self.auth = hs.get_auth()
+ # Import this here so that we've checked that authlib is available.
+ from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth
+
+ self.auth = checked_cast(MSC3861DelegatedAuth, hs.get_auth())
return hs
@@ -675,7 +676,8 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(
- requester.user.to_string(), "@%s:%s" % ("__oidc_admin", SERVER_NAME)
+ requester.user.to_string(),
+ OIDC_ADMIN_USERID,
)
self.assertEqual(requester.is_guest, False)
self.assertEqual(requester.device_id, None)
@@ -685,3 +687,102 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
# There should be no call to the introspection endpoint
self.http_client.request.assert_not_called()
+
+ @override_config({"mau_stats_only": True})
+ def test_request_tracking(self) -> None:
+ """Using an access token should update the client_ips and MAU tables."""
+ # To start, there are no MAU users.
+ store = self.hs.get_datastores().main
+ mau = self.get_success(store.get_monthly_active_count())
+ self.assertEqual(mau, 0)
+
+ known_token = "token-token-GOOD-:)"
+
+ async def mock_http_client_request(
+ method: str,
+ uri: str,
+ data: Optional[bytes] = None,
+ headers: Optional[Headers] = None,
+ ) -> IResponse:
+ """Mocked auth provider response."""
+ assert method == "POST"
+ token = parse_qs(data)[b"token"][0].decode("utf-8")
+ if token == known_token:
+ return FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "scope": MATRIX_USER_SCOPE,
+ "sub": SUBJECT,
+ "username": USERNAME,
+ },
+ )
+
+ return FakeResponse.json(code=200, payload={"active": False})
+
+ self.http_client.request = mock_http_client_request
+
+ EXAMPLE_IPV4_ADDR = "123.123.123.123"
+ EXAMPLE_USER_AGENT = "httprettygood"
+
+ # First test a known access token
+ channel = FakeChannel(self.site, self.reactor)
+ # type-ignore: FakeChannel is a mock of an HTTPChannel, not a proper HTTPChannel
+ req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
+ req.client.host = EXAMPLE_IPV4_ADDR
+ req.requestHeaders.addRawHeader("Authorization", f"Bearer {known_token}")
+ req.requestHeaders.addRawHeader("User-Agent", EXAMPLE_USER_AGENT)
+ req.content = BytesIO(b"")
+ req.requestReceived(
+ b"GET",
+ b"/_matrix/client/v3/account/whoami",
+ b"1.1",
+ )
+ channel.await_result()
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ self.assertEqual(channel.json_body["user_id"], USER_ID, channel.json_body)
+
+ # Expect to see one MAU entry, from the first request
+ mau = self.get_success(store.get_monthly_active_count())
+ self.assertEqual(mau, 1)
+
+ conn_infos = self.get_success(
+ store.get_user_ip_and_agents(UserID.from_string(USER_ID))
+ )
+ self.assertEqual(len(conn_infos), 1, conn_infos)
+ conn_info = conn_infos[0]
+ self.assertEqual(conn_info["access_token"], known_token)
+ self.assertEqual(conn_info["ip"], EXAMPLE_IPV4_ADDR)
+ self.assertEqual(conn_info["user_agent"], EXAMPLE_USER_AGENT)
+
+ # Now test MAS making a request using the special __oidc_admin token
+ MAS_IPV4_ADDR = "127.0.0.1"
+ MAS_USER_AGENT = "masmasmas"
+
+ channel = FakeChannel(self.site, self.reactor)
+ req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
+ req.client.host = MAS_IPV4_ADDR
+ req.requestHeaders.addRawHeader(
+ "Authorization", f"Bearer {self.auth._admin_token}"
+ )
+ req.requestHeaders.addRawHeader("User-Agent", MAS_USER_AGENT)
+ req.content = BytesIO(b"")
+ req.requestReceived(
+ b"GET",
+ b"/_matrix/client/v3/account/whoami",
+ b"1.1",
+ )
+ channel.await_result()
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ self.assertEqual(
+ channel.json_body["user_id"], OIDC_ADMIN_USERID, channel.json_body
+ )
+
+ # Still expect to see one MAU entry, from the first request
+ mau = self.get_success(store.get_monthly_active_count())
+ self.assertEqual(mau, 1)
+
+ conn_infos = self.get_success(
+ store.get_user_ip_and_agents(UserID.from_string(OIDC_ADMIN_USERID))
+ )
+ self.assertEqual(conn_infos, [])
diff --git a/tests/rest/media/test_media_retention.py b/tests/media/test_media_retention.py
index 27a663a23b..27a663a23b 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/media/test_media_retention.py
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index f262304c3d..f981d1c0d8 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -27,10 +27,11 @@ from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
from synapse.events import EventBase
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
@@ -247,6 +248,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
+ follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
"""A mock for MatrixFederationHttpClient.get_file."""
@@ -257,10 +259,15 @@ class MediaRepoTests(unittest.HomeserverTestCase):
output_stream.write(data)
return response
+ def write_err(f: Failure) -> Failure:
+ f.trap(HttpResponseException)
+ output_stream.write(f.value.response)
+ return f
+
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
# Note that this callback changes the value held by d.
- d_after_callback = d.addCallback(write_to)
+ d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)
# Mock out the homeserver's MatrixFederationHttpClient
@@ -316,10 +323,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
- self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
+ self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id
)
self.assertEqual(
- self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"}
+ self.fetches[0][3],
+ {"allow_remote": "false", "timeout_ms": "20000", "allow_redirect": "true"},
)
headers = {
@@ -671,6 +679,52 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"cross-origin"],
)
+ def test_unknown_v3_endpoint(self) -> None:
+ """
+ If the v3 endpoint fails, try the r0 one.
+ """
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/media/v3/download/{self.media_id}",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ # We've made one fetch, to example.com, using the media URL, and asking
+ # the other server not to do a remote fetch
+ self.assertEqual(len(self.fetches), 1)
+ self.assertEqual(self.fetches[0][1], "example.com")
+ self.assertEqual(
+ self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id
+ )
+
+ # The result which says the endpoint is unknown.
+ unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
+ self.fetches[0][0].errback(
+ HttpResponseException(404, "NOT FOUND", unknown_endpoint)
+ )
+
+ self.pump()
+
+ # There should now be another request to the r0 URL.
+ self.assertEqual(len(self.fetches), 2)
+ self.assertEqual(self.fetches[1][1], "example.com")
+ self.assertEqual(
+ self.fetches[1][2], f"/_matrix/media/r0/download/{self.media_id}"
+ )
+
+ headers = {
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ }
+
+ self.fetches[1][0].callback(
+ (self.test_image.data, (len(self.test_image.data), headers))
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
class TestSpamCheckerLegacy:
"""A spam checker module that rejects all media that includes the bytes
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 1e9994cc0b..9a7b675f54 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -133,7 +133,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(
request.path,
- f"/_matrix/media/r0/download/{target}/{media_id}".encode(),
+ f"/_matrix/media/v3/download/{target}/{media_id}".encode(),
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py
index a9a6191c73..842e92c3d0 100644
--- a/tests/rest/admin/test_jwks.py
+++ b/tests/rest/admin/test_jwks.py
@@ -19,13 +19,7 @@ from twisted.web.resource import Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from tests.unittest import HomeserverTestCase, override_config, skip_unless
-
-try:
- import authlib # noqa: F401
-
- HAS_AUTHLIB = True
-except ImportError:
- HAS_AUTHLIB = False
+from tests.utils import HAS_AUTHLIB
@skip_unless(HAS_AUTHLIB, "requires authlib")
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 9f81a695fa..a6023dff7a 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -30,13 +30,7 @@ from synapse.types import JsonDict, Requester, create_requester
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
from tests.unittest import override_config
-
-try:
- import authlib # noqa: F401
-
- HAS_AUTHLIB = True
-except ImportError:
- HAS_AUTHLIB = False
+from tests.utils import HAS_AUTHLIB
class KeyQueryTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 377243a170..7931a70abb 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -16,13 +16,7 @@ from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
from tests import unittest
-
-try:
- import authlib # noqa: F401
-
- HAS_AUTHLIB = True
-except ImportError:
- HAS_AUTHLIB = False
+from tests.utils import HAS_AUTHLIB
class WellKnownTests(unittest.HomeserverTestCase):
diff --git a/tests/utils.py b/tests/utils.py
index a0c87ad628..e0066fe15a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,6 +30,13 @@ from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
# set this to True to run the tests against postgres instead of sqlite.
#
# When running under postgres, we first create a base database with the name
|