diff --git a/changelog.d/8031.misc b/changelog.d/8031.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8031.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 2178e623da..d8190f92ab 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import List, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
-from twisted.internet import defer
from twisted.web.server import Request
import synapse.types
@@ -80,13 +79,14 @@ class Auth(object):
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
- @defer.inlineCallbacks
- def check_from_context(self, room_version: str, event, context, do_sig_check=True):
- prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
- auth_events_ids = yield self.compute_auth_events(
+ async def check_from_context(
+ self, room_version: str, event, context, do_sig_check=True
+ ):
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -94,14 +94,13 @@ class Auth(object):
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
)
- @defer.inlineCallbacks
- def check_user_in_room(
+ async def check_user_in_room(
self,
room_id: str,
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
- ):
+ ) -> EventBase:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@@ -119,37 +118,35 @@ class Auth(object):
Raises:
AuthError if the user is/was not in the room.
Returns:
- Deferred[Optional[EventBase]]:
- Membership event for the user if the user was in the
- room. This will be the join event if they are currently joined to
- the room. This will be the leave event if they have left the room.
+ Membership event for the user if the user was in the
+ room. This will be the join event if they are currently joined to
+ the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
- member = yield defer.ensureDeferred(
- self.state.get_current_state(
- room_id=room_id, event_type=EventTypes.Member, state_key=user_id
- )
+ member = await self.state.get_current_state(
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
- membership = member.membership if member else None
- if membership == Membership.JOIN:
- return member
+ if member:
+ membership = member.membership
- # XXX this looks totally bogus. Why do we not allow users who have been banned,
- # or those who were members previously and have been re-invited?
- if allow_departed_users and membership == Membership.LEAVE:
- forgot = yield self.store.did_forget(user_id, room_id)
- if not forgot:
+ if membership == Membership.JOIN:
return member
+ # XXX this looks totally bogus. Why do we not allow users who have been banned,
+ # or those who were members previously and have been re-invited?
+ if allow_departed_users and membership == Membership.LEAVE:
+ forgot = await self.store.did_forget(user_id, room_id)
+ if not forgot:
+ return member
+
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
- @defer.inlineCallbacks
- def check_host_in_room(self, room_id, host):
+ async def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
- latest_event_ids = yield self.store.is_host_joined(room_id, host)
+ latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids
def can_federate(self, event, auth_events):
@@ -160,14 +157,13 @@ class Auth(object):
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event)
- @defer.inlineCallbacks
- def get_user_by_req(
+ async def get_user_by_req(
self,
request: Request,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
- ):
+ ) -> synapse.types.Requester:
""" Get a registered user's ID.
Args:
@@ -180,7 +176,7 @@ class Auth(object):
/login will deliver access tokens regardless of expiration.
Returns:
- defer.Deferred: resolves to a `synapse.types.Requester` object
+ Resolves to the requester
Raises:
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
@@ -194,14 +190,14 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
- user_id, app_service = yield self._get_appservice_user_id(request)
+ user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips:
- yield self.store.insert_client_ip(
+ await self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
ip=ip_addr,
@@ -211,7 +207,7 @@ class Auth(object):
return synapse.types.create_requester(user_id, app_service=app_service)
- user_info = yield self.get_user_by_access_token(
+ user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
@@ -221,7 +217,7 @@ class Auth(object):
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
+ expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
if (
expiration_ts is not None
and self.clock.time_msec() >= expiration_ts
@@ -235,7 +231,7 @@ class Auth(object):
device_id = user_info.get("device_id")
if user and access_token and ip_addr:
- yield self.store.insert_client_ip(
+ await self.store.insert_client_ip(
user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
@@ -261,8 +257,7 @@ class Auth(object):
except KeyError:
raise MissingClientTokenError()
- @defer.inlineCallbacks
- def _get_appservice_user_id(self, request):
+ async def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
@@ -283,14 +278,13 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (yield self.store.get_user_by_id(user_id)):
+ if not (await self.store.get_user_by_id(user_id)):
raise AuthError(403, "Application service has not registered this user")
return user_id, app_service
- @defer.inlineCallbacks
- def get_user_by_access_token(
+ async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
- ):
+ ) -> dict:
""" Validate access token and get user_id from it
Args:
@@ -300,7 +294,7 @@ class Auth(object):
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
- Deferred[dict]: dict that includes:
+ dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
@@ -314,7 +308,7 @@ class Auth(object):
if rights == "access":
# first look in the database
- r = yield self._look_up_user_by_access_token(token)
+ r = await self._look_up_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
if (
@@ -352,7 +346,7 @@ class Auth(object):
# It would of course be much easier to store guest access
# tokens in the database as well, but that would break existing
# guest tokens.
- stored_user = yield self.store.get_user_by_id(user_id)
+ stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
@@ -482,9 +476,8 @@ class Auth(object):
now = self.hs.get_clock().time_msec()
return now < expiry
- @defer.inlineCallbacks
- def _look_up_user_by_access_token(self, token):
- ret = yield self.store.get_user_by_access_token(token)
+ async def _look_up_user_by_access_token(self, token):
+ ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
@@ -507,7 +500,7 @@ class Auth(object):
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
- return defer.succeed(service)
+ return service
async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin.
@@ -522,7 +515,7 @@ class Auth(object):
def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
- ):
+ ) -> List[str]:
"""Given an event and current state return the list of event IDs used
to auth an event.
@@ -530,11 +523,11 @@ class Auth(object):
should be added to the event's `auth_events`.
Returns:
- defer.Deferred(list[str]): List of event IDs.
+ List of event IDs.
"""
if event.type == EventTypes.Create:
- return defer.succeed([])
+ return []
# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
@@ -553,7 +546,7 @@ class Auth(object):
if auth_ev_id:
auth_ids.append(auth_ev_id)
- return defer.succeed(auth_ids)
+ return auth_ids
async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
@@ -636,10 +629,9 @@ class Auth(object):
return query_params[0].decode("ascii")
- @defer.inlineCallbacks
- def check_user_in_room_or_world_readable(
+ async def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False
- ):
+ ) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
@@ -650,10 +642,9 @@ class Auth(object):
members but have now departed
Returns:
- Deferred[tuple[str, str|None]]: Resolves to the current membership of
- the user in the room and the membership event ID of the user. If
- the user is not in the room and never has been, then
- `(Membership.JOIN, None)` is returned.
+ Resolves to the current membership of the user in the room and the
+ membership event ID of the user. If the user is not in the room and
+ never has been, then `(Membership.JOIN, None)` is returned.
"""
try:
@@ -662,15 +653,13 @@ class Auth(object):
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
- member_event = yield self.check_user_in_room(
+ member_event = await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
- visibility = yield defer.ensureDeferred(
- self.state.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
- )
+ visibility = await self.state.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 5c499b6b4e..49093bf181 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
@@ -36,8 +34,7 @@ class AuthBlocking(object):
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
- @defer.inlineCallbacks
- def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
+ async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
@@ -60,7 +57,7 @@ class AuthBlocking(object):
if user_id is not None:
if user_id == self._server_notices_mxid:
return
- if (yield self.store.is_support_user(user_id)):
+ if await self.store.is_support_user(user_id):
return
if self._hs_disabled:
@@ -76,11 +73,11 @@ class AuthBlocking(object):
# If the user is already part of the MAU cohort or a trial user
if user_id:
- timestamp = yield self.store.user_last_seen_monthly_active(user_id)
+ timestamp = await self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
- is_trial = yield self.store.is_trial_user(user_id)
+ is_trial = await self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
@@ -93,7 +90,7 @@ class AuthBlocking(object):
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
- current_mau = yield self.store.get_monthly_active_count()
+ current_mau = await self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value:
raise ResourceLimitError(
403,
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index f988f62a1e..7393d6cb74 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -21,8 +21,6 @@ import jsonschema
from canonicaljson import json
from jsonschema import FormatChecker
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
@@ -137,9 +135,8 @@ class Filtering(object):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_user_filter(self, user_localpart, filter_id):
- result = yield self.store.get_user_filter(user_localpart, filter_id)
+ async def get_user_filter(self, user_localpart, filter_id):
+ result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)
def add_user_filter(self, user_localpart, user_filter):
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 69b53ca2bc..4e179d49b3 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -106,7 +106,7 @@ class EventBuilder(object):
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
- auth_ids = await self._auth.compute_auth_events(self, state_ids)
+ auth_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b3764dedae..593932adb7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2064,7 +2064,7 @@ class FederationHandler(BaseHandler):
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = await self.auth.compute_auth_events(
+ auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 43901d0934..708533d4d1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1061,7 +1061,7 @@ class EventCreationHandler(object):
raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = await self.auth.compute_auth_events(
+ auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 8201849951..c2fb757d9a 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -194,12 +194,16 @@ class ModuleApi(object):
synapse.api.errors.AuthError: the access token is invalid
"""
# see if the access token corresponds to a device
- user_info = yield self._auth.get_user_by_access_token(access_token)
+ user_info = yield defer.ensureDeferred(
+ self._auth.get_user_by_access_token(access_token)
+ )
device_id = user_info.get("device_id")
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
- yield self._hs.get_device_handler().delete_device(user_id, device_id)
+ yield defer.ensureDeferred(
+ self._hs.get_device_handler().delete_device(user_id, device_id)
+ )
else:
# no associated device. Just delete the access token.
yield defer.ensureDeferred(
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 04b9d8ac82..e7fcee0e87 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -120,7 +120,7 @@ class BulkPushRuleEvaluator(object):
pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event}
else:
- auth_events_ids = await self.auth.compute_auth_events(
+ auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events = await self.store.get_events(auth_events_ids)
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 60dd3f6701..a6fdedde63 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -28,7 +28,7 @@ class SlavedClientIpStore(BaseSlavedStore):
name="client_ip_last_seen", keylen=4, max_entries=50000
)
- def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
+ async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 5934b1fe8b..b210015173 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
try:
- service = await self.auth.get_appservice_by_req(request)
+ service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await dir_handler.delete_appservice_association(service, room_alias)
logger.info(
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a4c079196d..c549c090b3 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -424,7 +424,7 @@ class RegisterRestServlet(RestServlet):
appservice = None
if self.auth.has_access_token(request):
- appservice = await self.auth.get_appservice_by_req(request)
+ appservice = self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 712c8d0264..50d71f5ebc 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -380,8 +380,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
- @defer.inlineCallbacks
- def insert_client_ip(
+ async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
if not now:
@@ -392,7 +391,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
- yield self.populate_monthly_active_users(user_id)
+ await self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0bfb86bf1f..5d45689c8c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
+ self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
self.store.is_support_user = Mock(return_value=defer.succeed(False))
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
@@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
- return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
+ return_value=defer.succeed(
+ {"name": "@baldrick:matrix.org", "device_id": "device"}
+ )
)
user_id = "@baldrick:matrix.org"
@@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
- self.store.get_user_by_id = Mock(return_value={"is_guest": True})
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
def get_user(tok):
if token != tok:
- return None
- return {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ return defer.succeed(None)
+ return defer.succeed(
+ {
+ "name": USER_ID,
+ "is_guest": False,
+ "token_id": 1234,
+ "device_id": "DEVICE",
+ }
+ )
self.store.get_user_by_access_token = get_user
- self.store.get_user_by_id = Mock(return_value={"is_guest": False})
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
# check the token works
request = Mock(args={})
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..1fab1d6b69 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart + "2", filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart + "2", filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events=events)
@@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events)
@@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(
user_filter_json,
(
- yield self.datastore.get_user_filter(
- user_localpart=user_localpart, filter_id=0
+ yield defer.ensureDeferred(
+ self.datastore.get_user_filter(
+ user_localpart=user_localpart, filter_id=0
+ )
)
),
)
@@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
user_localpart=user_localpart, user_filter=user_filter_json
)
- filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
self.assertEquals(filter.get_filter_json(), user_filter_json)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5878f74175..b7d0adb10e 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
- return defer.succeed(None)
+ return None
hs.get_auth().check_user_in_room = check_user_in_room
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index f16eef15f7..17d0aae2e9 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -20,6 +20,8 @@ import urllib.parse
from mock import Mock
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError
@@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
# Set monthly active users to the limit
- store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+ store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(self.hs.config.max_mau_value)
+ )
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
self.get_failure(
@@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler,
)
- def _get_user_by_req(request=None, allow_guest=False):
- return defer.succeed(synapse.types.create_requester(myid))
+ async def _get_user_by_req(request=None, allow_guest=False):
+ return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5ccda8b2bd..ef6b775ed2 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -23,8 +23,6 @@ from urllib import parse as urlparse
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock()
- def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 7f70353b0d..3f88abe3d2 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(return_value=1000)
+ self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
diff --git a/tests/unittest.py b/tests/unittest.py
index 2152c693f2..d0bba3ddef 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
- def get_user_by_access_token(token=None, allow_guest=False):
- return succeed(
- {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- )
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return succeed(
- create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
- )
+ async def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ async def get_user_by_req(request, allow_guest=False, rights="access"):
+ return create_requester(
+ UserID.from_string(self.helper.auth_user_id), 1, False, None
)
self.hs.get_auth().get_user_by_req = get_user_by_req
|