summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Teller <D.O.Teller@gmail.com>2020-12-11 20:05:15 +0100
committerGitHub <noreply@github.com>2020-12-11 14:05:15 -0500
commitf14428b25c37e44675edac4a80d7bd1e47112586 (patch)
tree7565992e70db2c48c7008b2e3fdfe122d315308e
parentAdd type hints to the push module. (#8901) (diff)
downloadsynapse-f14428b25c37e44675edac4a80d7bd1e47112586.tar.xz
Allow spam-checker modules to be provide async methods. (#8890)
Spam checker modules can now provide async methods. This is implemented
in a backwards-compatible manner.
-rw-r--r--changelog.d/8890.feature1
-rw-r--r--docs/spam_checker.md19
-rw-r--r--synapse/events/spamcheck.py55
-rw-r--r--synapse/federation/federation_base.py7
-rw-r--r--synapse/handlers/auth.py8
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/receipts.py7
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/room.py4
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/user_directory.py10
-rw-r--r--synapse/metrics/background_process_metrics.py9
-rw-r--r--synapse/rest/media/v1/storage_provider.py16
-rw-r--r--synapse/server.py2
-rw-r--r--synapse/util/async_helpers.py8
-rw-r--r--synapse/util/distributor.py7
-rw-r--r--tests/handlers/test_user_directory.py4
19 files changed, 98 insertions, 73 deletions
diff --git a/changelog.d/8890.feature b/changelog.d/8890.feature
new file mode 100644
index 0000000000..97aa72a76e
--- /dev/null
+++ b/changelog.d/8890.feature
@@ -0,0 +1 @@
+Spam-checkers may now define their methods as `async`.
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index 7fc08f1b70..5b4f6428e6 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -22,6 +22,8 @@ well as some specific methods:
 * `user_may_create_room`
 * `user_may_create_room_alias`
 * `user_may_publish_room`
+* `check_username_for_spam`
+* `check_registration_for_spam`
 
 The details of the each of these methods (as well as their inputs and outputs)
 are documented in the `synapse.events.spamcheck.SpamChecker` class.
@@ -32,28 +34,33 @@ call back into the homeserver internals.
 ### Example
 
 ```python
+from synapse.spam_checker_api import RegistrationBehaviour
+
 class ExampleSpamChecker:
     def __init__(self, config, api):
         self.config = config
         self.api = api
 
-    def check_event_for_spam(self, foo):
+    async def check_event_for_spam(self, foo):
         return False  # allow all events
 
-    def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
         return True  # allow all invites
 
-    def user_may_create_room(self, userid):
+    async def user_may_create_room(self, userid):
         return True  # allow all room creations
 
-    def user_may_create_room_alias(self, userid, room_alias):
+    async def user_may_create_room_alias(self, userid, room_alias):
         return True  # allow all room aliases
 
-    def user_may_publish_room(self, userid, room_id):
+    async def user_may_publish_room(self, userid, room_id):
         return True  # allow publishing of all rooms
 
-    def check_username_for_spam(self, user_profile):
+    async def check_username_for_spam(self, user_profile):
         return False  # allow all usernames
+
+    async def check_registration_for_spam(self, email_threepid, username, request_info):
+        return RegistrationBehaviour.ALLOW  # allow all registrations
 ```
 
 ## Configuration
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 936896656a..e7e3a7b9a4 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,10 +15,11 @@
 # limitations under the License.
 
 import inspect
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import Collection
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     import synapse.events
@@ -39,7 +40,9 @@ class SpamChecker:
             else:
                 self.spam_checkers.append(module(config=config))
 
-    def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
+    async def check_event_for_spam(
+        self, event: "synapse.events.EventBase"
+    ) -> Union[bool, str]:
         """Checks if a given event is considered "spammy" by this server.
 
         If the server considers an event spammy, then it will be rejected if
@@ -50,15 +53,16 @@ class SpamChecker:
             event: the event to be checked
 
         Returns:
-            True if the event is spammy.
+            True or a string if the event is spammy. If a string is returned it
+            will be used as the error message returned to the user.
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.check_event_for_spam(event):
+            if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
                 return True
 
         return False
 
-    def user_may_invite(
+    async def user_may_invite(
         self, inviter_userid: str, invitee_userid: str, room_id: str
     ) -> bool:
         """Checks if a given user may send an invite
@@ -75,14 +79,18 @@ class SpamChecker:
         """
         for spam_checker in self.spam_checkers:
             if (
-                spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+                await maybe_awaitable(
+                    spam_checker.user_may_invite(
+                        inviter_userid, invitee_userid, room_id
+                    )
+                )
                 is False
             ):
                 return False
 
         return True
 
-    def user_may_create_room(self, userid: str) -> bool:
+    async def user_may_create_room(self, userid: str) -> bool:
         """Checks if a given user may create a room
 
         If this method returns false, the creation request will be rejected.
@@ -94,12 +102,15 @@ class SpamChecker:
             True if the user may create a room, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_create_room(userid) is False:
+            if (
+                await maybe_awaitable(spam_checker.user_may_create_room(userid))
+                is False
+            ):
                 return False
 
         return True
 
-    def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
+    async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
         """Checks if a given user may create a room alias
 
         If this method returns false, the association request will be rejected.
@@ -112,12 +123,17 @@ class SpamChecker:
             True if the user may create a room alias, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+            if (
+                await maybe_awaitable(
+                    spam_checker.user_may_create_room_alias(userid, room_alias)
+                )
+                is False
+            ):
                 return False
 
         return True
 
-    def user_may_publish_room(self, userid: str, room_id: str) -> bool:
+    async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
         """Checks if a given user may publish a room to the directory
 
         If this method returns false, the publish request will be rejected.
@@ -130,12 +146,17 @@ class SpamChecker:
             True if the user may publish the room, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_publish_room(userid, room_id) is False:
+            if (
+                await maybe_awaitable(
+                    spam_checker.user_may_publish_room(userid, room_id)
+                )
+                is False
+            ):
                 return False
 
         return True
 
-    def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+    async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
         """Checks if a user ID or display name are considered "spammy" by this server.
 
         If the server considers a username spammy, then it will not be included in
@@ -157,12 +178,12 @@ class SpamChecker:
             if checker:
                 # Make a copy of the user profile object to ensure the spam checker
                 # cannot modify it.
-                if checker(user_profile.copy()):
+                if await maybe_awaitable(checker(user_profile.copy())):
                     return True
 
         return False
 
-    def check_registration_for_spam(
+    async def check_registration_for_spam(
         self,
         email_threepid: Optional[dict],
         username: Optional[str],
@@ -185,7 +206,9 @@ class SpamChecker:
             # spam checker
             checker = getattr(spam_checker, "check_registration_for_spam", None)
             if checker:
-                behaviour = checker(email_threepid, username, request_info)
+                behaviour = await maybe_awaitable(
+                    checker(email_threepid, username, request_info)
+                )
                 assert isinstance(behaviour, RegistrationBehaviour)
                 if behaviour != RegistrationBehaviour.ALLOW:
                     return behaviour
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 38aa47963f..383737520a 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -78,6 +78,7 @@ class FederationBase:
 
         ctx = current_context()
 
+        @defer.inlineCallbacks
         def callback(_, pdu: EventBase):
             with PreserveLoggingContext(ctx):
                 if not check_event_content_hash(pdu):
@@ -105,7 +106,11 @@ class FederationBase:
                         )
                     return redacted_event
 
-                if self.spam_checker.check_event_for_spam(pdu):
+                result = yield defer.ensureDeferred(
+                    self.spam_checker.check_event_for_spam(pdu)
+                )
+
+                if result:
                     logger.warning(
                         "Event contains spam, redacting %s: %s",
                         pdu.event_id,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 62f98dabc0..8deec4cd0c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import inspect
 import logging
 import time
 import unicodedata
@@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
@@ -1639,6 +1639,6 @@ class PasswordProvider:
 
         # This might return an awaitable, if it does block the log out
         # until it completes.
-        result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
-        if inspect.isawaitable(result):
-            await result
+        await maybe_awaitable(
+            g(user_id=user_id, device_id=device_id, access_token=access_token,)
+        )
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad5683d251..abcf86352d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
                         403, "You must be in the room to create an alias for it"
                     )
 
-            if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+            if not await self.spam_checker.user_may_create_room_alias(
+                user_id, room_alias
+            ):
                 raise AuthError(403, "This user is not permitted to create this alias")
 
             if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
         """
         user_id = requester.user.to_string()
 
-        if not self.spam_checker.user_may_publish_room(user_id, room_id):
+        if not await self.spam_checker.user_may_publish_room(user_id, room_id):
             raise AuthError(
                 403, "This user is not permitted to publish rooms to the room list"
             )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index df82e60b33..fd8de8696d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
         if self.hs.config.block_non_admin_invites:
             raise SynapseError(403, "This server does not accept room invites")
 
-        if not self.spam_checker.user_may_invite(
+        if not await self.spam_checker.user_may_invite(
             event.sender, event.state_key, event.room_id
         ):
             raise SynapseError(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 96843338ae..2b8aa9443d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -744,7 +744,7 @@ class EventCreationHandler:
                 event.sender,
             )
 
-            spam_error = self.spam_checker.check_event_for_spam(event)
+            spam_error = await self.spam_checker.check_event_for_spam(event)
             if spam_error:
                 if not isinstance(spam_error, str):
                     spam_error = "Spam is not permitted here"
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 153cbae7b9..e850e45e46 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,7 +18,6 @@ from typing import List, Tuple
 from synapse.appservice import ApplicationService
 from synapse.handlers._base import BaseHandler
 from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
-from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
 
         self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
         # Note that the min here shouldn't be relied upon to be accurate.
-        await maybe_awaitable(
-            self.hs.get_pusherpool().on_new_receipts(
-                min_batch_id, max_batch_id, affected_room_ids
-            )
+        await self.hs.get_pusherpool().on_new_receipts(
+            min_batch_id, max_batch_id, affected_room_ids
         )
 
         return True
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0d85fd0868..94b5610acd 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
         """
         self.check_registration_ratelimit(address)
 
-        result = self.spam_checker.check_registration_for_spam(
+        result = await self.spam_checker.check_registration_for_spam(
             threepid, localpart, user_agent_ips or [],
         )
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 82fb72b381..7583418946 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler):
         """
         user_id = requester.user.to_string()
 
-        if not self.spam_checker.user_may_create_room(user_id):
+        if not await self.spam_checker.user_may_create_room(user_id):
             raise SynapseError(403, "You are not permitted to create rooms")
 
         creation_content = {
@@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
                 403, "You are not permitted to create rooms", Codes.FORBIDDEN
             )
 
-        if not is_requester_admin and not self.spam_checker.user_may_create_room(
+        if not is_requester_admin and not await self.spam_checker.user_may_create_room(
             user_id
         ):
             raise SynapseError(403, "You are not permitted to create rooms")
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d85110a35e..cb5a29bc7e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     )
                     block_invite = True
 
-                if not self.spam_checker.user_may_invite(
+                if not await self.spam_checker.user_may_invite(
                     requester.user.to_string(), target.to_string(), room_id
                 ):
                     logger.info("Blocking invite due to spam checker")
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index afbebfc200..f263a638f8 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
         results = await self.store.search_user_dir(user_id, search_term, limit)
 
         # Remove any spammy users from the results.
-        results["results"] = [
-            user
-            for user in results["results"]
-            if not self.spam_checker.check_username_for_spam(user)
-        ]
+        non_spammy_users = []
+        for user in results["results"]:
+            if not await self.spam_checker.check_username_for_spam(user):
+                non_spammy_users.append(user)
+        results["results"] = non_spammy_users
 
         return results
 
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 658f6ecd72..76b7decf26 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
 import logging
 import threading
 from functools import wraps
@@ -25,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     import resource
@@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
                 if bg_start_span:
                     ctx = start_active_span(desc, tags={"request_id": context.request})
                 with ctx:
-                    result = func(*args, **kwargs)
-
-                    if inspect.isawaitable(result):
-                        result = await result
-
-                    return result
+                    return await maybe_awaitable(func(*args, **kwargs))
             except Exception:
                 logger.exception(
                     "Background process '%s' threw an exception", desc,
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..67f67efde7 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
 import logging
 import os
 import shutil
@@ -21,6 +20,7 @@ from typing import Optional
 
 from synapse.config._base import Config
 from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
 
 from ._base import FileInfo, Responder
 from .media_storage import FileResponder
@@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
         if self.store_synchronous:
             # store_file is supposed to return an Awaitable, but guard
             # against improper implementations.
-            result = self.backend.store_file(path, file_info)
-            if inspect.isawaitable(result):
-                return await result
+            return await maybe_awaitable(self.backend.store_file(path, file_info))
         else:
             # TODO: Handle errors.
             async def store():
                 try:
-                    result = self.backend.store_file(path, file_info)
-                    if inspect.isawaitable(result):
-                        return await result
+                    return await maybe_awaitable(
+                        self.backend.store_file(path, file_info)
+                    )
                 except Exception:
                     logger.exception("Error storing file")
 
@@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
     async def fetch(self, path, file_info):
         # store_file is supposed to return an Awaitable, but guard
         # against improper implementations.
-        result = self.backend.fetch(path, file_info)
-        if inspect.isawaitable(result):
-            return await result
+        return await maybe_awaitable(self.backend.fetch(path, file_info))
 
 
 class FileStorageProviderBackend(StorageProvider):
diff --git a/synapse/server.py b/synapse/server.py
index 043810ad31..a198b0eb46 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         return StatsHandler(self)
 
     @cache_in_self
-    def get_spam_checker(self):
+    def get_spam_checker(self) -> SpamChecker:
         return SpamChecker(self)
 
     @cache_in_self
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..9a873c8e8e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
 # limitations under the License.
 
 import collections
+import inspect
 import logging
 from contextlib import contextmanager
 from typing import (
     Any,
+    Awaitable,
     Callable,
     Dict,
     Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
         raise StopIteration(self.value)
 
 
-def maybe_awaitable(value):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
     """Convert a value to an awaitable if not already an awaitable.
     """
-
-    if hasattr(value, "__await__"):
+    if inspect.isawaitable(value):
+        assert isinstance(value, Awaitable)
         return value
 
     return DoneAwaitable(value)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..a6ee9edaec 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import inspect
 import logging
 
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -105,10 +105,7 @@ class Signal:
 
         async def do(observer):
             try:
-                result = observer(*args, **kwargs)
-                if inspect.isawaitable(result):
-                    result = await result
-                return result
+                return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
                 logger.warning(
                     "%s signal observer %s failed: %r", self.name, observer, e,
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 98e5af2072..647a17cb90 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -270,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         spam_checker = self.hs.get_spam_checker()
 
         class AllowAll:
-            def check_username_for_spam(self, user_profile):
+            async def check_username_for_spam(self, user_profile):
                 # Allow all users.
                 return False
 
@@ -283,7 +283,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Configure a spam checker that filters all users.
         class BlockAll:
-            def check_username_for_spam(self, user_profile):
+            async def check_username_for_spam(self, user_profile):
                 # All users are spammy.
                 return True