summary refs log tree commit diff
path: root/synapse/events/spamcheck.py
diff options
context:
space:
mode:
authorDavid Teller <D.O.Teller@gmail.com>2020-12-11 20:05:15 +0100
committerGitHub <noreply@github.com>2020-12-11 14:05:15 -0500
commitf14428b25c37e44675edac4a80d7bd1e47112586 (patch)
tree7565992e70db2c48c7008b2e3fdfe122d315308e /synapse/events/spamcheck.py
parentAdd type hints to the push module. (#8901) (diff)
downloadsynapse-f14428b25c37e44675edac4a80d7bd1e47112586.tar.xz
Allow spam-checker modules to be provide async methods. (#8890)
Spam checker modules can now provide async methods. This is implemented
in a backwards-compatible manner.
Diffstat (limited to 'synapse/events/spamcheck.py')
-rw-r--r--synapse/events/spamcheck.py55
1 files changed, 39 insertions, 16 deletions
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 936896656a..e7e3a7b9a4 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,10 +15,11 @@
 # limitations under the License.
 
 import inspect
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import Collection
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     import synapse.events
@@ -39,7 +40,9 @@ class SpamChecker:
             else:
                 self.spam_checkers.append(module(config=config))
 
-    def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
+    async def check_event_for_spam(
+        self, event: "synapse.events.EventBase"
+    ) -> Union[bool, str]:
         """Checks if a given event is considered "spammy" by this server.
 
         If the server considers an event spammy, then it will be rejected if
@@ -50,15 +53,16 @@ class SpamChecker:
             event: the event to be checked
 
         Returns:
-            True if the event is spammy.
+            True or a string if the event is spammy. If a string is returned it
+            will be used as the error message returned to the user.
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.check_event_for_spam(event):
+            if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
                 return True
 
         return False
 
-    def user_may_invite(
+    async def user_may_invite(
         self, inviter_userid: str, invitee_userid: str, room_id: str
     ) -> bool:
         """Checks if a given user may send an invite
@@ -75,14 +79,18 @@ class SpamChecker:
         """
         for spam_checker in self.spam_checkers:
             if (
-                spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+                await maybe_awaitable(
+                    spam_checker.user_may_invite(
+                        inviter_userid, invitee_userid, room_id
+                    )
+                )
                 is False
             ):
                 return False
 
         return True
 
-    def user_may_create_room(self, userid: str) -> bool:
+    async def user_may_create_room(self, userid: str) -> bool:
         """Checks if a given user may create a room
 
         If this method returns false, the creation request will be rejected.
@@ -94,12 +102,15 @@ class SpamChecker:
             True if the user may create a room, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_create_room(userid) is False:
+            if (
+                await maybe_awaitable(spam_checker.user_may_create_room(userid))
+                is False
+            ):
                 return False
 
         return True
 
-    def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
+    async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
         """Checks if a given user may create a room alias
 
         If this method returns false, the association request will be rejected.
@@ -112,12 +123,17 @@ class SpamChecker:
             True if the user may create a room alias, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+            if (
+                await maybe_awaitable(
+                    spam_checker.user_may_create_room_alias(userid, room_alias)
+                )
+                is False
+            ):
                 return False
 
         return True
 
-    def user_may_publish_room(self, userid: str, room_id: str) -> bool:
+    async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
         """Checks if a given user may publish a room to the directory
 
         If this method returns false, the publish request will be rejected.
@@ -130,12 +146,17 @@ class SpamChecker:
             True if the user may publish the room, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_publish_room(userid, room_id) is False:
+            if (
+                await maybe_awaitable(
+                    spam_checker.user_may_publish_room(userid, room_id)
+                )
+                is False
+            ):
                 return False
 
         return True
 
-    def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+    async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
         """Checks if a user ID or display name are considered "spammy" by this server.
 
         If the server considers a username spammy, then it will not be included in
@@ -157,12 +178,12 @@ class SpamChecker:
             if checker:
                 # Make a copy of the user profile object to ensure the spam checker
                 # cannot modify it.
-                if checker(user_profile.copy()):
+                if await maybe_awaitable(checker(user_profile.copy())):
                     return True
 
         return False
 
-    def check_registration_for_spam(
+    async def check_registration_for_spam(
         self,
         email_threepid: Optional[dict],
         username: Optional[str],
@@ -185,7 +206,9 @@ class SpamChecker:
             # spam checker
             checker = getattr(spam_checker, "check_registration_for_spam", None)
             if checker:
-                behaviour = checker(email_threepid, username, request_info)
+                behaviour = await maybe_awaitable(
+                    checker(email_threepid, username, request_info)
+                )
                 assert isinstance(behaviour, RegistrationBehaviour)
                 if behaviour != RegistrationBehaviour.ALLOW:
                     return behaviour