diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 3b781d9836..61dc4beafe 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
import synapse.types
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import LimitExceededError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -44,11 +44,26 @@ class BaseHandler(object):
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
- self.ratelimiter = hs.get_ratelimiter()
- self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs
+ # The rate_hz and burst_count are overridden on a per-user basis
+ self.request_ratelimiter = Ratelimiter(
+ clock=self.clock, rate_hz=0, burst_count=0
+ )
+ self._rc_message = self.hs.config.rc_message
+
+ # Check whether ratelimiting room admin message redaction is enabled
+ # by the presence of rate limits in the config
+ if self.hs.config.rc_admin_redaction:
+ self.admin_redaction_ratelimiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=self.hs.config.rc_admin_redaction.per_second,
+ burst_count=self.hs.config.rc_admin_redaction.burst_count,
+ )
+ else:
+ self.admin_redaction_ratelimiter = None
+
self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
@@ -70,7 +85,6 @@ class BaseHandler(object):
Raises:
LimitExceededError if the request should be ratelimited
"""
- time_now = self.clock.time()
user_id = requester.user.to_string()
# The AS user itself is never rate limited.
@@ -83,48 +97,32 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited():
return
+ messages_per_second = self._rc_message.per_second
+ burst_count = self._rc_message.burst_count
+
# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
if override:
- # If overriden with a null Hz then ratelimiting has been entirely
+ # If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
+
+ if is_admin_redaction and self.admin_redaction_ratelimiter:
+ # If we have separate config for admin redactions, use a separate
+ # ratelimiter as to not have user_ids clash
+ self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
else:
- # We default to different values if this is an admin redaction and
- # the config is set
- if is_admin_redaction and self.hs.config.rc_admin_redaction:
- messages_per_second = self.hs.config.rc_admin_redaction.per_second
- burst_count = self.hs.config.rc_admin_redaction.burst_count
- else:
- messages_per_second = self.hs.config.rc_message.per_second
- burst_count = self.hs.config.rc_message.burst_count
-
- if is_admin_redaction and self.hs.config.rc_admin_redaction:
- # If we have separate config for admin redactions we use a separate
- # ratelimiter
- allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action(
- user_id,
- time_now,
- rate_hz=messages_per_second,
- burst_count=burst_count,
- update=update,
- )
- else:
- allowed, time_allowed = self.ratelimiter.can_do_action(
+ # Override rate and burst count per-user
+ self.request_ratelimiter.ratelimit(
user_id,
- time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now))
- )
async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 75b39e878c..119678e67b 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -108,7 +108,11 @@ class AuthHandler(BaseHandler):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
- self._failed_uia_attempts_ratelimiter = Ratelimiter()
+ self._failed_uia_attempts_ratelimiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
self._clock = self.hs.get_clock()
@@ -196,13 +200,7 @@ class AuthHandler(BaseHandler):
user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
- self._failed_uia_attempts_ratelimiter.ratelimit(
- user_id,
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=False,
- )
+ self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types]
@@ -212,14 +210,8 @@ class AuthHandler(BaseHandler):
flows, request, request_body, clientip, description
)
except LoginError:
- # Update the ratelimite to say we failed (`can_do_action` doesn't raise).
- self._failed_uia_attempts_ratelimiter.can_do_action(
- user_id,
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=True,
- )
+ # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
+ self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
raise
# find the completed login type
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 29a19b4572..230d170258 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Any, Dict, Optional
from six import iteritems, itervalues
@@ -30,7 +31,11 @@ from synapse.api.errors import (
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.types import (
+ RoomStreamToken,
+ get_domain_from_id,
+ get_verify_key_from_cross_signing_key,
+)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -704,22 +709,27 @@ class DeviceListUpdater(object):
need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
# Iterate over the set of user IDs.
for user_id in need_resync:
- # Try to resync the current user's devices list. Exception handling
- # isn't necessary here, since user_device_resync catches all instances
- # of "Exception" that might be raised from the federation request. This
- # means that if an exception is raised by this function, it must be
- # because of a database issue, which means _maybe_retry_device_resync
- # probably won't be able to go much further anyway.
- result = yield self.user_device_resync(
- user_id=user_id, mark_failed_as_stale=False,
- )
- # user_device_resync only returns a result if it managed to successfully
- # resync and update the database. Updating the table of users requiring
- # resync isn't necessary here as user_device_resync already does it
- # (through self.store.update_remote_device_list_cache).
- if result:
+ try:
+ # Try to resync the current user's devices list.
+ result = yield self.user_device_resync(
+ user_id=user_id, mark_failed_as_stale=False,
+ )
+
+ # user_device_resync only returns a result if it managed to
+ # successfully resync and update the database. Updating the table
+ # of users requiring resync isn't necessary here as
+ # user_device_resync already does it (through
+ # self.store.update_remote_device_list_cache).
+ if result:
+ logger.debug(
+ "Successfully resynced the device list for %s", user_id,
+ )
+ except Exception as e:
+ # If there was an issue resyncing this user, e.g. if the remote
+ # server sent a malformed result, just log the error instead of
+ # aborting all the subsequent resyncs.
logger.debug(
- "Successfully resynced the device list for %s" % user_id,
+ "Could not resync the device list for %s: %s", user_id, e,
)
finally:
# Allow future calls to retry resyncinc out of sync device lists.
@@ -738,6 +748,7 @@ class DeviceListUpdater(object):
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
+ logger.debug("Attempting to resync the device list for %s", user_id)
log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
@@ -789,6 +800,13 @@ class DeviceListUpdater(object):
stream_id = result["stream_id"]
devices = result["devices"]
+ # Get the master key and the self-signing key for this user if provided in the
+ # response (None if not in the response).
+ # The response will not contain the user signing key, as this key is only used by
+ # its owner, thus it doesn't make sense to send it over federation.
+ master_key = result.get("master_key")
+ self_signing_key = result.get("self_signing_key")
+
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
@@ -818,6 +836,13 @@ class DeviceListUpdater(object):
yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
+
+ # Handle cross-signing keys.
+ cross_signing_device_ids = yield self.process_cross_signing_key_update(
+ user_id, master_key, self_signing_key,
+ )
+ device_ids = device_ids + cross_signing_device_ids
+
yield self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
@@ -825,3 +850,40 @@ class DeviceListUpdater(object):
self._seen_updates[user_id] = {stream_id}
defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def process_cross_signing_key_update(
+ self,
+ user_id: str,
+ master_key: Optional[Dict[str, Any]],
+ self_signing_key: Optional[Dict[str, Any]],
+ ) -> list:
+ """Process the given new master and self-signing key for the given remote user.
+
+ Args:
+ user_id: The ID of the user these keys are for.
+ master_key: The dict of the cross-signing master key as returned by the
+ remote server.
+ self_signing_key: The dict of the cross-signing self-signing key as returned
+ by the remote server.
+
+ Return:
+ The device IDs for the given keys.
+ """
+ device_ids = []
+
+ if master_key:
+ yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+ _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ device_ids.append(verify_key.version)
+ if self_signing_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
+ device_ids.append(verify_key.version)
+
+ return device_ids
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8f1bc0323c..774a252619 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1291,6 +1291,7 @@ class SigningKeyEduUpdater(object):
"""
device_handler = self.e2e_keys_handler.device_handler
+ device_list_updater = device_handler.device_list_updater
with (yield self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
@@ -1303,22 +1304,9 @@ class SigningKeyEduUpdater(object):
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- if master_key:
- yield self.store.set_e2e_cross_signing_key(
- user_id, "master", master_key
- )
- _, verify_key = get_verify_key_from_cross_signing_key(master_key)
- # verify_key is a VerifyKey from signedjson, which uses
- # .version to denote the portion of the key ID after the
- # algorithm and colon, which is the device ID
- device_ids.append(verify_key.version)
- if self_signing_key:
- yield self.store.set_e2e_cross_signing_key(
- user_id, "self_signing", self_signing_key
- )
- _, verify_key = get_verify_key_from_cross_signing_key(
- self_signing_key
- )
- device_ids.append(verify_key.version)
+ new_device_ids = yield device_list_updater.process_cross_signing_key_update(
+ user_id, master_key, self_signing_key,
+ )
+ device_ids = device_ids + new_device_ids
yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index eec8066eeb..bbf23345e2 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -504,7 +504,7 @@ class FederationHandler(BaseHandler):
min_depth=min_depth,
timeout=60000,
)
- except RequestSendFailed as e:
+ except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ca5c83811a..ebe8d25bd8 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -18,8 +18,6 @@ import logging
from six import iteritems
-from twisted.internet import defer
-
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import get_domain_from_id
@@ -92,19 +90,18 @@ class GroupsLocalWorkerHandler(object):
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
- @defer.inlineCallbacks
- def get_group_summary(self, group_id, requester_user_id):
+ async def get_group_summary(self, group_id, requester_user_id):
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
"""
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.get_group_summary(
+ res = await self.groups_server_handler.get_group_summary(
group_id, requester_user_id
)
else:
try:
- res = yield self.transport_client.get_group_summary(
+ res = await self.transport_client.get_group_summary(
get_domain_from_id(group_id), group_id, requester_user_id
)
except HttpResponseException as e:
@@ -122,7 +119,7 @@ class GroupsLocalWorkerHandler(object):
attestation = entry.pop("attestation", {})
try:
if get_domain_from_id(g_user_id) != group_server_name:
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
attestation,
group_id=group_id,
user_id=g_user_id,
@@ -139,19 +136,18 @@ class GroupsLocalWorkerHandler(object):
# Add `is_publicised` flag to indicate whether the user has publicised their
# membership of the group on their profile
- result = yield self.store.get_publicised_groups_for_user(requester_user_id)
+ result = await self.store.get_publicised_groups_for_user(requester_user_id)
is_publicised = group_id in result
res.setdefault("user", {})["is_publicised"] = is_publicised
return res
- @defer.inlineCallbacks
- def get_users_in_group(self, group_id, requester_user_id):
+ async def get_users_in_group(self, group_id, requester_user_id):
"""Get users in a group
"""
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.get_users_in_group(
+ res = await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
return res
@@ -159,7 +155,7 @@ class GroupsLocalWorkerHandler(object):
group_server_name = get_domain_from_id(group_id)
try:
- res = yield self.transport_client.get_users_in_group(
+ res = await self.transport_client.get_users_in_group(
get_domain_from_id(group_id), group_id, requester_user_id
)
except HttpResponseException as e:
@@ -174,7 +170,7 @@ class GroupsLocalWorkerHandler(object):
attestation = entry.pop("attestation", {})
try:
if get_domain_from_id(g_user_id) != group_server_name:
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
attestation,
group_id=group_id,
user_id=g_user_id,
@@ -188,15 +184,13 @@ class GroupsLocalWorkerHandler(object):
return res
- @defer.inlineCallbacks
- def get_joined_groups(self, user_id):
- group_ids = yield self.store.get_joined_groups(user_id)
+ async def get_joined_groups(self, user_id):
+ group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}
- @defer.inlineCallbacks
- def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id):
if self.hs.is_mine_id(user_id):
- result = yield self.store.get_publicised_groups_for_user(user_id)
+ result = await self.store.get_publicised_groups_for_user(user_id)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
@@ -206,7 +200,7 @@ class GroupsLocalWorkerHandler(object):
return {"groups": result}
else:
try:
- bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+ bulk_result = await self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id]
)
except HttpResponseException as e:
@@ -218,8 +212,7 @@ class GroupsLocalWorkerHandler(object):
# TODO: Verify attestations
return {"groups": result}
- @defer.inlineCallbacks
- def bulk_get_publicised_groups(self, user_ids, proxy=True):
+ async def bulk_get_publicised_groups(self, user_ids, proxy=True):
destinations = {}
local_users = set()
@@ -236,7 +229,7 @@ class GroupsLocalWorkerHandler(object):
failed_results = []
for destination, dest_user_ids in iteritems(destinations):
try:
- r = yield self.transport_client.bulk_get_publicised_groups(
+ r = await self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids)
)
results.update(r["users"])
@@ -244,7 +237,7 @@ class GroupsLocalWorkerHandler(object):
failed_results.extend(dest_user_ids)
for uid in local_users:
- results[uid] = yield self.store.get_publicised_groups_for_user(uid)
+ results[uid] = await self.store.get_publicised_groups_for_user(uid)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
@@ -333,12 +326,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- @defer.inlineCallbacks
- def join_group(self, group_id, user_id, content):
+ async def join_group(self, group_id, user_id, content):
"""Request to join a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.join_group(group_id, user_id, content)
+ await self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@@ -346,7 +338,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
content["attestation"] = local_attestation
try:
- res = yield self.transport_client.join_group(
+ res = await self.transport_client.join_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
@@ -356,7 +348,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
remote_attestation = res["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
@@ -366,7 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# TODO: Check that the group is public and we're being added publically
is_publicised = content.get("publicise", False)
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
@@ -379,12 +371,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- @defer.inlineCallbacks
- def accept_invite(self, group_id, user_id, content):
+ async def accept_invite(self, group_id, user_id, content):
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.accept_invite(group_id, user_id, content)
+ await self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@@ -392,7 +383,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
content["attestation"] = local_attestation
try:
- res = yield self.transport_client.accept_group_invite(
+ res = await self.transport_client.accept_group_invite(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
@@ -402,7 +393,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
remote_attestation = res["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
@@ -412,7 +403,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# TODO: Check that the group is public and we're being added publically
is_publicised = content.get("publicise", False)
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
@@ -425,18 +416,17 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- @defer.inlineCallbacks
- def invite(self, group_id, user_id, requester_user_id, config):
+ async def invite(self, group_id, user_id, requester_user_id, config):
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.invite_to_group(
+ res = await self.groups_server_handler.invite_to_group(
group_id, user_id, requester_user_id, content
)
else:
try:
- res = yield self.transport_client.invite_to_group(
+ res = await self.transport_client.invite_to_group(
get_domain_from_id(group_id),
group_id,
user_id,
@@ -450,8 +440,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- @defer.inlineCallbacks
- def on_invite(self, group_id, user_id, content):
+ async def on_invite(self, group_id, user_id, content):
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
@@ -466,7 +455,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
if "avatar_url" in content["profile"]:
local_profile["avatar_url"] = content["profile"]["avatar_url"]
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="invite",
@@ -474,7 +463,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
try:
- user_profile = yield self.profile_handler.get_profile(user_id)
+ user_profile = await self.profile_handler.get_profile(user_id)
except Exception as e:
logger.warning("No profile for user %s: %s", user_id, e)
user_profile = {}
@@ -516,12 +505,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- @defer.inlineCallbacks
- def user_removed_from_group(self, group_id, user_id, content):
+ async def user_removed_from_group(self, group_id, user_id, content):
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index e0c426a13b..6039034c00 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -66,8 +66,7 @@ class IdentityHandler(BaseHandler):
self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
self._enable_lookup = hs.config.enable_3pid_lookup
- @defer.inlineCallbacks
- def threepid_from_creds(self, id_server_url, creds):
+ async def threepid_from_creds(self, id_server_url, creds):
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
@@ -110,7 +109,7 @@ class IdentityHandler(BaseHandler):
)
try:
- data = yield self.http_client.get_json(url, query_params)
+ data = await self.http_client.get_json(url, query_params)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
@@ -133,8 +132,7 @@ class IdentityHandler(BaseHandler):
logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
return None
- @defer.inlineCallbacks
- def bind_threepid(
+ async def bind_threepid(
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
):
"""Bind a 3PID to an identity server
@@ -179,12 +177,12 @@ class IdentityHandler(BaseHandler):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
bind_url, bind_data, headers=headers
)
# Remember where we bound the threepid
- yield self.store.add_user_bound_threepid(
+ await self.store.add_user_bound_threepid(
user_id=mxid,
medium=data["medium"],
address=data["address"],
@@ -203,13 +201,12 @@ class IdentityHandler(BaseHandler):
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
- res = yield self.bind_threepid(
+ res = await self.bind_threepid(
client_secret, sid, mxid, id_server, id_access_token, use_v2=False
)
return res
- @defer.inlineCallbacks
- def try_unbind_threepid(self, mxid, threepid):
+ async def try_unbind_threepid(self, mxid, threepid):
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
@@ -229,7 +226,7 @@ class IdentityHandler(BaseHandler):
if threepid.get("id_server"):
id_servers = [threepid["id_server"]]
else:
- id_servers = yield self.store.get_id_servers_user_bound(
+ id_servers = await self.store.get_id_servers_user_bound(
user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)
@@ -239,14 +236,13 @@ class IdentityHandler(BaseHandler):
changed = True
for id_server in id_servers:
- changed &= yield self.try_unbind_threepid_with_id_server(
+ changed &= await self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server
)
return changed
- @defer.inlineCallbacks
- def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
+ async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
"""Removes a binding from an identity server
Args:
@@ -291,7 +287,7 @@ class IdentityHandler(BaseHandler):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
- yield self.blacklisting_http_client.post_json_get_json(
+ await self.blacklisting_http_client.post_json_get_json(
url, content, headers
)
changed = True
@@ -306,7 +302,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
- yield self.store.remove_user_bound_threepid(
+ await self.store.remove_user_bound_threepid(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
@@ -420,8 +416,7 @@ class IdentityHandler(BaseHandler):
logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
return rewritten_url
- @defer.inlineCallbacks
- def requestEmailToken(
+ async def requestEmailToken(
self, id_server_url, email, client_secret, send_attempt, next_link=None
):
"""
@@ -461,7 +456,7 @@ class IdentityHandler(BaseHandler):
)
try:
- data = yield self.http_client.post_json_get_json(
+ data = await self.http_client.post_json_get_json(
"%s/_matrix/identity/api/v1/validate/email/requestToken"
% (id_server_url,),
params,
@@ -473,8 +468,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
- @defer.inlineCallbacks
- def requestMsisdnToken(
+ async def requestMsisdnToken(
self,
id_server_url,
country,
@@ -519,7 +513,7 @@ class IdentityHandler(BaseHandler):
# apply it now.
id_server_url = self.rewrite_id_server_url(id_server_url)
try:
- data = yield self.http_client.post_json_get_json(
+ data = await self.http_client.post_json_get_json(
"%s/_matrix/identity/api/v1/validate/msisdn/requestToken"
% (id_server_url,),
params,
@@ -541,8 +535,7 @@ class IdentityHandler(BaseHandler):
)
return data
- @defer.inlineCallbacks
- def validate_threepid_session(self, client_secret, sid):
+ async def validate_threepid_session(self, client_secret, sid):
"""Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally.
@@ -564,12 +557,12 @@ class IdentityHandler(BaseHandler):
# Try to validate as email
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server
- validation_session = yield self.threepid_from_creds(
+ validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
- validation_session = yield self.store.get_threepid_validation_session(
+ validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
@@ -579,14 +572,13 @@ class IdentityHandler(BaseHandler):
# Try to validate as msisdn
if self.hs.config.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
- validation_session = yield self.threepid_from_creds(
+ validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
return validation_session
- @defer.inlineCallbacks
- def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+ async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
"""Proxy a POST submitToken request to an identity server for verification purposes
Args:
@@ -607,11 +599,9 @@ class IdentityHandler(BaseHandler):
body = {"client_secret": client_secret, "sid": sid, "token": token}
try:
- return (
- yield self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
- body,
- )
+ return await self.http_client.post_json_get_json(
+ id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
+ body,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@@ -663,7 +653,7 @@ class IdentityHandler(BaseHandler):
logger.info("Failed to contact %s: %s", id_server, e)
raise ProxiedRequestError(503, "Failed to contact identity server")
- defer.returnValue(data)
+ return data
@defer.inlineCallbacks
def proxy_bulk_lookup_3pid(self, id_server, threepids):
@@ -702,8 +692,7 @@ class IdentityHandler(BaseHandler):
defer.returnValue(data)
- @defer.inlineCallbacks
- def lookup_3pid(self, id_server, medium, address, id_access_token=None):
+ async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
Args:
@@ -722,7 +711,7 @@ class IdentityHandler(BaseHandler):
if id_access_token is not None:
try:
- results = yield self._lookup_3pid_v2(
+ results = await self._lookup_3pid_v2(
id_server_url, id_access_token, medium, address
)
return results
@@ -741,10 +730,9 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return (yield self._lookup_3pid_v1(id_server, id_server_url, medium, address))
+ return await self._lookup_3pid_v1(id_server, id_server_url, medium, address)
- @defer.inlineCallbacks
- def _lookup_3pid_v1(self, id_server, id_server_url, medium, address):
+ async def _lookup_3pid_v1(self, id_server, id_server_url, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
@@ -758,7 +746,7 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = yield self.http_client.get_json(
+ data = await self.http_client.get_json(
"%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
{"medium": medium, "address": address},
)
@@ -766,7 +754,7 @@ class IdentityHandler(BaseHandler):
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
- yield self._verify_any_signature(data, id_server)
+ await self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@@ -775,8 +763,7 @@ class IdentityHandler(BaseHandler):
return None
- @defer.inlineCallbacks
- def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address):
+ async def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
@@ -790,7 +777,7 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = yield self.http_client.get_json(
+ hash_details = await self.http_client.get_json(
"%s/_matrix/identity/v2/hash_details" % (id_server_url,),
{"access_token": id_access_token},
)
@@ -856,7 +843,7 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = yield self.http_client.post_json_get_json(
+ lookup_results = await self.http_client.post_json_get_json(
"%s/_matrix/identity/v2/lookup" % (id_server_url,),
{
"addresses": [lookup_value],
@@ -884,15 +871,14 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
- @defer.inlineCallbacks
- def _verify_any_signature(self, data, id_server):
+ async def _verify_any_signature(self, data, id_server):
if id_server not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (id_server,))
for key_name, signature in data["signatures"][id_server].items():
id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
- key_data = yield self.http_client.get_json(
+ key_data = await self.http_client.get_json(
"%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_url, key_name)
)
if "public_key" not in key_data:
@@ -910,8 +896,7 @@ class IdentityHandler(BaseHandler):
raise AuthError(401, "No signature from server %s" % (id_server,))
- @defer.inlineCallbacks
- def ask_id_server_for_third_party_invite(
+ async def ask_id_server_for_third_party_invite(
self,
requester,
id_server,
@@ -986,7 +971,7 @@ class IdentityHandler(BaseHandler):
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
@@ -1005,7 +990,7 @@ class IdentityHandler(BaseHandler):
url = base_url + "/api/v1/store-invite"
try:
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
@@ -1020,7 +1005,7 @@ class IdentityHandler(BaseHandler):
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
- data = yield self.blacklisting_http_client.post_urlencoded_get_json(
+ data = await self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 681f92cafd..649ca1f08a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -362,7 +362,6 @@ class EventCreationHandler(object):
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
- self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 4ba8c7fda5..9c08eb5399 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -37,6 +37,7 @@ from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable
from synapse.push.mailer import load_jinja2_templates
from synapse.server import HomeServer
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -99,7 +100,6 @@ class OidcHandler:
hs.config.oidc_client_auth_method,
) # type: ClientAuth
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
- self._subject_claim = hs.config.oidc_subject_claim
self._provider_metadata = OpenIDProviderMetadata(
issuer=hs.config.oidc_issuer,
authorization_endpoint=hs.config.oidc_authorization_endpoint,
@@ -310,6 +310,10 @@ class OidcHandler:
received in the callback to exchange it for a token. The call uses the
``ClientAuth`` to authenticate with the client with its ID and secret.
+ See:
+ https://tools.ietf.org/html/rfc6749#section-3.2
+ https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
+
Args:
code: The authorization code we got from the callback.
@@ -362,7 +366,7 @@ class OidcHandler:
code=response.code, phrase=response.phrase.decode("utf-8")
)
- resp_body = await readBody(response)
+ resp_body = await make_deferred_yieldable(readBody(response))
if response.code >= 500:
# In case of a server error, we should first try to decode the body
@@ -484,6 +488,7 @@ class OidcHandler:
claims_params=claims_params,
)
except ValueError:
+ logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
token["id_token"],
@@ -592,6 +597,9 @@ class OidcHandler:
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
+ # error response from the auth server. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+ # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
@@ -605,8 +613,11 @@ class OidcHandler:
self._render_error(request, error, description)
return
+ # otherwise, it is presumably a successful response. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2
+
# Fetch the session cookie
- session = request.getCookie(SESSION_COOKIE_NAME)
+ session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found")
@@ -654,7 +665,7 @@ class OidcHandler:
self._render_error(request, "invalid_request", "Code parameter is missing")
return
- logger.info("Exchanging code")
+ logger.debug("Exchanging code")
code = request.args[b"code"][0].decode()
try:
token = await self._exchange_code(code)
@@ -663,10 +674,12 @@ class OidcHandler:
self._render_error(request, e.error, e.error_description)
return
+ logger.debug("Successfully obtained OAuth2 access token")
+
# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
- logger.info("Fetching userinfo")
+ logger.debug("Fetching userinfo")
try:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
@@ -674,7 +687,7 @@ class OidcHandler:
self._render_error(request, "fetch_error", str(e))
return
else:
- logger.info("Extracting userinfo from id_token")
+ logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
@@ -750,7 +763,7 @@ class OidcHandler:
return macaroon.serialize()
def _verify_oidc_session_token(
- self, session: str, state: str
+ self, session: bytes, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8c6f61d9d1..d5d44de8d0 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,8 +16,6 @@
"""Contains functions for registering clients."""
import logging
-from twisted.internet import defer
-
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@@ -78,8 +76,7 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
- @defer.inlineCallbacks
- def check_username(
+ async def check_username(
self, localpart, guest_access_token=None, assigned_user_id=None,
):
"""
@@ -128,7 +125,7 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME,
)
- users = yield self.store.get_users_by_id_case_insensitive(user_id)
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
raise SynapseError(
@@ -136,7 +133,7 @@ class RegistrationHandler(BaseHandler):
)
# Retrieve guest user information from provided access token
- user_data = yield self.auth.get_user_by_access_token(guest_access_token)
+ user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
403,
@@ -145,8 +142,16 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.FORBIDDEN,
)
- @defer.inlineCallbacks
- def register_user(
+ if guest_access_token is None:
+ try:
+ int(localpart)
+ raise SynapseError(
+ 400, "Numeric user IDs are reserved for guest users."
+ )
+ except ValueError:
+ pass
+
+ async def register_user(
self,
localpart=None,
password_hash=None,
@@ -158,6 +163,7 @@ class RegistrationHandler(BaseHandler):
default_display_name=None,
address=None,
bind_emails=[],
+ by_admin=False,
):
"""Registers a new client on the server.
@@ -173,29 +179,24 @@ class RegistrationHandler(BaseHandler):
will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the registration.
bind_emails (List[str]): list of emails to bind to this account.
+ by_admin (bool): True if this registration is being made via the
+ admin api, otherwise False.
Returns:
- Deferred[str]: user_id
+ str: user_id
Raises:
SynapseError if there was a problem registering.
"""
- yield self.check_registration_ratelimit(address)
+ self.check_registration_ratelimit(address)
- yield self.auth.check_auth_blocking(threepid=threepid)
+ # do not check_auth_blocking if the call is coming through the Admin API
+ if not by_admin:
+ await self.auth.check_auth_blocking(threepid=threepid)
if localpart is not None:
- yield self.check_username(localpart, guest_access_token=guest_access_token)
+ await self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
- if not was_guest:
- try:
- int(localpart)
- raise SynapseError(
- 400, "Numeric user IDs are reserved for guest users."
- )
- except ValueError:
- pass
-
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -206,7 +207,7 @@ class RegistrationHandler(BaseHandler):
elif default_display_name is None:
default_display_name = localpart
- yield self.register_with_store(
+ await self.register_with_store(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -218,15 +219,13 @@ class RegistrationHandler(BaseHandler):
)
if default_display_name:
- yield defer.ensureDeferred(
- self.profile_handler.set_displayname(
- user, None, default_display_name, by_admin=True
- )
+ await self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
)
if self.hs.config.user_directory_search_all_users:
- profile = yield self.store.get_profileinfo(localpart)
- yield self.user_directory_handler.handle_local_profile_change(
+ profile = await self.store.get_profileinfo(localpart)
+ await self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
@@ -239,14 +238,14 @@ class RegistrationHandler(BaseHandler):
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
- localpart = yield self._generate_user_id()
+ localpart = await self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
- yield self.check_user_id_not_appservice_exclusive(user_id)
+ self.check_user_id_not_appservice_exclusive(user_id)
if default_display_name is None:
default_display_name = localpart
try:
- yield self.register_with_store(
+ await self.register_with_store(
user_id=user_id,
password_hash=password_hash,
make_guest=make_guest,
@@ -254,10 +253,8 @@ class RegistrationHandler(BaseHandler):
address=address,
)
- yield defer.ensureDeferred(
- self.profile_handler.set_displayname(
- user, None, default_display_name, by_admin=True
- )
+ await self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
)
# Successfully registered
@@ -269,7 +266,13 @@ class RegistrationHandler(BaseHandler):
fail_count += 1
if not self.hs.config.user_consent_at_registration:
- yield defer.ensureDeferred(self._auto_join_rooms(user_id))
+ if not self.hs.config.auto_join_rooms_for_guests and make_guest:
+ logger.info(
+ "Skipping auto-join for %s because auto-join for guests is disabled",
+ user_id,
+ )
+ else:
+ await self._auto_join_rooms(user_id)
else:
logger.info(
"Skipping auto-join for %s because consent is required at registration",
@@ -287,15 +290,15 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- yield self.register_email_threepid(user_id, threepid_dict, None)
+ await self.register_email_threepid(user_id, threepid_dict, None)
# Prevent the new user from showing up in the user directory if the server
# mandates it.
if not self._show_in_user_directory:
- yield self.store.add_account_data_for_user(
+ await self.store.add_account_data_for_user(
user_id, "im.vector.hide_profile", {"hide_profile": True}
)
- yield self.profile_handler.set_active([user], False, True)
+ await self.profile_handler.set_active([user], False, True)
return user_id
@@ -360,12 +363,10 @@ class RegistrationHandler(BaseHandler):
"""
await self._auto_join_rooms(user_id)
- @defer.inlineCallbacks
- def appservice_register(
+ async def appservice_register(
self, user_localpart, as_token, password_hash, display_name
):
# FIXME: this should be factored out and merged with normal register()
-
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -380,28 +381,24 @@ class RegistrationHandler(BaseHandler):
service_id = service.id if service.is_exclusive_user(user_id) else None
- yield self.check_user_id_not_appservice_exclusive(
- user_id, allowed_appservice=service
- )
+ self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
display_name = display_name or user.localpart
- yield self.register_with_store(
+ await self.register_with_store(
user_id=user_id,
password_hash=password_hash,
appservice_id=service_id,
create_profile_with_displayname=display_name,
)
- yield defer.ensureDeferred(
- self.profile_handler.set_displayname(
- user, None, display_name, by_admin=True
- )
+ await self.profile_handler.set_displayname(
+ user, None, display_name, by_admin=True
)
if self.hs.config.user_directory_search_all_users:
- profile = yield self.store.get_profileinfo(user_localpart)
- yield self.user_directory_handler.handle_local_profile_change(
+ profile = await self.store.get_profileinfo(user_localpart)
+ await self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
@@ -431,8 +428,7 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- @defer.inlineCallbacks
- def shadow_register(self, localpart, display_name, auth_result, params):
+ async def shadow_register(self, localpart, display_name, auth_result, params):
"""Invokes the current registration on another server, using
shared secret registration, passing in any auth_results from
other registration UI auth flows (e.g. validated 3pids)
@@ -443,7 +439,7 @@ class RegistrationHandler(BaseHandler):
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
{
# XXX: auth_result is an unspecified extension for shadow registration
@@ -463,13 +459,12 @@ class RegistrationHandler(BaseHandler):
},
)
- @defer.inlineCallbacks
- def _generate_user_id(self):
+ async def _generate_user_id(self):
if self._next_generated_user_id is None:
- with (yield self._generate_user_id_linearizer.queue(())):
+ with await self._generate_user_id_linearizer.queue(()):
if self._next_generated_user_id is None:
self._next_generated_user_id = (
- yield self.store.find_next_generated_user_id_localpart()
+ await self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
@@ -514,14 +509,7 @@ class RegistrationHandler(BaseHandler):
if not address:
return
- time_now = self.clock.time()
-
- self.ratelimiter.ratelimit(
- address,
- time_now_s=time_now,
- rate_hz=self.hs.config.rc_registration.per_second,
- burst_count=self.hs.config.rc_registration.burst_count,
- )
+ self.ratelimiter.ratelimit(address)
def register_with_store(
self,
@@ -579,8 +567,9 @@ class RegistrationHandler(BaseHandler):
user_type=user_type,
)
- @defer.inlineCallbacks
- def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
+ async def register_device(
+ self, user_id, device_id, initial_display_name, is_guest=False
+ ):
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
@@ -594,11 +583,11 @@ class RegistrationHandler(BaseHandler):
is_guest (bool): Whether this is a guest account
Returns:
- defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+ tuple[str, str]: Tuple of device ID and access token
"""
if self.hs.config.worker_app:
- r = yield self._register_device_client(
+ r = await self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
@@ -614,7 +603,7 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
- device_id = yield self.device_handler.check_device_registered(
+ device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
@@ -623,10 +612,8 @@ class RegistrationHandler(BaseHandler):
user_id, ["guest = true"]
)
else:
- access_token = yield defer.ensureDeferred(
- self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id, valid_until_ms=valid_until_ms
- )
+ access_token = await self._auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
return (device_id, access_token)
@@ -706,8 +693,7 @@ class RegistrationHandler(BaseHandler):
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- @defer.inlineCallbacks
- def register_email_threepid(self, user_id, threepid, token):
+ async def register_email_threepid(self, user_id, threepid, token):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
@@ -720,8 +706,6 @@ class RegistrationHandler(BaseHandler):
threepid (object): m.login.email.identity auth response
token (str|None): access_token for the user, or None if not logged
in.
- Returns:
- defer.Deferred:
"""
reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
@@ -729,13 +713,8 @@ class RegistrationHandler(BaseHandler):
logger.info("Can't add incomplete 3pid")
return
- yield defer.ensureDeferred(
- self._auth_handler.add_threepid(
- user_id,
- threepid["medium"],
- threepid["address"],
- threepid["validated_at"],
- )
+ await self._auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
)
# And we add an email pusher for them by default, but only
@@ -751,10 +730,10 @@ class RegistrationHandler(BaseHandler):
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
- user_tuple = yield self.store.get_user_by_access_token(token)
+ user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"]
- yield self.pusher_pool.add_pusher(
+ await self.pusher_pool.add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
@@ -766,8 +745,7 @@ class RegistrationHandler(BaseHandler):
data={},
)
- @defer.inlineCallbacks
- def _register_msisdn_threepid(self, user_id, threepid):
+ async def _register_msisdn_threepid(self, user_id, threepid):
"""Add a phone number as a 3pid identifier
Must be called on master.
@@ -775,8 +753,6 @@ class RegistrationHandler(BaseHandler):
Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
- Returns:
- defer.Deferred:
"""
try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
@@ -787,11 +763,6 @@ class RegistrationHandler(BaseHandler):
return None
raise
- yield defer.ensureDeferred(
- self._auth_handler.add_threepid(
- user_id,
- threepid["medium"],
- threepid["address"],
- threepid["validated_at"],
- )
+ await self._auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index e75dabcd77..4cbc02b0d0 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -253,10 +253,21 @@ class RoomListHandler(BaseHandler):
"""
result = {"room_id": room_id, "num_joined_members": num_joined_users}
+ if with_alias:
+ aliases = yield self.store.get_aliases_for_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ if aliases:
+ result["aliases"] = aliases
+
current_state_ids = yield self.store.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
+ if not current_state_ids:
+ # We're not in the room, so may as well bail out here.
+ return result
+
event_map = yield self.store.get_events(
[
event_id
@@ -289,14 +300,7 @@ class RoomListHandler(BaseHandler):
create_event = current_state.get((EventTypes.Create, ""))
result["m.federate"] = create_event.content.get("m.federate", True)
- if with_alias:
- aliases = yield self.store.get_aliases_for_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- if aliases:
- result["aliases"] = aliases
-
- name_event = yield current_state.get((EventTypes.Name, ""))
+ name_event = current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e7015c704f..abecaa8313 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,11 +23,9 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
-from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
-from synapse.module_api.errors import RedirectException
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
@@ -80,8 +78,6 @@ class SamlHandler:
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
- self._error_html_content = hs.config.saml2_error_html_content
-
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
) -> bytes:
@@ -129,26 +125,9 @@ class SamlHandler:
# the dict.
self.expire_sessions()
- try:
- user_id, current_session = await self._map_saml_response_to_user(
- resp_bytes, relay_state
- )
- except RedirectException:
- # Raise the exception as per the wishes of the SAML module response
- raise
- except Exception as e:
- # If decoding the response or mapping it to a user failed, then log the
- # error and tell the user that something went wrong.
- logger.error(e)
-
- request.setResponseCode(400)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(
- b"Content-Length", b"%d" % (len(self._error_html_content),)
- )
- request.write(self._error_html_content.encode("utf8"))
- finish_request(request)
- return
+ user_id, current_session = await self._map_saml_response_to_user(
+ resp_bytes, relay_state
+ )
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
@@ -171,6 +150,11 @@ class SamlHandler:
Returns:
Tuple of the user ID and SAML session associated with this response.
+
+ Raises:
+ SynapseError if there was a problem with the response.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
"""
try:
saml2_auth = self._saml_client.parse_authn_request_response(
@@ -179,11 +163,9 @@ class SamlHandler:
outstanding=self._outstanding_requests_dict,
)
except Exception as e:
- logger.warning("Exception parsing SAML2 response: %s", e)
raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,))
if saml2_auth.not_signed:
- logger.warning("SAML2 response was not signed")
raise SynapseError(400, "SAML2 response was not signed")
logger.debug("SAML2 response: %s", saml2_auth.origxml)
@@ -264,13 +246,13 @@ class SamlHandler:
localpart = attribute_dict.get("mxid_localpart")
if not localpart:
- logger.error(
- "SAML mapping provider plugin did not return a "
- "mxid_localpart object"
+ raise Exception(
+ "Error parsing SAML2 response: SAML mapping provider plugin "
+ "did not return a mxid_localpart value"
)
- raise SynapseError(500, "Error parsing SAML2 response")
displayname = attribute_dict.get("displayname")
+ emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
@@ -288,7 +270,9 @@ class SamlHandler:
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayname
+ localpart=localpart,
+ default_display_name=displayname,
+ bind_emails=emails,
)
await self._datastore.record_user_external_id(
@@ -381,6 +365,7 @@ class DefaultSamlMappingProvider(object):
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
+ * emails (list[str]): Any emails for the user
"""
try:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
@@ -403,9 +388,13 @@ class DefaultSamlMappingProvider(object):
# If displayname is None, the mxid_localpart will be used instead
displayname = saml_response.ava.get("displayName", [None])[0]
+ # Retrieve any emails present in the saml response
+ emails = saml_response.ava.get("email", [])
+
return {
"mxid_localpart": localpart,
"displayname": displayname,
+ "emails": emails,
}
@staticmethod
@@ -444,4 +433,4 @@ class DefaultSamlMappingProvider(object):
second set consists of those attributes which can be used if
available, but are not necessary
"""
- return {"uid", config.mxid_source_attribute}, {"displayName"}
+ return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index f065970c40..8590c1eff4 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
logger = logging.getLogger(__name__)
@@ -24,8 +22,7 @@ class StateDeltasHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
@@ -41,10 +38,10 @@ class StateDeltasHandler(object):
prev_event = None
event = None
if prev_event_id:
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if event_id:
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index d93a276693..149f861239 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -16,17 +16,14 @@
import logging
from collections import Counter
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
-from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
logger = logging.getLogger(__name__)
-class StatsHandler(StateDeltasHandler):
+class StatsHandler:
"""Handles keeping the *_stats tables updated with a simple time-series of
information about the users, rooms and media on the server, such that admins
have some idea of who is consuming their resources.
@@ -35,7 +32,6 @@ class StatsHandler(StateDeltasHandler):
"""
def __init__(self, hs):
- super(StatsHandler, self).__init__(hs)
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -68,20 +64,18 @@ class StatsHandler(StateDeltasHandler):
self._is_processing = True
- @defer.inlineCallbacks
- def process():
+ async def process():
try:
- yield self._unsafe_process()
+ await self._unsafe_process()
finally:
self._is_processing = False
run_as_background_process("stats.notify_new_event", process)
- @defer.inlineCallbacks
- def _unsafe_process(self):
+ async def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
- self.pos = yield self.store.get_stats_positions()
+ self.pos = await self.store.get_stats_positions()
# Loop round handling deltas until we're up to date
@@ -96,13 +90,13 @@ class StatsHandler(StateDeltasHandler):
logger.debug(
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
)
- max_pos, deltas = yield self.store.get_current_state_deltas(
+ max_pos, deltas = await self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
if deltas:
logger.debug("Handling %d state deltas", len(deltas))
- room_deltas, user_deltas = yield self._handle_deltas(deltas)
+ room_deltas, user_deltas = await self._handle_deltas(deltas)
else:
room_deltas = {}
user_deltas = {}
@@ -111,7 +105,7 @@ class StatsHandler(StateDeltasHandler):
(
room_count,
user_count,
- ) = yield self.store.get_changes_room_total_events_and_bytes(
+ ) = await self.store.get_changes_room_total_events_and_bytes(
self.pos, max_pos
)
@@ -125,7 +119,7 @@ class StatsHandler(StateDeltasHandler):
logger.debug("user_deltas: %s", user_deltas)
# Always call this so that we update the stats position.
- yield self.store.bulk_update_stats_delta(
+ await self.store.bulk_update_stats_delta(
self.clock.time_msec(),
updates={"room": room_deltas, "user": user_deltas},
stream_id=max_pos,
@@ -137,13 +131,12 @@ class StatsHandler(StateDeltasHandler):
self.pos = max_pos
- @defer.inlineCallbacks
- def _handle_deltas(self, deltas):
+ async def _handle_deltas(self, deltas):
"""Called with the state deltas to process
Returns:
- Deferred[tuple[dict[str, Counter], dict[str, counter]]]
- Resovles to two dicts, the room deltas and the user deltas,
+ tuple[dict[str, Counter], dict[str, counter]]
+ Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields.
"""
@@ -162,7 +155,7 @@ class StatsHandler(StateDeltasHandler):
logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id)
- token = yield self.store.get_earliest_token_for_stats("room", room_id)
+ token = await self.store.get_earliest_token_for_stats("room", room_id)
# If the earliest token to begin from is larger than our current
# stream ID, skip processing this delta.
@@ -184,7 +177,7 @@ class StatsHandler(StateDeltasHandler):
sender = None
if event_id is not None:
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if event:
event_content = event.content or {}
sender = event.sender
@@ -200,16 +193,16 @@ class StatsHandler(StateDeltasHandler):
room_stats_delta["current_state_events"] += 1
if typ == EventTypes.Member:
- # we could use _get_key_change here but it's a bit inefficient
- # given we're not testing for a specific result; might as well
- # just grab the prev_membership and membership strings and
- # compare them.
+ # we could use StateDeltasHandler._get_key_change here but it's
+ # a bit inefficient given we're not testing for a specific
+ # result; might as well just grab the prev_membership and
+ # membership strings and compare them.
# We take None rather than leave as a previous membership
# in the absence of a previous event because we do not want to
# reduce the leave count when a new-to-the-room user joins.
prev_membership = None
if prev_event_id is not None:
- prev_event = yield self.store.get_event(
+ prev_event = await self.store.get_event(
prev_event_id, allow_none=True
)
if prev_event:
@@ -301,6 +294,6 @@ class StatsHandler(StateDeltasHandler):
for room_id, state in room_to_state_updates.items():
logger.debug("Updating room_stats_state for %s: %s", room_id, state)
- yield self.store.update_room_state(room_id, state)
+ await self.store.update_room_state(room_id, state)
return room_to_stats_deltas, user_to_stats_deltas
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 00718d7f2d..6bdb24baff 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1370,7 +1370,7 @@ class SyncHandler(object):
sync_result_builder.now_token = now_token
# We check up front if anything has changed, if it hasn't then there is
- # no point in going futher.
+ # no point in going further.
since_token = sync_result_builder.since_token
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8363d887a9..8b24a73319 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker:
self.hs = hs
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def _check_threepid(self, medium, authdict):
+ async def _check_threepid(self, medium, authdict):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker:
raise SynapseError(
400, "Phone number verification is not enabled on this homeserver"
)
- threepid = yield identity_handler.threepid_from_creds(
+ threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
- threepid = yield identity_handler.threepid_from_creds(
+ threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
threepid = None
- row = yield self.store.get_threepid_validation_session(
+ row = await self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
@@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker:
}
# Valid threepid returned, delete from the db
- yield self.store.delete_threepid_session(threepid_creds["sid"])
+ await self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(
400, "Email address verification is not enabled on this homeserver"
@@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
)
def check_auth(self, authdict, clientip):
- return self._check_threepid("email", authdict)
+ return defer.ensureDeferred(self._check_threepid("email", authdict))
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return bool(self.hs.config.account_threepid_delegate_msisdn)
def check_auth(self, authdict, clientip):
- return self._check_threepid("msisdn", authdict)
+ return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
INTERACTIVE_AUTH_CHECKERS = [
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 722760c59d..12423b909a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -17,14 +17,11 @@ import logging
from six import iteritems, iterkeys
-from twisted.internet import defer
-
import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
-from synapse.types import get_localpart_from_id
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -103,43 +100,39 @@ class UserDirectoryHandler(StateDeltasHandler):
if self._is_processing:
return
- @defer.inlineCallbacks
- def process():
+ async def process():
try:
- yield self._unsafe_process()
+ await self._unsafe_process()
finally:
self._is_processing = False
self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process)
- @defer.inlineCallbacks
- def handle_local_profile_change(self, user_id, profile):
+ async def handle_local_profile_change(self, user_id, profile):
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- is_support = yield self.store.is_support_user(user_id)
+ is_support = await self.store.is_support_user(user_id)
# Support users are for diagnostics and should not appear in the user directory.
if not is_support:
- yield self.store.update_profile_in_user_dir(
+ await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- @defer.inlineCallbacks
- def handle_user_deactivated(self, user_id):
+ async def handle_user_deactivated(self, user_id):
"""Called when a user ID is deactivated
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- yield self.store.remove_from_user_dir(user_id)
+ await self.store.remove_from_user_dir(user_id)
- @defer.inlineCallbacks
- def _unsafe_process(self):
+ async def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
- self.pos = yield self.store.get_user_directory_stream_pos()
+ self.pos = await self.store.get_user_directory_stream_pos()
# If still None then the initial background update hasn't happened yet
if self.pos is None:
@@ -155,12 +148,12 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug(
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
)
- max_pos, deltas = yield self.store.get_current_state_deltas(
+ max_pos, deltas = await self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
logger.debug("Handling %d state deltas", len(deltas))
- yield self._handle_deltas(deltas)
+ await self._handle_deltas(deltas)
self.pos = max_pos
@@ -169,10 +162,9 @@ class UserDirectoryHandler(StateDeltasHandler):
max_pos
)
- yield self.store.update_user_directory_stream_pos(max_pos)
+ await self.store.update_user_directory_stream_pos(max_pos)
- @defer.inlineCallbacks
- def _handle_deltas(self, deltas):
+ async def _handle_deltas(self, deltas):
"""Called with the state deltas to process
"""
for delta in deltas:
@@ -187,11 +179,11 @@ class UserDirectoryHandler(StateDeltasHandler):
# For join rule and visibility changes we need to check if the room
# may have become public or not and add/remove the users in said room
if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
- yield self._handle_room_publicity_change(
+ await self._handle_room_publicity_change(
room_id, prev_event_id, event_id, typ
)
elif typ == EventTypes.Member:
- change = yield self._get_key_change(
+ change = await self._get_key_change(
prev_event_id,
event_id,
key_name="membership",
@@ -201,7 +193,7 @@ class UserDirectoryHandler(StateDeltasHandler):
if change is False:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
- is_in_room = yield self.store.is_host_joined(
+ is_in_room = await self.store.is_host_joined(
room_id, self.server_name
)
if not is_in_room:
@@ -209,40 +201,41 @@ class UserDirectoryHandler(StateDeltasHandler):
# Fetch all the users that we marked as being in user
# directory due to being in the room and then check if
# need to remove those users or not
- user_ids = yield self.store.get_users_in_dir_due_to_room(
+ user_ids = await self.store.get_users_in_dir_due_to_room(
room_id
)
for user_id in user_ids:
- yield self._handle_remove_user(room_id, user_id)
+ await self._handle_remove_user(room_id, user_id)
return
else:
logger.debug("Server is still in room: %r", room_id)
- is_support = yield self.store.is_support_user(state_key)
+ is_support = await self.store.is_support_user(state_key)
if not is_support:
if change is None:
# Handle any profile changes
- yield self._handle_profile_change(
+ await self._handle_profile_change(
state_key, room_id, prev_event_id, event_id
)
continue
if change: # The user joined
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
)
- yield self._handle_new_user(room_id, state_key, profile)
+ await self._handle_new_user(room_id, state_key, profile)
else: # The user left
- yield self._handle_remove_user(room_id, state_key)
+ await self._handle_remove_user(room_id, state_key)
else:
logger.debug("Ignoring irrelevant type: %r", typ)
- @defer.inlineCallbacks
- def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ):
+ async def _handle_room_publicity_change(
+ self, room_id, prev_event_id, event_id, typ
+ ):
"""Handle a room having potentially changed from/to world_readable/publically
joinable.
@@ -255,14 +248,14 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Handling change for %s: %s", typ, room_id)
if typ == EventTypes.RoomHistoryVisibility:
- change = yield self._get_key_change(
+ change = await self._get_key_change(
prev_event_id,
event_id,
key_name="history_visibility",
public_value="world_readable",
)
elif typ == EventTypes.JoinRules:
- change = yield self._get_key_change(
+ change = await self._get_key_change(
prev_event_id,
event_id,
key_name="join_rule",
@@ -278,7 +271,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# There's been a change to or from being world readable.
- is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
@@ -293,11 +286,11 @@ class UserDirectoryHandler(StateDeltasHandler):
# ignore the change
return
- users_with_profile = yield self.state.get_current_users_in_room(room_id)
+ users_with_profile = await self.state.get_current_users_in_room(room_id)
# Remove every user from the sharing tables for that room.
for user_id in iterkeys(users_with_profile):
- yield self.store.remove_user_who_share_room(user_id, room_id)
+ await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables.
# NOTE: this is not the most efficient method, as handle_new_user sets
@@ -306,26 +299,9 @@ class UserDirectoryHandler(StateDeltasHandler):
# being added multiple times. The batching upserts shouldn't make this
# too bad, though.
for user_id, profile in iteritems(users_with_profile):
- yield self._handle_new_user(room_id, user_id, profile)
-
- @defer.inlineCallbacks
- def _handle_local_user(self, user_id):
- """Adds a new local roomless user into the user_directory_search table.
- Used to populate up the user index when we have an
- user_directory_search_all_users specified.
- """
- logger.debug("Adding new local user to dir, %r", user_id)
-
- profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
-
- row = yield self.store.get_user_in_directory(user_id)
- if not row:
- yield self.store.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
- )
+ await self._handle_new_user(room_id, user_id, profile)
- @defer.inlineCallbacks
- def _handle_new_user(self, room_id, user_id, profile):
+ async def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory
Args:
@@ -334,18 +310,18 @@ class UserDirectoryHandler(StateDeltasHandler):
"""
logger.debug("Adding new user to dir, %r", user_id)
- yield self.store.update_profile_in_user_dir(
+ await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
# Now we update users who share rooms with users.
- users_with_profile = yield self.state.get_current_users_in_room(room_id)
+ users_with_profile = await self.state.get_current_users_in_room(room_id)
if is_public:
- yield self.store.add_users_in_public_rooms(room_id, (user_id,))
+ await self.store.add_users_in_public_rooms(room_id, (user_id,))
else:
to_insert = set()
@@ -376,10 +352,9 @@ class UserDirectoryHandler(StateDeltasHandler):
to_insert.add((other_user_id, user_id))
if to_insert:
- yield self.store.add_users_who_share_private_room(room_id, to_insert)
+ await self.store.add_users_who_share_private_room(room_id, to_insert)
- @defer.inlineCallbacks
- def _handle_remove_user(self, room_id, user_id):
+ async def _handle_remove_user(self, room_id, user_id):
"""Called when we might need to remove user from directory
Args:
@@ -389,24 +364,23 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Removing user %r", user_id)
# Remove user from sharing tables
- yield self.store.remove_user_who_share_room(user_id, room_id)
+ await self.store.remove_user_who_share_room(user_id, room_id)
# Are they still in any rooms? If not, remove them entirely.
- rooms_user_is_in = yield self.store.get_user_dir_rooms_user_is_in(user_id)
+ rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
if len(rooms_user_is_in) == 0:
- yield self.store.remove_from_user_dir(user_id)
+ await self.store.remove_from_user_dir(user_id)
- @defer.inlineCallbacks
- def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+ async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
"""Check member event changes for any profile changes and update the
database if there are.
"""
if not prev_event_id or not event_id:
return
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
- event = yield self.store.get_event(event_id, allow_none=True)
+ prev_event = await self.store.get_event(prev_event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if not prev_event or not event:
return
@@ -421,4 +395,4 @@ class UserDirectoryHandler(StateDeltasHandler):
new_avatar = event.content.get("avatar_url")
if prev_name != new_name or prev_avatar != new_avatar:
- yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
+ await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
|