summary refs log tree commit diff
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2022-06-14 10:51:15 +0200
committerGitHub <noreply@github.com>2022-06-14 09:51:15 +0100
commit92103cb2c8b8bff6b522a7bfa8a3a776b4821b11 (patch)
tree644b933f8bc9c80d39adda432529340d0d865d1c
parentUniformize spam-checker API, part 4: port other spam-checker callbacks to ret... (diff)
downloadsynapse-92103cb2c8b8bff6b522a7bfa8a3a776b4821b11.tar.xz
Decouple `synapse.api.auth_blocking.AuthBlocking` from `synapse.api.auth.Auth`. (#13021)
-rw-r--r--changelog.d/13021.misc1
-rw-r--r--synapse/api/auth.py14
-rw-r--r--synapse/handlers/auth.py5
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/handlers/register.py3
-rw-r--r--synapse/handlers/room.py3
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py4
-rw-r--r--tests/api/test_auth.py42
-rw-r--r--tests/handlers/test_auth.py2
-rw-r--r--tests/handlers/test_register.py2
-rw-r--r--tests/handlers/test_sync.py2
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py22
14 files changed, 63 insertions, 50 deletions
diff --git a/changelog.d/13021.misc b/changelog.d/13021.misc
new file mode 100644
index 0000000000..84c41cdf59
--- /dev/null
+++ b/changelog.d/13021.misc
@@ -0,0 +1 @@
+Decouple `synapse.api.auth_blocking.AuthBlocking` from `synapse.api.auth.Auth`.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 5a410f805a..c037ccb984 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -20,7 +20,6 @@ from netaddr import IPAddress
 from twisted.web.server import Request
 
 from synapse import event_auth
-from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.api.errors import (
     AuthError,
@@ -67,8 +66,6 @@ class Auth:
             10000, "token_cache"
         )
 
-        self._auth_blocking = AuthBlocking(self.hs)
-
         self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
         self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
         self._macaroon_secret_key = hs.config.key.macaroon_secret_key
@@ -711,14 +708,3 @@ class Auth:
                 "User %s not in room %s, and room previews are disabled"
                 % (user_id, room_id),
             )
-
-    async def check_auth_blocking(
-        self,
-        user_id: Optional[str] = None,
-        threepid: Optional[dict] = None,
-        user_type: Optional[str] = None,
-        requester: Optional[Requester] = None,
-    ) -> None:
-        await self._auth_blocking.check_auth_blocking(
-            user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
-        )
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6e15028b0a..60d13040a2 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -199,6 +199,7 @@ class AuthHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self.clock = hs.get_clock()
         self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
@@ -985,7 +986,7 @@ class AuthHandler:
             not is_appservice_ghost
             or self.hs.config.appservice.track_appservice_user_ips
         ):
-            await self.auth.check_auth_blocking(user_id)
+            await self.auth_blocking.check_auth_blocking(user_id)
 
         access_token = self.generate_access_token(target_user_id_obj)
         await self.store.add_access_token_to_user(
@@ -1439,7 +1440,7 @@ class AuthHandler:
         except Exception:
             raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
 
-        await self.auth.check_auth_blocking(res.user_id)
+        await self.auth_blocking.check_auth_blocking(res.user_id)
         return res
 
     async def delete_access_token(self, access_token: str) -> None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ad87c41782..189f52fe5a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -444,7 +444,7 @@ _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
 class EventCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self._event_auth_handler = hs.get_event_auth_handler()
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
@@ -605,7 +605,7 @@ class EventCreationHandler:
         Returns:
             Tuple of created event, Context
         """
-        await self.auth.check_auth_blocking(requester=requester)
+        await self.auth_blocking.check_auth_blocking(requester=requester)
 
         if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
             room_version_id = event_dict["content"]["room_version"]
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 338204287f..c77d181722 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -91,6 +91,7 @@ class RegistrationHandler:
         self.clock = hs.get_clock()
         self.hs = hs
         self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self._auth_handler = hs.get_auth_handler()
         self.profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
@@ -276,7 +277,7 @@ class RegistrationHandler:
 
         # do not check_auth_blocking if the call is coming through the Admin API
         if not by_admin:
-            await self.auth.check_auth_blocking(threepid=threepid)
+            await self.auth_blocking.check_auth_blocking(threepid=threepid)
 
         if localpart is not None:
             await self.check_username(localpart, guest_access_token=guest_access_token)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 42aae4a215..75c0be8c36 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -110,6 +110,7 @@ class RoomCreationHandler:
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
         self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self.clock = hs.get_clock()
         self.hs = hs
         self.spam_checker = hs.get_spam_checker()
@@ -706,7 +707,7 @@ class RoomCreationHandler:
         """
         user_id = requester.user.to_string()
 
-        await self.auth.check_auth_blocking(requester=requester)
+        await self.auth_blocking.check_auth_blocking(requester=requester)
 
         if (
             self._server_notices_mxid is not None
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b4ead79f97..af19c513be 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -237,7 +237,7 @@ class SyncHandler:
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
-        self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
 
@@ -280,7 +280,7 @@ class SyncHandler:
         # not been exceeded (if not part of the group by this point, almost certain
         # auth_blocking will occur)
         user_id = sync_config.user.to_string()
-        await self.auth.check_auth_blocking(requester=requester)
+        await self.auth_blocking.check_auth_blocking(requester=requester)
 
         res = await self.response_cache.wrap(
             sync_config.request_key,
diff --git a/synapse/server.py b/synapse/server.py
index a66ec228db..a6a415aeab 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -29,6 +29,7 @@ from twisted.web.iweb import IPolicyForHTTPS
 from twisted.web.resource import Resource
 
 from synapse.api.auth import Auth
+from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.filtering import Filtering
 from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
 from synapse.appservice.api import ApplicationServiceApi
@@ -380,6 +381,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return Auth(self)
 
     @cache_in_self
+    def get_auth_blocking(self) -> AuthBlocking:
+        return AuthBlocking(self)
+
+    @cache_in_self
     def get_http_client_context_factory(self) -> IPolicyForHTTPS:
         if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
             return InsecureInterceptableContextFactory()
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 6863020778..3134cd2d3d 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -37,7 +37,7 @@ class ResourceLimitsServerNotices:
         self._server_notices_manager = hs.get_server_notices_manager()
         self._store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
-        self._auth = hs.get_auth()
+        self._auth_blocking = hs.get_auth_blocking()
         self._config = hs.config
         self._resouce_limited = False
         self._account_data_handler = hs.get_account_data_handler()
@@ -91,7 +91,7 @@ class ResourceLimitsServerNotices:
             # Normally should always pass in user_id to check_auth_blocking
             # if you have it, but in this case are checking what would happen
             # to other users if they were to arrive.
-            await self._auth.check_auth_blocking()
+            await self._auth_blocking.check_auth_blocking()
         except ResourceLimitError as e:
             limit_msg = e.msg
             limit_type = e.limit_type
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index bc75ddd3e9..54af9089e9 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,6 +19,7 @@ import pymacaroons
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.auth import Auth
+from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.constants import UserTypes
 from synapse.api.errors import (
     AuthError,
@@ -49,7 +50,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = self.auth._auth_blocking
+        self.auth_blocking = AuthBlocking(hs)
 
         self.test_user = "@foo:bar"
         self.test_token = b"_test_token_"
@@ -362,20 +363,22 @@ class AuthTestCase(unittest.HomeserverTestCase):
         small_number_of_users = 1
 
         # Ensure no error thrown
-        self.get_success(self.auth.check_auth_blocking())
+        self.get_success(self.auth_blocking.check_auth_blocking())
 
         self.auth_blocking._limit_usage_by_mau = True
 
         self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
 
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
 
         # Ensure does not throw an error
         self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
-        self.get_success(self.auth.check_auth_blocking())
+        self.get_success(self.auth_blocking.check_auth_blocking())
 
     def test_blocking_mau__depending_on_user_type(self):
         self.auth_blocking._max_mau_value = 50
@@ -383,15 +386,18 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Support users allowed
-        self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
+        self.get_success(
+            self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
+        )
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Bots not allowed
         self.get_failure(
-            self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
+            ResourceLimitError,
         )
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Real users not allowed
-        self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
     def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
         self.auth_blocking._max_mau_value = 50
@@ -419,7 +425,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             app_service=appservice,
             authenticated_entity="@appservice:server",
         )
-        self.get_success(self.auth.check_auth_blocking(requester=requester))
+        self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
 
     def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
         self.auth_blocking._max_mau_value = 50
@@ -448,7 +454,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
             authenticated_entity="@appservice:server",
         )
         self.get_failure(
-            self.auth.check_auth_blocking(requester=requester), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(requester=requester),
+            ResourceLimitError,
         )
 
     def test_reserved_threepid(self):
@@ -459,18 +466,21 @@ class AuthTestCase(unittest.HomeserverTestCase):
         unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.auth_blocking._mau_limits_reserved_threepids = [threepid]
 
-        self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
         self.get_failure(
-            self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
+            ResourceLimitError,
         )
 
-        self.get_success(self.auth.check_auth_blocking(threepid=threepid))
+        self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
 
     def test_hs_disabled(self):
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
@@ -485,7 +495,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
@@ -495,4 +507,4 @@ class AuthTestCase(unittest.HomeserverTestCase):
         user = "@user:server"
         self.auth_blocking._server_notices_mxid = user
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        self.get_success(self.auth.check_auth_blocking(user))
+        self.get_success(self.auth_blocking.check_auth_blocking(user))
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 67a7829769..7106799d44 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -38,7 +38,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         # MAU tests
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = hs.get_auth()._auth_blocking
+        self.auth_blocking = hs.get_auth_blocking()
         self.auth_blocking._max_mau_value = 50
 
         self.small_number_of_users = 1
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b6ba19c739..23f35d5bf5 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -699,7 +699,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         """
         if localpart is None:
             raise SynapseError(400, "Request must include user id")
-        await self.hs.get_auth().check_auth_blocking()
+        await self.hs.get_auth_blocking().check_auth_blocking()
         need_register = True
 
         try:
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index db3302a4c7..ecc7cc6461 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -45,7 +45,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = self.hs.get_auth()._auth_blocking
+        self.auth_blocking = self.hs.get_auth_blocking()
 
     def test_wait_for_sync_for_user_auth_blocking(self):
         user_id1 = "@user1:test"
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 07e29788e5..e07ae78fc4 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -96,7 +96,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
     def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
         """Test when user has blocked notice, but should have it removed"""
 
-        self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+            return_value=make_awaitable(None)
+        )
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
@@ -112,7 +114,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(403, "foo"),
         )
@@ -132,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, but should have one
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(403, "foo"),
         )
@@ -145,7 +147,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
-        self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+            return_value=make_awaitable(None)
+        )
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -156,7 +160,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
         """
-        self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+            return_value=make_awaitable(None)
+        )
         self._rlsn._store.user_last_seen_monthly_active = Mock(
             return_value=make_awaitable(None)
         )
@@ -170,7 +176,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test that when server is over MAU limit and alerting is suppressed, then
         an alert message is not sent into the room
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@@ -185,7 +191,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
@@ -202,7 +208,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         When the room is already in a blocked state, test that when alerting
         is suppressed that the room is returned to an unblocked state.
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER