diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 936896656a..e7e3a7b9a4 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,10 +15,11 @@
# limitations under the License.
import inspect
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import synapse.events
@@ -39,7 +40,9 @@ class SpamChecker:
else:
self.spam_checkers.append(module(config=config))
- def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
+ async def check_event_for_spam(
+ self, event: "synapse.events.EventBase"
+ ) -> Union[bool, str]:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
@@ -50,15 +53,16 @@ class SpamChecker:
event: the event to be checked
Returns:
- True if the event is spammy.
+ True or a string if the event is spammy. If a string is returned it
+ will be used as the error message returned to the user.
"""
for spam_checker in self.spam_checkers:
- if spam_checker.check_event_for_spam(event):
+ if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
return True
return False
- def user_may_invite(
+ async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
) -> bool:
"""Checks if a given user may send an invite
@@ -75,14 +79,18 @@ class SpamChecker:
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ await maybe_awaitable(
+ spam_checker.user_may_invite(
+ inviter_userid, invitee_userid, room_id
+ )
+ )
is False
):
return False
return True
- def user_may_create_room(self, userid: str) -> bool:
+ async def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
@@ -94,12 +102,15 @@ class SpamChecker:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room(userid) is False:
+ if (
+ await maybe_awaitable(spam_checker.user_may_create_room(userid))
+ is False
+ ):
return False
return True
- def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
+ async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
If this method returns false, the association request will be rejected.
@@ -112,12 +123,17 @@ class SpamChecker:
True if the user may create a room alias, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+ if (
+ await maybe_awaitable(
+ spam_checker.user_may_create_room_alias(userid, room_alias)
+ )
+ is False
+ ):
return False
return True
- def user_may_publish_room(self, userid: str, room_id: str) -> bool:
+ async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected.
@@ -130,12 +146,17 @@ class SpamChecker:
True if the user may publish the room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_publish_room(userid, room_id) is False:
+ if (
+ await maybe_awaitable(
+ spam_checker.user_may_publish_room(userid, room_id)
+ )
+ is False
+ ):
return False
return True
- def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+ async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
@@ -157,12 +178,12 @@ class SpamChecker:
if checker:
# Make a copy of the user profile object to ensure the spam checker
# cannot modify it.
- if checker(user_profile.copy()):
+ if await maybe_awaitable(checker(user_profile.copy())):
return True
return False
- def check_registration_for_spam(
+ async def check_registration_for_spam(
self,
email_threepid: Optional[dict],
username: Optional[str],
@@ -185,7 +206,9 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
- behaviour = checker(email_threepid, username, request_info)
+ behaviour = await maybe_awaitable(
+ checker(email_threepid, username, request_info)
+ )
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 38aa47963f..383737520a 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -78,6 +78,7 @@ class FederationBase:
ctx = current_context()
+ @defer.inlineCallbacks
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
@@ -105,7 +106,11 @@ class FederationBase:
)
return redacted_event
- if self.spam_checker.check_event_for_spam(pdu):
+ result = yield defer.ensureDeferred(
+ self.spam_checker.check_event_for_spam(pdu)
+ )
+
+ if result:
logger.warning(
"Event contains spam, redacting %s: %s",
pdu.event_id,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 62f98dabc0..8deec4cd0c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import time
import unicodedata
@@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -1639,6 +1639,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
- result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
- if inspect.isawaitable(result):
- await result
+ await maybe_awaitable(
+ g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ )
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad5683d251..abcf86352d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it"
)
- if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ if not await self.spam_checker.user_may_create_room_alias(
+ user_id, room_alias
+ ):
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_publish_room(user_id, room_id):
+ if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403, "This user is not permitted to publish rooms to the room list"
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index df82e60b33..fd8de8696d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id
):
raise SynapseError(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 96843338ae..2b8aa9443d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -744,7 +744,7 @@ class EventCreationHandler:
event.sender,
)
- spam_error = self.spam_checker.check_event_for_spam(event)
+ spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 153cbae7b9..e850e45e46 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,7 +18,6 @@ from typing import List, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
-from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- await maybe_awaitable(
- self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
- )
+ await self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
)
return True
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0d85fd0868..94b5610acd 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
- result = self.spam_checker.check_registration_for_spam(
+ result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 82fb72b381..7583418946 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
- if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
):
raise SynapseError(403, "You are not permitted to create rooms")
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d85110a35e..cb5a29bc7e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id
):
logger.info("Blocking invite due to spam checker")
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index afbebfc200..f263a638f8 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results.
- results["results"] = [
- user
- for user in results["results"]
- if not self.spam_checker.check_username_for_spam(user)
- ]
+ non_spammy_users = []
+ for user in results["results"]:
+ if not await self.spam_checker.check_username_for_spam(user):
+ non_spammy_users.append(user)
+ results["results"] = non_spammy_users
return results
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 658f6ecd72..76b7decf26 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import threading
from functools import wraps
@@ -25,6 +24,7 @@ from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request})
with ctx:
- result = func(*args, **kwargs)
-
- if inspect.isawaitable(result):
- result = await result
-
- return result
+ return await maybe_awaitable(func(*args, **kwargs))
except Exception:
logger.exception(
"Background process '%s' threw an exception", desc,
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..67f67efde7 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import os
import shutil
@@ -21,6 +20,7 @@ from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder
from .media_storage import FileResponder
@@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(self.backend.store_file(path, file_info))
else:
# TODO: Handle errors.
async def store():
try:
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(
+ self.backend.store_file(path, file_info)
+ )
except Exception:
logger.exception("Error storing file")
@@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
async def fetch(self, path, file_info):
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.fetch(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(self.backend.fetch(path, file_info))
class FileStorageProviderBackend(StorageProvider):
diff --git a/synapse/server.py b/synapse/server.py
index 043810ad31..a198b0eb46 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return StatsHandler(self)
@cache_in_self
- def get_spam_checker(self):
+ def get_spam_checker(self) -> SpamChecker:
return SpamChecker(self)
@cache_in_self
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..9a873c8e8e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
# limitations under the License.
import collections
+import inspect
import logging
from contextlib import contextmanager
from typing import (
Any,
+ Awaitable,
Callable,
Dict,
Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
raise StopIteration(self.value)
-def maybe_awaitable(value):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable.
"""
-
- if hasattr(value, "__await__"):
+ if inspect.isawaitable(value):
+ assert isinstance(value, Awaitable)
return value
return DoneAwaitable(value)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..a6ee9edaec 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -105,10 +105,7 @@ class Signal:
async def do(observer):
try:
- result = observer(*args, **kwargs)
- if inspect.isawaitable(result):
- result = await result
- return result
+ return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
logger.warning(
"%s signal observer %s failed: %r", self.name, observer, e,
|