diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 664d09da1c..ce97fa70d7 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,11 +18,14 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import StoreError, SynapseError
from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import UserID
from synapse.util import stringutils
@@ -40,27 +43,37 @@ class AccountValidityHandler:
self.sendmail = self.hs.get_sendmail()
self.clock = self.hs.get_clock()
- self._account_validity = self.hs.config.account_validity
+ self._account_validity_enabled = self.hs.config.account_validity_enabled
+ self._account_validity_renew_by_email_enabled = (
+ self.hs.config.account_validity_renew_by_email_enabled
+ )
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
+
+ self._account_validity_period = None
+ if self._account_validity_enabled:
+ self._account_validity_period = self.hs.config.account_validity_period
if (
- self._account_validity.enabled
- and self._account_validity.renew_by_email_enabled
+ self._account_validity_enabled
+ and self._account_validity_renew_by_email_enabled
):
# Don't do email-specific configuration if renewal by email is disabled.
self._template_html = self.config.account_validity_template_html
self._template_text = self.config.account_validity_template_text
+ account_validity_renew_email_subject = (
+ self.hs.config.account_validity_renew_email_subject
+ )
try:
app_name = self.hs.config.email_app_name
- self._subject = self._account_validity.renew_email_subject % {
- "app": app_name
- }
+ self._subject = account_validity_renew_email_subject % {"app": app_name}
self._from_string = self.hs.config.email_notif_from % {"app": app_name}
except Exception:
# If substitution failed, fall back to the bare strings.
- self._subject = self._account_validity.renew_email_subject
+ self._subject = account_validity_renew_email_subject
self._from_string = self.hs.config.email_notif_from
self._raw_from = email.utils.parseaddr(self._from_string)[1]
@@ -69,6 +82,18 @@ class AccountValidityHandler:
if hs.config.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
+ # Mark users as inactive when they expired. Check once every hour
+ if self._account_validity_enabled:
+
+ def mark_expired_users_as_inactive():
+ # run as a background process to allow async functions to work
+ return run_as_background_process(
+ "_mark_expired_users_as_inactive",
+ self._mark_expired_users_as_inactive,
+ )
+
+ self.clock.looping_call(mark_expired_users_as_inactive, 60 * 60 * 1000)
+
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time
@@ -221,47 +246,107 @@ class AccountValidityHandler:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- async def renew_account(self, renewal_token: str) -> bool:
+ async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
+ If it turns out that the token is valid but has already been used, then the
+ token is considered stale. A token is stale if the 'token_used_ts_ms' db column
+ is non-null.
+
Args:
renewal_token: Token sent with the renewal request.
Returns:
- Whether the provided token is valid.
+ A tuple containing:
+ * A bool representing whether the token is valid and unused.
+ * A bool representing whether the token is stale.
+ * An int representing the user's expiry timestamp as milliseconds since the
+ epoch, or 0 if the token was invalid.
"""
try:
- user_id = await self.store.get_user_from_renewal_token(renewal_token)
+ (
+ user_id,
+ current_expiration_ts,
+ token_used_ts,
+ ) = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- return False
+ return False, False, 0
+
+ # Check whether this token has already been used.
+ if token_used_ts:
+ logger.info(
+ "User '%s' attempted to use previously used token '%s' to renew account",
+ user_id,
+ renewal_token,
+ )
+ return False, True, current_expiration_ts
logger.debug("Renewing an account for user %s", user_id)
- await self.renew_account_for_user(user_id)
- return True
+ # Renew the account. Pass the renewal_token here so that it is not cleared.
+ # We want to keep the token around in case the user attempts to renew their
+ # account with the same token twice (clicking the email link twice).
+ #
+ # In that case, the token will be accepted, but the account's expiration ts
+ # will remain unchanged.
+ new_expiration_ts = await self.renew_account_for_user(
+ user_id, renewal_token=renewal_token
+ )
+
+ return True, False, new_expiration_ts
async def renew_account_for_user(
- self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ self,
+ user_id: str,
+ expiration_ts: Optional[int] = None,
+ email_sent: bool = False,
+ renewal_token: Optional[str] = None,
) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token: Token sent with the renewal request.
+ user_id: The ID of the user to renew.
expiration_ts: New expiration date. Defaults to now + validity period.
- email_sen: Whether an email has been sent for this validity period.
- Defaults to False.
+ email_sent: Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request. The user's token
+ will be cleared if this is None.
Returns:
New expiration date for this account, as a timestamp in
milliseconds since epoch.
"""
+ now = self.clock.time_msec()
if expiration_ts is None:
- expiration_ts = self.clock.time_msec() + self._account_validity.period
+ expiration_ts = now + self._account_validity_period
await self.store.set_account_validity_for_user(
- user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
+ user_id=user_id,
+ expiration_ts=expiration_ts,
+ email_sent=email_sent,
+ renewal_token=renewal_token,
+ token_used_ts=now,
)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ await self.profile_handler.set_active(
+ [UserID.from_string(user_id)], True, True
+ )
+
return expiration_ts
+
+ async def _mark_expired_users_as_inactive(self):
+ """Iterate over active, expired users. Mark them as inactive in order to hide them
+ from the user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get active, expired users
+ active_expired_users = await self.store.get_expired_users()
+
+ # Mark each as non-active
+ await self.profile_handler.set_active(active_expired_users, False, True)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 37e63da9b1..db68c94c50 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -203,13 +203,11 @@ class AdminHandler(BaseHandler):
class ExfiltrationWriter(metaclass=abc.ABCMeta):
- """Interface used to specify how to write exported data.
- """
+ """Interface used to specify how to write exported data."""
@abc.abstractmethod
def write_events(self, room_id: str, events: List[EventBase]) -> None:
- """Write a batch of events for a room.
- """
+ """Write a batch of events for a room."""
raise NotImplementedError()
@abc.abstractmethod
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5c6458eb52..deab8ff2d0 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -290,7 +290,9 @@ class ApplicationServicesHandler:
if not interested:
continue
presence_events, _ = await presence_source.get_new_events(
- user=user, service=service, from_key=from_key,
+ user=user,
+ service=service,
+ from_key=from_key,
)
time_now = self.clock.time_msec()
events.extend(
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a19c556437..9ba9f591d9 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -120,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier(
# Ensure the identifier has a type
if "type" not in identifier:
raise SynapseError(
- 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
+ 400,
+ "'identifier' dict has no key 'type'",
+ errcode=Codes.MISSING_PARAM,
)
return identifier
@@ -351,7 +353,11 @@ class AuthHandler(BaseHandler):
try:
result, params, session_id = await self.check_ui_auth(
- flows, request, request_body, description, get_new_session_data,
+ flows,
+ request,
+ request_body,
+ description,
+ get_new_session_data,
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
@@ -379,8 +385,7 @@ class AuthHandler(BaseHandler):
return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
- """Get a list of the authentication types this user can use
- """
+ """Get a list of the authentication types this user can use"""
ui_auth_types = set()
@@ -723,7 +728,9 @@ class AuthHandler(BaseHandler):
}
def _auth_dict_for_flows(
- self, flows: List[List[str]], session_id: str,
+ self,
+ flows: List[List[str]],
+ session_id: str,
) -> Dict[str, Any]:
public_flows = []
for f in flows:
@@ -880,7 +887,9 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
- self, login_submission: Dict[str, Any], ratelimit: bool = False,
+ self,
+ login_submission: Dict[str, Any],
+ ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
@@ -1023,7 +1032,9 @@ class AuthHandler(BaseHandler):
raise
async def _validate_userid_login(
- self, username: str, login_submission: Dict[str, Any],
+ self,
+ username: str,
+ login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
@@ -1446,7 +1457,8 @@ class AuthHandler(BaseHandler):
# is considered OK since the newest SSO attributes should be most valid.
if extra_attributes:
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
- self._clock.time_msec(), extra_attributes,
+ self._clock.time_msec(),
+ extra_attributes,
)
# Create a login token
@@ -1472,10 +1484,22 @@ class AuthHandler(BaseHandler):
# Remove the query parameters from the redirect URL to get a shorter version of
# it. This is only to display a human-readable URL in the template, but not the
# URL we redirect users to.
- redirect_url_no_params = client_redirect_url.split("?")[0]
+ url_parts = urllib.parse.urlsplit(client_redirect_url)
+
+ if url_parts.scheme == "https":
+ # for an https uri, just show the netloc (ie, the hostname. Specifically,
+ # the bit between "//" and "/"; this includes any potential
+ # "username:password@" prefix.)
+ display_url = url_parts.netloc
+ else:
+ # for other uris, strip the query-params (including the login token) and
+ # fragment.
+ display_url = urllib.parse.urlunsplit(
+ (url_parts.scheme, url_parts.netloc, url_parts.path, "", "")
+ )
html = self._sso_redirect_confirm_template.render(
- display_url=redirect_url_no_params,
+ display_url=display_url,
redirect_url=redirect_url,
server_name=self._server_name,
new_user=new_user,
@@ -1690,5 +1714,9 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
await maybe_awaitable(
- g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ g(
+ user_id=user_id,
+ device_id=device_id,
+ access_token=access_token,
+ )
)
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index bd35d1fb87..04972f9cf0 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree import ElementTree as ET
import attr
@@ -33,8 +33,7 @@ logger = logging.getLogger(__name__)
class CasError(Exception):
- """Used to catch errors when validating the CAS ticket.
- """
+ """Used to catch errors when validating the CAS ticket."""
def __init__(self, error, error_description=None):
self.error = error
@@ -49,7 +48,7 @@ class CasError(Exception):
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
- attributes = attr.ib(type=Dict[str, Optional[str]])
+ attributes = attr.ib(type=Dict[str, List[Optional[str]]])
class CasHandler:
@@ -100,7 +99,10 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
- return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
+ return "%s?%s" % (
+ self._cas_service_url,
+ urllib.parse.urlencode(args),
+ )
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
@@ -169,7 +171,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
- attributes = {}
+ attributes = {} # type: Dict[str, List[Optional[str]]]
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@@ -182,7 +184,7 @@ class CasHandler:
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
- attributes[tag] = attribute.text
+ attributes.setdefault(tag, []).append(attribute.text)
# Ensure a user was found.
if user is None:
@@ -296,36 +298,20 @@ class CasHandler:
# first check if we're doing a UIA
if session:
return await self._sso_handler.complete_sso_ui_auth_request(
- self.idp_id, cas_response.username, session, request,
+ self.idp_id,
+ cas_response.username,
+ session,
+ request,
)
# otherwise, we're handling a login request.
# Ensure that the attributes of the logged in user meet the required
# attributes.
- for required_attribute, required_value in self._cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in cas_response.attributes:
- self._sso_handler.render_error(
- request,
- "unauthorised",
- "You are not authorised to log in here.",
- 401,
- )
- return
-
- # Also need to check value
- if required_value is not None:
- actual_value = cas_response.attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- self._sso_handler.render_error(
- request,
- "unauthorised",
- "You are not authorised to log in here.",
- 401,
- )
- return
+ if not self._sso_handler.check_required_attributes(
+ request, cas_response.attributes, self._cas_required_attributes
+ ):
+ return
# Call the mapper to register/login the user
@@ -372,9 +358,10 @@ class CasHandler:
if failures:
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+ # Arbitrarily use the first attribute found.
display_name = cas_response.attributes.get(
- self._cas_displayname_attribute, None
- )
+ self._cas_displayname_attribute, [None]
+ )[0]
return UserAttributes(localpart=localpart, display_name=display_name)
@@ -384,7 +371,8 @@ class CasHandler:
user_id = UserID(localpart, self._hostname).to_string()
logger.debug(
- "Looking for existing account based on mapped %s", user_id,
+ "Looking for existing account based on mapped %s",
+ user_id,
)
users = await self._store.get_users_by_id_case_insensitive(user_id)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index c4a3b26a84..7911d126f5 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -50,7 +50,7 @@ class DeactivateAccountHandler(BaseHandler):
if hs.config.run_background_tasks:
hs.get_reactor().callWhenRunning(self._start_user_parting)
- self._account_validity_enabled = hs.config.account_validity.enabled
+ self._account_validity_enabled = hs.config.account_validity_enabled
async def deactivate_account(
self,
@@ -120,6 +120,9 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ await self._profile_handler.set_active([user], False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id)
@@ -196,8 +199,7 @@ class DeactivateAccountHandler(BaseHandler):
run_as_background_process("user_parter_loop", self._user_parter_loop)
async def _user_parter_loop(self) -> None:
- """Loop that parts deactivated users from rooms
- """
+ """Loop that parts deactivated users from rooms"""
self._user_parter_running = True
logger.info("Starting user parter")
try:
@@ -214,8 +216,7 @@ class DeactivateAccountHandler(BaseHandler):
self._user_parter_running = False
async def _part_user(self, user_id: str) -> None:
- """Causes the given user_id to leave all the rooms they're joined to
- """
+ """Causes the given user_id to leave all the rooms they're joined to"""
user = UserID.from_string(user_id)
rooms_for_user = await self.store.get_rooms_for_user(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 0863154f7a..df3cdc8fba 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -86,7 +86,7 @@ class DeviceWorkerHandler(BaseHandler):
@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
- """ Retrieve the given device
+ """Retrieve the given device
Args:
user_id: The user to get the device from
@@ -341,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
async def delete_device(self, user_id: str, device_id: str) -> None:
- """ Delete the given device
+ """Delete the given device
Args:
user_id: The user to delete the device from.
@@ -386,7 +386,7 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, device_ids)
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
- """ Delete several devices
+ """Delete several devices
Args:
user_id: The user to delete devices from.
@@ -417,7 +417,7 @@ class DeviceHandler(DeviceWorkerHandler):
await self.notify_device_update(user_id, device_ids)
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
- """ Update the given device
+ """Update the given device
Args:
user_id: The user to update devices of.
@@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler):
device id of the dehydrated device
"""
device_id = await self.check_device_registered(
- user_id, None, initial_device_display_name,
+ user_id,
+ None,
+ initial_device_display_name,
)
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
@@ -803,7 +805,8 @@ class DeviceListUpdater:
try:
# Try to resync the current user's devices list.
result = await self.user_device_resync(
- user_id=user_id, mark_failed_as_stale=False,
+ user_id=user_id,
+ mark_failed_as_stale=False,
)
# user_device_resync only returns a result if it managed to
@@ -813,14 +816,17 @@ class DeviceListUpdater:
# self.store.update_remote_device_list_cache).
if result:
logger.debug(
- "Successfully resynced the device list for %s", user_id,
+ "Successfully resynced the device list for %s",
+ user_id,
)
except Exception as e:
# If there was an issue resyncing this user, e.g. if the remote
# server sent a malformed result, just log the error instead of
# aborting all the subsequent resyncs.
logger.debug(
- "Could not resync the device list for %s: %s", user_id, e,
+ "Could not resync the device list for %s: %s",
+ user_id,
+ e,
)
finally:
# Allow future calls to retry resyncinc out of sync device lists.
@@ -855,7 +861,9 @@ class DeviceListUpdater:
return None
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
- "Failed to handle device list update for %s: %s", user_id, e,
+ "Failed to handle device list update for %s: %s",
+ user_id,
+ e,
)
if mark_failed_as_stale:
@@ -931,7 +939,9 @@ class DeviceListUpdater:
# Handle cross-signing keys.
cross_signing_device_ids = await self.process_cross_signing_key_update(
- user_id, master_key, self_signing_key,
+ user_id,
+ master_key,
+ self_signing_key,
)
device_ids = device_ids + cross_signing_device_ids
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 0c7737e09d..1aa7d803b5 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -62,7 +62,8 @@ class DeviceMessageHandler:
)
else:
hs.get_federation_registry().register_instances_for_edu(
- "m.direct_to_device", hs.config.worker.writers.to_device,
+ "m.direct_to_device",
+ hs.config.worker.writers.to_device,
)
# The handler to call when we think a user's device list might be out of
@@ -73,8 +74,8 @@ class DeviceMessageHandler:
hs.get_device_handler().device_list_updater.user_device_resync
)
else:
- self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
- hs
+ self._user_device_resync = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8f3a6b35a4..9a946a3cfe 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -61,8 +61,8 @@ class E2eKeysHandler:
self._is_master = hs.config.worker_app is None
if not self._is_master:
- self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
- hs
+ self._user_device_resync_client = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
else:
# Only register this edu handler on master as it requires writing
@@ -85,7 +85,7 @@ class E2eKeysHandler:
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str
) -> JsonDict:
- """ Handle a device key query from a client
+ """Handle a device key query from a client
{
"device_keys": {
@@ -391,8 +391,7 @@ class E2eKeysHandler:
async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
- """ Handle a device key query from a federated server
- """
+ """Handle a device key query from a federated server"""
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
@@ -1065,7 +1064,9 @@ class E2eKeysHandler:
return key, key_id, verify_key
async def _retrieve_cross_signing_keys_for_remote_user(
- self, user: UserID, desired_key_type: str,
+ self,
+ user: UserID,
+ desired_key_type: str,
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
@@ -1269,8 +1270,7 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
@attr.s(slots=True)
class SignatureListItem:
- """An item in the signature list as used by upload_signatures_for_device_keys.
- """
+ """An item in the signature list as used by upload_signatures_for_device_keys."""
signing_key_id = attr.ib(type=str)
target_user_id = attr.ib(type=str)
@@ -1355,8 +1355,12 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = await device_list_updater.process_cross_signing_key_update(
- user_id, master_key, self_signing_key,
+ new_device_ids = (
+ await device_list_updater.process_cross_signing_key_update(
+ user_id,
+ master_key,
+ self_signing_key,
+ )
)
device_ids = device_ids + new_device_ids
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 539b4fc32e..3e23f82cf7 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler):
room_id: Optional[str] = None,
is_guest: bool = False,
) -> JsonDict:
- """Fetches the events stream for a given user.
- """
+ """Fetches the events stream for a given user."""
if room_id:
blocked = await self.store.is_room_blocked(room_id)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index eddc7582d0..51bdf97920 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -111,13 +112,13 @@ class _NewEventInfo:
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
- Responsible for:
- a) handling received Pdus before handing them on as Events to the rest
- of the homeserver (including auth and state conflict resolutions)
- b) converting events that were produced by local clients that may need
- to be sent to remote homeservers.
- c) doing the necessary dances to invite remote users and join remote
- rooms.
+ Responsible for:
+ a) handling received Pdus before handing them on as Events to the rest
+ of the homeserver (including auth and state conflict resolutions)
+ b) converting events that were produced by local clients that may need
+ to be sent to remote homeservers.
+ c) doing the necessary dances to invite remote users and join remote
+ rooms.
"""
def __init__(self, hs: "HomeServer"):
@@ -150,11 +151,11 @@ class FederationHandler(BaseHandler):
)
if hs.config.worker_app:
- self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
- hs
+ self._user_device_resync = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
- self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
- hs
+ self._maybe_store_room_on_outlier_membership = (
+ ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
@@ -172,7 +173,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
- """ Process a PDU received via a federation /send/ transaction, or
+ """Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args:
@@ -186,7 +187,7 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("handling received PDU: %s", pdu)
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = await self.store.get_event(
@@ -301,6 +302,14 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
+ elif missing_prevs:
+ logger.info(
+ "[%s %s] Not recursively fetching %d missing prev_events: %s",
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -345,12 +354,6 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- logger.info(
- "Event %s is missing prev_events: calculating state for a "
- "backwards extremity",
- event_id,
- )
-
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
@@ -368,7 +371,8 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "Requesting state at missing prev_event %s", event_id,
+ "Requesting state at missing prev_event %s",
+ event_id,
)
with nested_logging_context(p):
@@ -388,12 +392,14 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id)
- state_map = await self._state_resolution_handler.resolve_events_with_store(
- room_id,
- room_version,
- state_maps,
- event_map,
- state_res_store=StateResolutionStore(self.store),
+ state_map = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ room_id,
+ room_version,
+ state_maps,
+ event_map,
+ state_res_store=StateResolutionStore(self.store),
+ )
)
# We need to give _process_received_pdu the actual state events
@@ -402,9 +408,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
+ list(state_map.values()), get_prev_content=False,
)
event_map.update(evs)
@@ -687,9 +691,12 @@ class FederationHandler(BaseHandler):
return fetched_events
async def _process_received_pdu(
- self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
):
- """ Called when we have a new pdu. We need to do auth checks and put it
+ """Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
Args:
@@ -801,7 +808,7 @@ class FederationHandler(BaseHandler):
@log_function
async def backfill(self, dest, room_id, limit, extremities):
- """ Trigger a backfill request to `dest` for the given `room_id`
+ """Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
has no new events to offer, this will return an empty list.
@@ -1204,11 +1211,16 @@ class FederationHandler(BaseHandler):
with nested_logging_context(event_id):
try:
event = await self.federation_client.get_pdu(
- [destination], event_id, room_version, outlier=True,
+ [destination],
+ event_id,
+ room_version,
+ outlier=True,
)
if event is None:
logger.warning(
- "Server %s didn't return event %s", destination, event_id,
+ "Server %s didn't return event %s",
+ destination,
+ event_id,
)
return
@@ -1235,7 +1247,8 @@ class FederationHandler(BaseHandler):
if aid not in event_map
]
persisted_events = await self.store.get_events(
- auth_events, allow_rejected=True,
+ auth_events,
+ allow_rejected=True,
)
event_infos = []
@@ -1251,7 +1264,9 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
- destination, room_id, event_infos,
+ destination,
+ room_id,
+ event_infos,
)
def _sanity_check_event(self, ev):
@@ -1287,7 +1302,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def send_invite(self, target_host, event):
- """ Sends the invite to the remote server for signing.
+ """Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
"""
@@ -1310,7 +1325,7 @@ class FederationHandler(BaseHandler):
async def do_invite_join(
self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
) -> Tuple[str, int]:
- """ Attempts to join the `joinee` to the room `room_id` via the
+ """Attempts to join the `joinee` to the room `room_id` via the
servers contained in `target_hosts`.
This first triggers a /make_join/ request that returns a partial
@@ -1354,8 +1369,6 @@ class FederationHandler(BaseHandler):
await self._clean_room_for_join(room_id)
- handled_events = set()
-
try:
# Try the host we successfully got a response to /make_join/
# request first.
@@ -1375,10 +1388,6 @@ class FederationHandler(BaseHandler):
auth_chain = ret["auth_chain"]
auth_chain.sort(key=lambda e: e.depth)
- handled_events.update([s.event_id for s in state])
- handled_events.update([a.event_id for a in auth_chain])
- handled_events.add(event.event_id)
-
logger.debug("do_invite_join auth_chain: %s", auth_chain)
logger.debug("do_invite_join state: %s", state)
@@ -1394,7 +1403,8 @@ class FederationHandler(BaseHandler):
# so we can rely on it now.
#
await self.store.upsert_room_on_join(
- room_id=room_id, room_version=room_version_obj,
+ room_id=room_id,
+ room_version=room_version_obj,
)
max_stream_id = await self._persist_auth_tree(
@@ -1439,6 +1449,73 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
+ @log_function
+ async def do_knock(
+ self, target_hosts: List[str], room_id: str, knockee: str, content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Sends the knock to the remote server.
+
+ This first triggers a make_knock request that returns a partial
+ event that we can fill out and sign. This is then sent to the
+ remote server via send_knock.
+
+ Knock events must be signed by the knockee's server before distributing.
+
+ Args:
+ target_hosts: A list of hosts that we want to try knocking through.
+ room_id: The ID of the room to knock on.
+ knockee: The ID of the user who is knocking.
+ content: The content of the knock event.
+
+ Returns:
+ A tuple of (event ID, stream ID).
+
+ Raises:
+ SynapseError: If the chosen remote server returns a 3xx/4xx code.
+ RuntimeError: If no servers were reachable.
+ """
+ logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee)
+
+ # Inform the remote server of the room versions we support
+ supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys())
+
+ # Ask the remote server to create a valid knock event for us. Once received,
+ # we sign the event
+ params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
+ origin, event, event_format_version = await self._make_and_verify_event(
+ target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
+ )
+
+ # Record the room ID and its version so that we have a record of the room
+ await self._maybe_store_room_on_outlier_membership(
+ room_id=event.room_id, room_version=event_format_version
+ )
+
+ # Initially try the host that we successfully called /make_knock on
+ try:
+ target_hosts.remove(origin)
+ target_hosts.insert(0, origin)
+ except ValueError:
+ pass
+
+ # Send the signed event back to the room, and potentially receive some
+ # further information about the room in the form of partial state events
+ stripped_room_state = await self.federation_client.send_knock(
+ target_hosts, event
+ )
+
+ # Store any stripped room state events in the "unsigned" key of the event.
+ # This is a bit of a hack and is cribbing off of invites. Basically we
+ # store the room state here and retrieve it again when this event appears
+ # in the invitee's sync stream. It is stripped out for all other local users.
+ event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+
+ context = await self.state_handler.compute_event_context(event)
+ stream_id = await self.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
+ return event.event_id, stream_id
+
async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
@@ -1464,7 +1541,7 @@ class FederationHandler(BaseHandler):
async def on_make_join_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
- """ We've received a /make_join/ request, so we create a partial
+ """We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
@@ -1489,7 +1566,8 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
- "Got /make_join request for room %s we are no longer in", room_id,
+ "Got /make_join request for room %s we are no longer in",
+ room_id,
)
raise NotFoundError("Not an active room on this server")
@@ -1523,7 +1601,7 @@ class FederationHandler(BaseHandler):
return event
async def on_send_join_request(self, origin, pdu):
- """ We have received a join event for a room. Fully process it and
+ """We have received a join event for a room. Fully process it and
respond with the current state and auth chains.
"""
event = pdu
@@ -1579,7 +1657,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
):
- """ We've got an invite event. Process and persist it. Sign it.
+ """We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event.
"""
@@ -1593,8 +1671,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = await self.store.is_room_published(event.room_id)
+
if not await self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1706,7 +1791,7 @@ class FederationHandler(BaseHandler):
async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
- """ We've received a /make_leave/ request, so we create a partial
+ """We've received a /make_leave/ request, so we create a partial
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
@@ -1781,9 +1866,122 @@ class FederationHandler(BaseHandler):
return None
- async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
- """Returns the state at the event. i.e. not including said event.
+ @log_function
+ async def on_make_knock_request(
+ self, origin: str, room_id: str, user_id: str
+ ) -> EventBase:
+ """We've received a make_knock request, so we create a partial
+ knock event for the room and return that. We do *not* persist or
+ process it until the other server has signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: The room to create the knock event in.
+ user_id: The user to create the knock for.
+
+ Returns:
+ The partial knock event.
"""
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Get /xyz.amorgan.knock/make_knock request for user %r"
+ "from different origin %s, ignoring",
+ user_id,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ room_version = await self.store.get_room_version_id(room_id)
+
+ builder = self.event_builder_factory.new(
+ room_version,
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.KNOCK},
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": user_id,
+ },
+ )
+
+ event, context = await self.event_creation_handler.create_new_client_event(
+ builder=builder
+ )
+
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.warning("Creation of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ try:
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_knock_request`
+ await self.auth.check_from_context(
+ room_version, event, context, do_sig_check=False
+ )
+ except AuthError as e:
+ logger.warning("Failed to create new knock %r because %s", event, e)
+ raise e
+
+ return event
+
+ @log_function
+ async def on_send_knock_request(
+ self, origin: str, event: EventBase
+ ) -> EventContext:
+ """
+ We have received a knock event for a room. Verify that event and send it into the room
+ on the knocking homeserver's behalf.
+
+ Args:
+ origin: The remote homeserver of the knocking user.
+ event: The knocking member event that has been signed by the remote homeserver.
+
+ Returns:
+ The context of the event after inserting it into the room graph.
+ """
+ logger.debug(
+ "on_send_knock_request: Got event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got /xyz.amorgan.knock/send_knock request for user %r "
+ "from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ event.internal_metadata.outlier = False
+
+ context = await self._handle_new_event(origin, event)
+
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ logger.debug(
+ "on_send_knock_request: After _handle_new_event: %s, sigs: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ return context
+
+ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
+ """Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
@@ -1809,8 +2007,7 @@ class FederationHandler(BaseHandler):
return []
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
- """Returns the state at the event. i.e. not including said event.
- """
+ """Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -2016,7 +2213,11 @@ class FederationHandler(BaseHandler):
for e_id in missing_auth_events:
m_ev = await self.federation_client.get_pdu(
- [origin], e_id, room_version=room_version, outlier=True, timeout=10000,
+ [origin],
+ e_id,
+ room_version=room_version,
+ outlier=True,
+ timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
@@ -2166,7 +2367,9 @@ class FederationHandler(BaseHandler):
)
logger.debug(
- "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+ "Doing soft-fail check for %s: state %s",
+ event.event_id,
+ current_state_ids,
)
# Now check if event pass auth against said current state
@@ -2519,7 +2722,7 @@ class FederationHandler(BaseHandler):
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
- """ Given a local and remote auth chain, find the differences. This
+ """Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 71f11ef94a..bfb95e3eee 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -146,8 +146,7 @@ class GroupsLocalWorkerHandler:
async def get_users_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
- """Get users in a group
- """
+ """Get users in a group"""
if self.is_mine_id(group_id):
return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
@@ -283,8 +282,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def create_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
- """Create a group
- """
+ """Create a group"""
logger.info("Asking to create group with ID: %r", group_id)
@@ -314,8 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def join_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
- """Request to join a group
- """
+ """Request to join a group"""
if self.is_mine_id(group_id):
await self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
@@ -361,8 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def accept_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
- """Accept an invite to a group
- """
+ """Accept an invite to a group"""
if self.is_mine_id(group_id):
await self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
@@ -408,8 +404,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def invite(
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
) -> JsonDict:
- """Invite a user to a group
- """
+ """Invite a user to a group"""
content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
res = await self.groups_server_handler.invite_to_group(
@@ -434,8 +429,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def on_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
- """One of our users were invited to a group
- """
+ """One of our users were invited to a group"""
# TODO: Support auto join and rejection
if not self.is_mine_id(user_id):
@@ -466,8 +460,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def remove_user_from_group(
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
- """Remove a user from a group
- """
+ """Remove a user from a group"""
if user_id == requester_user_id:
token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
@@ -501,8 +494,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def user_removed_from_group(
self, group_id: str, user_id: str, content: JsonDict
) -> None:
- """One of our users was removed/kicked from a group
- """
+ """One of our users was removed/kicked from a group"""
# TODO: Check if user in group
token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 8fc1e8b91c..ac81fa3678 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,9 +22,11 @@ import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from synapse.api.errors import (
+ AuthError,
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
@@ -41,8 +43,6 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-id_server_scheme = "https://"
-
class IdentityHandler(BaseHandler):
def __init__(self, hs):
@@ -57,6 +57,9 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
+
self._web_client_location = hs.config.invite_client_location
# Ratelimiters for `/requestToken` endpoints.
@@ -72,7 +75,10 @@ class IdentityHandler(BaseHandler):
)
def ratelimit_request_token_requests(
- self, request: SynapseRequest, medium: str, address: str,
+ self,
+ request: SynapseRequest,
+ medium: str,
+ address: str,
):
"""Used to ratelimit requests to `/requestToken` by IP and address.
@@ -86,14 +92,14 @@ class IdentityHandler(BaseHandler):
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
async def threepid_from_creds(
- self, id_server: str, creds: Dict[str, str]
+ self, id_server_url: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
- id_server: The identity server to validate 3PIDs against. Must be a
+ id_server_url: The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
creds: Dictionary containing the following keys:
* client_secret|clientSecret: A unique secret str provided by the client
@@ -118,7 +124,14 @@ class IdentityHandler(BaseHandler):
query_params = {"sid": session_id, "client_secret": client_secret}
- url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
+ url = "%s%s" % (
+ id_server_url,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
try:
data = await self.http_client.get_json(url, query_params)
@@ -127,7 +140,7 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info(
"%s returned %i for threepid validation for: %s",
- id_server,
+ id_server_url,
e.code,
creds,
)
@@ -141,7 +154,7 @@ class IdentityHandler(BaseHandler):
if "medium" in data:
return data
- logger.info("%s reported non-validated threepid: %s", id_server, creds)
+ logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
return None
async def bind_threepid(
@@ -173,14 +186,19 @@ class IdentityHandler(BaseHandler):
if id_access_token is None:
use_v2 = False
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,)
headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
else:
- bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,)
try:
# Use the blacklisting http client as this call is only to identity servers
@@ -267,9 +285,6 @@ class IdentityHandler(BaseHandler):
True on success, otherwise False if the identity
server doesn't support unbinding
"""
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
-
content = {
"mxid": mxid,
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
@@ -278,6 +293,7 @@ class IdentityHandler(BaseHandler):
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
+ url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method=b"POST",
@@ -287,6 +303,15 @@ class IdentityHandler(BaseHandler):
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,)
+
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
@@ -401,9 +426,28 @@ class IdentityHandler(BaseHandler):
return session_id
+ def rewrite_id_server_url(self, url: str, add_https=False) -> str:
+ """Given an identity server URL, optionally add a protocol scheme
+ before rewriting it according to the rewrite_identity_server_urls
+ config option
+
+ Adds https:// to the URL if specified, then tries to rewrite the
+ url. Returns either the rewritten URL or the URL with optional
+ protocol scheme additions.
+ """
+ rewritten_url = url
+ if add_https:
+ rewritten_url = "https://" + rewritten_url
+
+ rewritten_url = self.rewrite_identity_server_urls.get(
+ rewritten_url, rewritten_url
+ )
+ logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
+ return rewritten_url
+
async def requestEmailToken(
self,
- id_server: str,
+ id_server_url: str,
email: str,
client_secret: str,
send_attempt: int,
@@ -414,7 +458,7 @@ class IdentityHandler(BaseHandler):
validation.
Args:
- id_server: The identity server to proxy to
+ id_server_url: The identity server to proxy to
email: The email to send the message to
client_secret: The unique client_secret sends by the user
send_attempt: Which attempt this is
@@ -428,6 +472,11 @@ class IdentityHandler(BaseHandler):
"client_secret": client_secret,
"send_attempt": send_attempt,
}
+
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
if next_link:
params["next_link"] = next_link
@@ -442,7 +491,8 @@ class IdentityHandler(BaseHandler):
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
+ "%s/_matrix/identity/api/v1/validate/email/requestToken"
+ % (id_server_url,),
params,
)
return data
@@ -454,7 +504,7 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
- id_server: str,
+ id_server_url: str,
country: str,
phone_number: str,
client_secret: str,
@@ -465,7 +515,7 @@ class IdentityHandler(BaseHandler):
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
- id_server: The identity server to proxy to
+ id_server_url: The identity server to proxy to
country: The country code of the phone number
phone_number: The number to send the message to
client_secret: The unique client_secret sends by the user
@@ -493,9 +543,13 @@ class IdentityHandler(BaseHandler):
"details and update your config file."
)
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
+ "%s/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ % (id_server_url,),
params,
)
except HttpResponseException as e:
@@ -591,6 +645,86 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
+ # TODO: The following two methods are used for proxying IS requests using
+ # the CS API. They should be consolidated with those in RoomMemberHandler
+ # https://github.com/matrix-org/synapse-dinsic/issues/25
+
+ async def proxy_lookup_3pid(
+ self, id_server: str, medium: str, address: str
+ ) -> JsonDict:
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server: The server name (including port, if required)
+ of the identity server to use.
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
+ {"medium": medium, "address": address},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
+ async def proxy_bulk_lookup_3pid(
+ self, id_server: str, threepids: List[List[str]]
+ ) -> JsonDict:
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server: The server name (including port, if required)
+ of the identity server to use.
+ threepids: The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
async def lookup_3pid(
self,
id_server: str,
@@ -611,10 +745,13 @@ class IdentityHandler(BaseHandler):
Returns:
the matrix ID of the 3pid, or None if it is not recognized.
"""
+ # Rewrite id_server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
if id_access_token is not None:
try:
results = await self._lookup_3pid_v2(
- id_server, id_access_token, medium, address
+ id_server_url, id_access_token, medium, address
)
return results
@@ -632,16 +769,17 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return await self._lookup_3pid_v1(id_server, medium, address)
+ return await self._lookup_3pid_v1(id_server, id_server_url, medium, address)
async def _lookup_3pid_v1(
- self, id_server: str, medium: str, address: str
+ self, id_server: str, id_server_url: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server: The server name (including port, if required)
of the identity server to use.
+ id_server_url: The actual, reachable domain of the id server
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
@@ -649,8 +787,8 @@ class IdentityHandler(BaseHandler):
the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
{"medium": medium, "address": address},
)
@@ -667,13 +805,12 @@ class IdentityHandler(BaseHandler):
return None
async def _lookup_3pid_v2(
- self, id_server: str, id_access_token: str, medium: str, address: str
+ self, id_server_url: str, id_access_token: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
- id_server: The server name (including port, if required)
- of the identity server to use.
+ id_server_url: The protocol scheme and domain of the id server
id_access_token: The access token to authenticate to the identity server with
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
@@ -683,8 +820,8 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
+ hash_details = await self.http_client.get_json(
+ "%s/_matrix/identity/v2/hash_details" % (id_server_url,),
{"access_token": id_access_token},
)
except RequestTimedOutError:
@@ -692,15 +829,14 @@ class IdentityHandler(BaseHandler):
if not isinstance(hash_details, dict):
logger.warning(
- "Got non-dict object when checking hash details of %s%s: %s",
- id_server_scheme,
- id_server,
+ "Got non-dict object when checking hash details of %s: %s",
+ id_server_url,
hash_details,
)
raise SynapseError(
400,
- "Non-dict object from %s%s during v2 hash_details request: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Non-dict object from %s during v2 hash_details request: %s"
+ % (id_server_url, hash_details),
)
# Extract information from hash_details
@@ -714,8 +850,8 @@ class IdentityHandler(BaseHandler):
):
raise SynapseError(
400,
- "Invalid hash details received from identity server %s%s: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Invalid hash details received from identity server %s: %s"
+ % (id_server_url, hash_details),
)
# Check if any of the supported lookup algorithms are present
@@ -737,7 +873,7 @@ class IdentityHandler(BaseHandler):
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
- id_server,
+ id_server_url,
supported_lookup_algorithms,
)
raise SynapseError(
@@ -750,8 +886,8 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = await self.blacklisting_http_client.post_json_get_json(
- "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
+ lookup_results = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/v2/lookup" % (id_server_url,),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
@@ -839,15 +975,17 @@ class IdentityHandler(BaseHandler):
if self._web_client_location:
invite_config["org.matrix.web_client_location"] = self._web_client_location
+ # Rewrite the identity server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
- base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+ base_url = "%s/_matrix/identity" % (id_server_url,)
if id_access_token:
- key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (
+ id_server_url,
)
# Attempt a v2 lookup
@@ -866,9 +1004,8 @@ class IdentityHandler(BaseHandler):
raise e
if data is None:
- key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+ id_server_url,
)
url = base_url + "/api/v1/store-invite"
@@ -880,10 +1017,7 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
- "Error trying to call /store-invite on %s%s: %s",
- id_server_scheme,
- id_server,
- e,
+ "Error trying to call /store-invite on %s: %s", id_server_url, e,
)
if data is None:
@@ -896,10 +1030,9 @@ class IdentityHandler(BaseHandler):
)
except HttpResponseException as e:
logger.warning(
- "Error calling /store-invite on %s%s with fallback "
+ "Error calling /store-invite on %s with fallback "
"encoding: %s",
- id_server_scheme,
- id_server,
+ id_server_url,
e,
)
raise e
@@ -920,6 +1053,42 @@ class IdentityHandler(BaseHandler):
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
+ async def bind_email_using_internal_sydent_api(
+ self, id_server_url: str, email: str, user_id: str,
+ ):
+ """Bind an email to a fully qualified user ID using the internal API of an
+ instance of Sydent.
+
+ Args:
+ id_server_url: The URL of the Sydent instance
+ email: The email address to bind
+ user_id: The user ID to bind the email to
+
+ Raises:
+ HTTPResponseException: On a non-2xx HTTP response.
+ """
+ # Extract the domain name from the IS URL as we store IS domains instead of URLs
+ id_server = urllib.parse.urlparse(id_server_url).hostname
+ if not id_server:
+ # We were unable to determine the hostname, bail out
+ return
+
+ # id_server_url is assumed to have no trailing slashes
+ url = id_server_url + "/_matrix/identity/internal/bind"
+ body = {
+ "address": email,
+ "medium": "email",
+ "mxid": user_id,
+ }
+
+ # Bind the threepid
+ await self.http_client.post_json_get_json(url, body)
+
+ # Remember where we bound the threepid
+ await self.store.add_user_bound_threepid(
+ user_id=user_id, medium="email", address=email, id_server=id_server,
+ )
+
def create_id_access_token_header(id_access_token: str) -> List[str]:
"""Create an Authorization header for passing to SimpleHttpClient as the header value
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index fbd8df9dcc..78c3e5a10b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler):
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt = await self.store.get_linearized_receipts_for_rooms(
- joined_rooms, to_key=int(now_token.receipt_key),
+ joined_rooms,
+ to_key=int(now_token.receipt_key),
)
tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
- room_end_token = RoomStreamToken(None, event.stream_ordering,)
+ room_end_token = RoomStreamToken(
+ None,
+ event.stream_ordering,
+ )
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)
@@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler):
membership,
member_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True,
+ room_id,
+ user_id,
+ allow_departed_users=True,
)
is_peeking = member_event_id is None
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a15336bf00..1aded280c7 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+# Copyrignt 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -40,6 +41,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
+from synapse.config.api import DEFAULT_ROOM_STATE_TYPES
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
@@ -65,8 +67,7 @@ logger = logging.getLogger(__name__)
class MessageHandler:
- """Contains some read only APIs to get state about a room
- """
+ """Contains some read only APIs to get state about a room"""
def __init__(self, hs):
self.auth = hs.get_auth()
@@ -88,9 +89,13 @@ class MessageHandler:
)
async def get_room_data(
- self, user_id: str, room_id: str, event_type: str, state_key: str,
+ self,
+ user_id: str,
+ room_id: str,
+ event_type: str,
+ state_key: str,
) -> dict:
- """ Get data from a room.
+ """Get data from a room.
Args:
user_id
@@ -174,7 +179,10 @@ class MessageHandler:
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = await filter_events_for_client(
- self.storage, user_id, last_events, filter_send_to_client=False,
+ self.storage,
+ user_id,
+ last_events,
+ filter_send_to_client=False,
)
event = last_events[0]
@@ -494,7 +502,7 @@ class EventCreationHandler:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
- if membership in {Membership.JOIN, Membership.INVITE}:
+ if membership in {Membership.JOIN, Membership.INVITE, Membership.KNOCK}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
@@ -571,7 +579,7 @@ class EventCreationHandler:
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
) -> bool:
- """"Determine if an event to be sent is exempt from having to consent
+ """ "Determine if an event to be sent is exempt from having to consent
to the privacy policy
Args:
@@ -793,9 +801,10 @@ class EventCreationHandler:
"""
if prev_event_ids is not None:
- assert len(prev_event_ids) <= 10, (
- "Attempting to create an event with %i prev_events"
- % (len(prev_event_ids),)
+ assert (
+ len(prev_event_ids) <= 10
+ ), "Attempting to create an event with %i prev_events" % (
+ len(prev_event_ids),
)
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
@@ -821,7 +830,8 @@ class EventCreationHandler:
)
if not third_party_result:
logger.info(
- "Event %s forbidden by third-party rules", event,
+ "Event %s forbidden by third-party rules",
+ event,
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
@@ -920,8 +930,8 @@ class EventCreationHandler:
room_version = await self.store.get_room_version_id(event.room_id)
if event.internal_metadata.is_out_of_band_membership():
- # the only sort of out-of-band-membership events we expect to see here
- # are invite rejections we have generated ourselves.
+ # the only sort of out-of-band-membership events we expect to see here are
+ # invite rejections and rescinded knocks that we have generated ourselves.
assert event.type == EventTypes.Member
assert event.content["membership"] == Membership.LEAVE
else:
@@ -1167,6 +1177,13 @@ class EventCreationHandler:
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
+ if event.content["membership"] == Membership.KNOCK:
+ event.unsigned[
+ "knock_room_state"
+ ] = await self.store.get_stripped_room_state_from_event_context(
+ context, DEFAULT_ROOM_STATE_TYPES,
+ )
+
if event.type == EventTypes.Redaction:
original_event = await self.store.get_event(
event.redacts,
@@ -1316,7 +1333,11 @@ class EventCreationHandler:
# Since this is a dummy-event it is OK if it is sent by a
# shadow-banned user.
await self.handle_new_client_event(
- requester, event, context, ratelimit=False, ignore_shadow_ban=True,
+ requester,
+ event,
+ context,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
return True
except AuthError:
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 71008ec50d..07db1e31e4 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -41,13 +41,33 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
+from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-SESSION_COOKIE_NAME = b"oidc_session"
+# we want the cookie to be returned to us even when the request is the POSTed
+# result of a form on another domain, as is used with `response_mode=form_post`.
+#
+# Modern browsers will not do so unless we set SameSite=None; however *older*
+# browsers (including all versions of Safari on iOS 12?) don't support
+# SameSite=None, and interpret it as SameSite=Strict:
+# https://bugs.webkit.org/show_bug.cgi?id=198181
+#
+# As a rather painful workaround, we set *two* cookies, one with SameSite=None
+# and one with no SameSite, in the hope that at least one of them will get
+# back to us.
+#
+# Secure is necessary for SameSite=None (and, empirically, also breaks things
+# on iOS 12.)
+#
+# Here we have the names of the cookies, and the options we use to set them.
+_SESSION_COOKIES = [
+ (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
+ (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
+]
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
#: OpenID.Core sec 3.1.3.3.
@@ -72,8 +92,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcHandler:
- """Handles requests related to the OpenID Connect login flow.
- """
+ """Handles requests related to the OpenID Connect login flow."""
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()
@@ -123,7 +142,6 @@ class OidcHandler:
Args:
request: the incoming request from the browser.
"""
-
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
@@ -137,8 +155,12 @@ class OidcHandler:
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
- if error != "access_denied":
- logger.error("Error from the OIDC provider: %s %s", error, description)
+ logger.log(
+ logging.INFO if error == "access_denied" else logging.ERROR,
+ "Received OIDC callback with error: %s %s",
+ error,
+ description,
+ )
self._sso_handler.render_error(request, error, description)
return
@@ -146,30 +168,37 @@ class OidcHandler:
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
- # Fetch the session cookie
- session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
- if session is None:
- logger.info("No session cookie found")
+ # Fetch the session cookie. See the comments on SESSION_COOKIES for why there
+ # are two.
+
+ for cookie_name, _ in _SESSION_COOKIES:
+ session = request.getCookie(cookie_name) # type: Optional[bytes]
+ if session is not None:
+ break
+ else:
+ logger.info("Received OIDC callback, with no session cookie")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
- # Remove the cookie. There is a good chance that if the callback failed
+ # Remove the cookies. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
- # Removing it early avoids spamming the provider with token requests.
- request.addCookie(
- SESSION_COOKIE_NAME,
- b"",
- path="/_synapse/oidc",
- expires="Thu, Jan 01 1970 00:00:00 UTC",
- httpOnly=True,
- sameSite="lax",
- )
+ # Removing the cookies early avoids spamming the provider with token requests.
+ #
+ # we have to build the header by hand rather than calling request.addCookie
+ # because the latter does not support SameSite=None
+ # (https://twistedmatrix.com/trac/ticket/10088)
+
+ for cookie_name, options in _SESSION_COOKIES:
+ request.cookies.append(
+ b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s"
+ % (cookie_name, options)
+ )
# Check for the state query parameter
if b"state" not in request.args:
- logger.info("State parameter is missing")
+ logger.info("Received OIDC callback, with no state parameter")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
@@ -183,14 +212,16 @@ class OidcHandler:
session, state
)
except (MacaroonDeserializationException, ValueError) as e:
- logger.exception("Invalid session")
+ logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
- logger.exception("Could not verify session")
+ logger.exception("Could not verify session for OIDC callback")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
+ logger.info("Received OIDC callback for IdP %s", session_data.idp_id)
+
oidc_provider = self._providers.get(session_data.idp_id)
if not oidc_provider:
logger.error("OIDC session uses unknown IdP %r", oidc_provider)
@@ -210,8 +241,7 @@ class OidcHandler:
class OidcError(Exception):
- """Used to catch errors when calling the token_endpoint
- """
+ """Used to catch errors when calling the token_endpoint"""
def __init__(self, error, error_description=None):
self.error = error
@@ -240,22 +270,27 @@ class OidcProvider:
self._token_generator = token_generator
+ self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth(
- provider.client_id, provider.client_secret, provider.client_auth_method,
+ provider.client_id,
+ provider.client_secret,
+ provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
- self._provider_metadata = OpenIDProviderMetadata(
- issuer=provider.issuer,
- authorization_endpoint=provider.authorization_endpoint,
- token_endpoint=provider.token_endpoint,
- userinfo_endpoint=provider.userinfo_endpoint,
- jwks_uri=provider.jwks_uri,
- ) # type: OpenIDProviderMetadata
- self._provider_needs_discovery = provider.discover
+
+ # cache of metadata for the identity provider (endpoint uris, mostly). This is
+ # loaded on-demand from the discovery endpoint (if discovery is enabled), with
+ # possible overrides from the config. Access via `load_metadata`.
+ self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+
+ # cache of JWKs used by the identity provider to sign tokens. Loaded on demand
+ # from the IdP's jwks_uri, if required.
+ self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
+
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
)
@@ -281,7 +316,7 @@ class OidcProvider:
self._sso_handler.register_identity_provider(self)
- def _validate_metadata(self):
+ def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
"""Verifies the provider metadata.
This checks the validity of the currently loaded provider. Not
@@ -300,7 +335,6 @@ class OidcProvider:
if self._skip_verification is True:
return
- m = self._provider_metadata
m.validate_issuer()
m.validate_authorization_endpoint()
m.validate_token_endpoint()
@@ -335,11 +369,7 @@ class OidcProvider:
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
- if m.get("jwks") is None:
- if m.get("jwks_uri") is not None:
- m.validate_jwks_uri()
- else:
- raise ValueError('"jwks_uri" must be set')
+ m.validate_jwks_uri()
@property
def _uses_userinfo(self) -> bool:
@@ -356,11 +386,15 @@ class OidcProvider:
or self._user_profile_method == "userinfo_endpoint"
)
- async def load_metadata(self) -> OpenIDProviderMetadata:
- """Load and validate the provider metadata.
+ async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
+ """Return the provider metadata.
- The values metadatas are discovered if ``oidc_config.discovery`` is
- ``True`` and then cached.
+ If this is the first call, the metadata is built from the config and from the
+ metadata discovery endpoint (if enabled), and then validated. If the metadata
+ is successfully validated, it is then cached for future use.
+
+ Args:
+ force: If true, any cached metadata is discarded to force a reload.
Raises:
ValueError: if something in the provider is not valid
@@ -368,18 +402,41 @@ class OidcProvider:
Returns:
The provider's metadata.
"""
- # If we are using the OpenID Discovery documents, it needs to be loaded once
- # FIXME: should there be a lock here?
- if self._provider_needs_discovery:
- url = get_well_known_url(self._provider_metadata["issuer"], external=True)
+ if force:
+ # reset the cached call to ensure we get a new result
+ self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+
+ return await self._provider_metadata.get()
+
+ async def _load_metadata(self) -> OpenIDProviderMetadata:
+ # start out with just the issuer (unlike the other settings, discovered issuer
+ # takes precedence over configured issuer, because configured issuer is
+ # required for discovery to take place.)
+ #
+ metadata = OpenIDProviderMetadata(issuer=self._config.issuer)
+
+ # load any data from the discovery endpoint, if enabled
+ if self._config.discover:
+ url = get_well_known_url(self._config.issuer, external=True)
metadata_response = await self._http_client.get_json(url)
- # TODO: maybe update the other way around to let user override some values?
- self._provider_metadata.update(metadata_response)
- self._provider_needs_discovery = False
+ metadata.update(metadata_response)
+
+ # override any discovered data with any settings in our config
+ if self._config.authorization_endpoint:
+ metadata["authorization_endpoint"] = self._config.authorization_endpoint
+
+ if self._config.token_endpoint:
+ metadata["token_endpoint"] = self._config.token_endpoint
- self._validate_metadata()
+ if self._config.userinfo_endpoint:
+ metadata["userinfo_endpoint"] = self._config.userinfo_endpoint
- return self._provider_metadata
+ if self._config.jwks_uri:
+ metadata["jwks_uri"] = self._config.jwks_uri
+
+ self._validate_metadata(metadata)
+
+ return metadata
async def load_jwks(self, force: bool = False) -> JWKS:
"""Load the JSON Web Key Set used to sign ID tokens.
@@ -409,27 +466,27 @@ class OidcProvider:
]
}
"""
+ if force:
+ # reset the cached call to ensure we get a new result
+ self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
+ return await self._jwks.get()
+
+ async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}
- # First check if the JWKS are loaded in the provider metadata.
- # It can happen either if the provider gives its JWKS in the discovery
- # document directly or if it was already loaded once.
metadata = await self.load_metadata()
- jwk_set = metadata.get("jwks")
- if jwk_set is not None and not force:
- return jwk_set
- # Loading the JWKS using the `jwks_uri` metadata
+ # Load the JWKS using the `jwks_uri` metadata.
uri = metadata.get("jwks_uri")
if not uri:
+ # this should be unreachable: load_metadata validates that
+ # there is a jwks_uri in the metadata if _uses_userinfo is unset
raise RuntimeError('Missing "jwks_uri" in metadata')
jwk_set = await self._http_client.get_json(uri)
- # Caching the JWKS in the provider's metadata
- self._provider_metadata["jwks"] = jwk_set
return jwk_set
async def _exchange_code(self, code: str) -> Token:
@@ -487,7 +544,10 @@ class OidcProvider:
# We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code and we do the body encoding ourself.
response = await self._http_client.request(
- method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
+ method="POST",
+ uri=uri,
+ data=body.encode("utf-8"),
+ headers=headers,
)
# This is used in multiple error messages below
@@ -565,6 +625,7 @@ class OidcProvider:
Returns:
UserInfo: an object representing the user.
"""
+ logger.debug("Using the OAuth2 access_token to request userinfo")
metadata = await self.load_metadata()
resp = await self._http_client.get_json(
@@ -572,6 +633,8 @@ class OidcProvider:
headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
)
+ logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
+
return UserInfo(resp)
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
@@ -600,17 +663,19 @@ class OidcProvider:
claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
-
jwt = JsonWebToken(alg_values)
claim_options = {"iss": {"values": [metadata["issuer"]]}}
+ id_token = token["id_token"]
+ logger.debug("Attempting to decode JWT id_token %r", id_token)
+
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
- token["id_token"],
+ id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
@@ -620,13 +685,15 @@ class OidcProvider:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
- token["id_token"],
+ id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
+ logger.debug("Decoded id_token JWT %r; validating", claims)
+
claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)
@@ -681,14 +748,18 @@ class OidcProvider:
ui_auth_session_id=ui_auth_session_id,
),
)
- request.addCookie(
- SESSION_COOKIE_NAME,
- cookie,
- path="/_synapse/client/oidc",
- max_age="3600",
- httpOnly=True,
- sameSite="lax",
- )
+
+ # Set the cookies. See the comments on _SESSION_COOKIES for why there are two.
+ #
+ # we have to build the header by hand rather than calling request.addCookie
+ # because the latter does not support SameSite=None
+ # (https://twistedmatrix.com/trac/ticket/10088)
+
+ for cookie_name, options in _SESSION_COOKIES:
+ request.cookies.append(
+ b"%s=%s; Max-Age=3600; %s"
+ % (cookie_name, cookie.encode("utf-8"), options)
+ )
metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
@@ -726,19 +797,18 @@ class OidcProvider:
"""
# Exchange the code with the provider
try:
- logger.debug("Exchanging code")
+ logger.debug("Exchanging OAuth2 code for a token")
token = await self._exchange_code(code)
except OidcError as e:
- logger.exception("Could not exchange code")
+ logger.exception("Could not exchange OAuth2 code")
self._sso_handler.render_error(request, e.error, e.error_description)
return
- logger.debug("Successfully obtained OAuth2 access token")
+ logger.debug("Successfully obtained OAuth2 token data: %r", token)
# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
- logger.debug("Fetching userinfo")
try:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
@@ -746,7 +816,6 @@ class OidcProvider:
self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
- logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
@@ -939,7 +1008,9 @@ class OidcSessionTokenGenerator:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
- location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+ location=self._server_name,
+ identifier="key",
+ key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5372753707..059064a4eb 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -197,7 +197,8 @@ class PaginationHandler:
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = await self.store.get_room_event_before_stream_ordering(
- room_id, stream_ordering,
+ room_id,
+ stream_ordering,
)
if not r:
logger.warning(
@@ -223,7 +224,12 @@ class PaginationHandler:
# the background so that it's not blocking any other operation apart from
# other purges in the same room.
run_as_background_process(
- "_purge_history", self._purge_history, purge_id, room_id, token, True,
+ "_purge_history",
+ self._purge_history,
+ purge_id,
+ room_id,
+ token,
+ True,
)
def start_purge_history(
@@ -389,7 +395,9 @@ class PaginationHandler:
)
await self.hs.get_federation_handler().maybe_backfill(
- room_id, curr_topo, limit=pagin_config.limit,
+ room_id,
+ curr_topo,
+ limit=pagin_config.limit,
)
to_room_key = None
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 22d1e9d35c..fb85b19770 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -349,10 +349,13 @@ class PresenceHandler(BasePresenceHandler):
[self.user_to_current_state[user_id] for user_id in unpersisted]
)
- async def _update_states(self, new_states):
+ async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None:
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
should be sent to clients/servers.
+
+ Args:
+ new_states: The new user presence state updates to process.
"""
now = self.clock.time_msec()
@@ -368,7 +371,7 @@ class PresenceHandler(BasePresenceHandler):
new_states_dict = {}
for new_state in new_states:
new_states_dict[new_state.user_id] = new_state
- new_state = new_states_dict.values()
+ new_states = new_states_dict.values()
for new_state in new_states:
user_id = new_state.user_id
@@ -635,8 +638,7 @@ class PresenceHandler(BasePresenceHandler):
self.external_process_last_updated_ms.pop(process_id, None)
async def current_state_for_user(self, user_id):
- """Get the current presence state for a user.
- """
+ """Get the current presence state for a user."""
res = await self.current_state_for_users([user_id])
return res[user_id]
@@ -658,17 +660,6 @@ class PresenceHandler(BasePresenceHandler):
self._push_to_remotes(states)
- async def notify_for_states(self, state, stream_id):
- parties = await get_interested_parties(self.store, [state])
- room_ids_to_states, users_to_states = parties
-
- self.notifier.on_new_event(
- "presence_key",
- stream_id,
- rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states],
- )
-
def _push_to_remotes(self, states):
"""Sends state updates to remote servers.
@@ -678,8 +669,7 @@ class PresenceHandler(BasePresenceHandler):
self.federation.send_presence(states)
async def incoming_presence(self, origin, content):
- """Called when we receive a `m.presence` EDU from a remote server.
- """
+ """Called when we receive a `m.presence` EDU from a remote server."""
if not self._presence_enabled:
return
@@ -729,8 +719,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states(updates)
async def set_state(self, target_user, state, ignore_status_msg=False):
- """Set the presence state of the user.
- """
+ """Set the presence state of the user."""
status_msg = state.get("status_msg", None)
presence = state["presence"]
@@ -758,8 +747,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states([prev_state.copy_and_replace(**new_fields)])
async def is_visible(self, observed_user, observer_user):
- """Returns whether a user can see another user's presence.
- """
+ """Returns whether a user can see another user's presence."""
observer_room_ids = await self.store.get_rooms_for_user(
observer_user.to_string()
)
@@ -953,8 +941,7 @@ class PresenceHandler(BasePresenceHandler):
def should_notify(old_state, new_state):
- """Decides if a presence state change should be sent to interested parties.
- """
+ """Decides if a presence state change should be sent to interested parties."""
if old_state == new_state:
return False
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c02b951031..b04ee5f430 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +15,11 @@
# limitations under the License.
import logging
import random
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional
+
+from signedjson.sign import sign_json
+
+from twisted.internet import reactor
from synapse.api.errors import (
AuthError,
@@ -24,7 +29,11 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import (
JsonDict,
Requester,
@@ -54,6 +63,8 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -64,11 +75,98 @@ class ProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+ self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to
+
if hs.config.run_background_tasks:
self.clock.looping_call(
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._do_assign_profile_replication_batches)
+ reactor.callWhenRunning(self._start_replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._start_replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ def _do_assign_profile_replication_batches(self):
+ return run_as_background_process(
+ "_assign_profile_replication_batches",
+ self._assign_profile_replication_batches,
+ )
+
+ def _start_replicate_profiles(self):
+ return run_as_background_process(
+ "_replicate_profiles", self._replicate_profiles
+ )
+
+ async def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = await self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ async def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = await self.store.get_replication_hosts()
+ latest_batch = await self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ await self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ async def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = await self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ await self.http_client.post_json_get_json(url, signed_body)
+ await self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info(
+ "Successfully replicated profile batch %d to %s", batchnum, host
+ )
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
@@ -207,11 +305,20 @@ class ProfileHandler(BaseHandler):
# This must be done by the target user himself.
if by_admin:
requester = create_requester(
- target_user, authenticated_entity=requester.authenticated_entity,
+ target_user,
+ authenticated_entity=requester.authenticated_entity,
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
)
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
await self.store.set_profile_displayname(
- target_user.localpart, displayname_to_set
+ target_user.localpart, displayname_to_set, new_batchnum
)
if self.hs.config.user_directory_search_all_users:
@@ -222,6 +329,46 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ async def set_active(
+ self, users: List[UserID], active: bool, hide: bool,
+ ):
+ """
+ Sets the 'active' flag on a set of user profiles. If set to false, the
+ accounts are considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the users rather than deactivating them. This means withholding the
+ profiles from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB.
+
+ Note that unlike set_displayname and set_avatar_url, this does *not*
+ perform authorization checks! This is because the only place it's used
+ currently is in account deactivation where we've already done these
+ checks anyway.
+
+ Args:
+ users: The users to modify
+ active: Whether to set the user to active or inactive
+ hide: Whether to hide the user (withold from replication). If
+ False and active is False, user will have their profile
+ erased
+ """
+ if len(self.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ await self.store.set_profiles_active(users, active, hide, new_batchnum)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user):
try:
@@ -290,14 +437,56 @@ class ProfileHandler(BaseHandler):
if new_avatar_url == "":
avatar_url_to_set = None
+ # Enforce a max avatar size if one is defined
+ if avatar_url_to_set and (
+ self.max_avatar_size or self.allowed_avatar_mimetypes
+ ):
+ media_id = self._validate_and_parse_media_id_from_avatar_url(
+ avatar_url_to_set
+ )
+
+ # Check that this media exists locally
+ media_info = await self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
# Same like set_displayname
if by_admin:
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
await self.store.set_profile_avatar_url(
- target_user.localpart, avatar_url_to_set
+ target_user.localpart, avatar_url_to_set, new_batchnum
)
if self.hs.config.user_directory_search_all_users:
@@ -308,6 +497,23 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index cc21fc2284..6a6c528849 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler):
)
else:
hs.get_federation_registry().register_instances_for_edu(
- "m.receipt", hs.config.worker.writers.receipts,
+ "m.receipt",
+ hs.config.worker.writers.receipts,
)
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
- """Called when we receive an EDU of type m.receipt from a remote HS.
- """
+ """Called when we receive an EDU of type m.receipt from a remote HS."""
receipts = []
for room_id, room_values in content.items():
for receipt_type, users in room_values.items():
@@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler):
await self._handle_new_receipts(receipts)
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
- """Takes a list of receipts, stores them and informs the notifier.
- """
+ """Takes a list of receipts, stores them and informs the notifier."""
min_batch_id = None # type: Optional[int]
max_batch_id = None # type: Optional[int]
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 49b085269b..553fcb5b66 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -49,6 +49,7 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_identity_handler()
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
@@ -57,13 +58,15 @@ class RegistrationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
hs
)
- self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
- hs
+ self._post_registration_client = (
+ ReplicationPostRegisterActionsServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
@@ -77,6 +80,16 @@ class RegistrationHandler(BaseHandler):
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
):
+ """
+
+ Args:
+ localpart (str|None): The user's localpart
+ guest_access_token (str|None): A guest's access token
+ assigned_user_id (str|None): An existing User ID for this user if pre-calculated
+
+ Returns:
+ Deferred
+ """
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -119,6 +132,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
+
+ # Retrieve guest user information from provided access token
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if (
not user_data.is_guest
@@ -189,12 +204,15 @@ class RegistrationHandler(BaseHandler):
self.check_registration_ratelimit(address)
result = await self.spam_checker.check_registration_for_spam(
- threepid, localpart, user_agent_ips or [],
+ threepid,
+ localpart,
+ user_agent_ips or [],
)
if result == RegistrationBehaviour.DENY:
logger.info(
- "Blocked registration of %r", localpart,
+ "Blocked registration of %r",
+ localpart,
)
# We return a 429 to make it not obvious that they've been
# denied.
@@ -203,7 +221,8 @@ class RegistrationHandler(BaseHandler):
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
if shadow_banned:
logger.info(
- "Shadow banning registration of %r", localpart,
+ "Shadow banning registration of %r",
+ localpart,
)
# do not check_auth_blocking if the call is coming through the Admin API
@@ -237,6 +256,12 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
+ if default_display_name:
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
@@ -246,8 +271,6 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
- # If a default display name is not given, generate one.
- generate_display_name = default_display_name is None
# This breaks on successful registration *or* errors after 10 failures.
while True:
# Fail after being unable to find a suitable ID a few times
@@ -258,7 +281,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
- if generate_display_name:
+ if default_display_name is None:
default_display_name = localpart
try:
await self.register_with_store(
@@ -270,6 +293,11 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, default_display_name, by_admin=True
+ )
+
# Successfully registered
break
except SynapseError:
@@ -301,7 +329,15 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- await self._register_email_threepid(user_id, threepid_dict, None)
+ await self.register_email_threepid(user_id, threepid_dict, None)
+
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ await self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ await self.profile_handler.set_active([user], False, True)
return user_id
@@ -369,7 +405,9 @@ class RegistrationHandler(BaseHandler):
config["room_alias_name"] = room_alias.localpart
info, _ = await room_creation_handler.create_room(
- fake_requester, config=config, ratelimit=False,
+ fake_requester,
+ config=config,
+ ratelimit=False,
)
# If the room does not require an invite, but another user
@@ -501,7 +539,10 @@ class RegistrationHandler(BaseHandler):
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart: str, as_token: str) -> str:
+ async def appservice_register(
+ self, user_localpart: str, as_token: str, password_hash: str, display_name: str
+ ):
+ # FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -518,12 +559,26 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
+ display_name = display_name or user.localpart
+
await self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
)
+
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, display_name, by_admin=True
+ )
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = await self.store.get_profileinfo(user_localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
return user_id
def check_user_id_not_appservice_exclusive(
@@ -552,6 +607,37 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
+ async def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
@@ -704,6 +790,7 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
+
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
@@ -711,7 +798,32 @@ class RegistrationHandler(BaseHandler):
):
await self.store.upsert_monthly_active_user(user_id)
- await self._register_email_threepid(user_id, threepid, access_token)
+ await self.register_email_threepid(user_id, threepid, access_token)
+
+ if self.hs.config.bind_new_user_emails_to_sydent:
+ # Attempt to call Sydent's internal bind API on the given identity server
+ # to bind this threepid
+ id_server_url = self.hs.config.bind_new_user_emails_to_sydent
+
+ logger.debug(
+ "Attempting the bind email of %s to identity server: %s using "
+ "internal Sydent bind API.",
+ user_id,
+ self.hs.config.bind_new_user_emails_to_sydent,
+ )
+
+ try:
+ await self.identity_handler.bind_email_using_internal_sydent_api(
+ id_server_url, threepid["address"], user_id
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to bind email of '%s' to Sydent instance '%s' ",
+ "using Sydent internal bind API: %s",
+ user_id,
+ id_server_url,
+ e,
+ )
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
@@ -731,7 +843,7 @@ class RegistrationHandler(BaseHandler):
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def _register_email_threepid(
+ async def register_email_threepid(
self, user_id: str, threepid: dict, token: Optional[str]
) -> None:
"""Add an email address as a 3pid identifier
@@ -753,7 +865,10 @@ class RegistrationHandler(BaseHandler):
return
await self._auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
)
# And we add an email pusher for them by default, but only
@@ -805,5 +920,8 @@ class RegistrationHandler(BaseHandler):
raise
await self._auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 07b2187eb1..2271c60afc 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -38,6 +38,7 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
+from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
@@ -197,7 +198,9 @@ class RoomCreationHandler(BaseHandler):
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = await self._generate_room_id(
- creator_id=user_id, is_public=r["is_public"], room_version=new_version,
+ creator_id=user_id,
+ is_public=r["is_public"],
+ room_version=new_version,
)
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
@@ -235,7 +238,9 @@ class RoomCreationHandler(BaseHandler):
# now send the tombstone
await self.event_creation_handler.handle_new_client_event(
- requester=requester, event=tombstone_event, context=tombstone_context,
+ requester=requester,
+ event=tombstone_event,
+ context=tombstone_context,
)
old_room_state = await tombstone_context.get_current_state_ids()
@@ -256,7 +261,10 @@ class RoomCreationHandler(BaseHandler):
# finally, shut down the PLs in the old room, and update them in the new
# room.
await self._update_upgraded_room_pls(
- requester, old_room_id, new_room_id, old_room_state,
+ requester,
+ old_room_id,
+ new_room_id,
+ old_room_state,
)
return new_room_id
@@ -363,7 +371,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not await self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -424,17 +444,20 @@ class RoomCreationHandler(BaseHandler):
# Copy over user power levels now as this will not be possible with >100PL users once
# the room has been created
-
# Calculate the minimum power level needed to clone the room
event_power_levels = power_levels.get("events", {})
- state_default = power_levels.get("state_default", 0)
- ban = power_levels.get("ban")
+ state_default = power_levels.get("state_default", 50)
+ ban = power_levels.get("ban", 50)
needed_power_level = max(state_default, ban, max(event_power_levels.values()))
+ # Get the user's current power level, this matches the logic in get_user_power_level,
+ # but without the entire state map.
+ user_power_levels = power_levels.setdefault("users", {})
+ users_default = power_levels.get("users_default", 0)
+ current_power_level = user_power_levels.get(user_id, users_default)
# Raise the requester's power level in the new room if necessary
- current_power_level = power_levels["users"][user_id]
if current_power_level < needed_power_level:
- power_levels["users"][user_id] = needed_power_level
+ user_power_levels[user_id] = needed_power_level
await self._send_events_for_new_room(
requester,
@@ -566,7 +589,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[dict, int]:
- """ Creates a new room.
+ """Creates a new room.
Args:
requester:
@@ -614,8 +637,14 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -687,7 +716,9 @@ class RoomCreationHandler(BaseHandler):
is_public = visibility == "public"
room_id = await self._generate_room_id(
- creator_id=user_id, is_public=is_public, room_version=room_version,
+ creator_id=user_id,
+ is_public=is_public,
+ room_version=room_version,
)
# Check whether this visibility value is blocked by a third party module
@@ -803,6 +834,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -820,6 +852,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
id_access_token=id_access_token,
)
@@ -828,7 +861,7 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
- # Always wait for room creation to progate before returning
+ # Always wait for room creation to propagate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
@@ -880,7 +913,10 @@ class RoomCreationHandler(BaseHandler):
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- creator, event, ratelimit=False, ignore_shadow_ban=True,
+ creator,
+ event,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
return last_stream_id
@@ -897,6 +933,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=ratelimit,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
@@ -980,7 +1017,10 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
- self, creator_id: str, is_public: bool, room_version: RoomVersion,
+ self,
+ creator_id: str,
+ is_public: bool,
+ room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -1004,41 +1044,51 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
+ self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
async def get_event_context(
self,
- user: UserID,
+ requester: Requester,
room_id: str,
event_id: str,
limit: int,
event_filter: Optional[Filter],
+ use_admin_priviledge: bool = False,
) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
- user
+ requester
room_id
event_id
limit: The maximum number of events to return in total
(excluding state).
event_filter: the filter to apply to the events returned
(excluding the target event_id)
-
+ use_admin_priviledge: if `True`, return all events, regardless
+ of whether `user` has access to them. To be used **ONLY**
+ from the admin API.
Returns:
dict, or None if the event isn't found
"""
+ user = requester.user
+ if use_admin_priviledge:
+ await assert_user_is_admin(self.auth, requester.user)
+
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
- def filter_evts(events):
- return filter_events_for_client(
+ async def filter_evts(events):
+ if use_admin_priviledge:
+ return events
+ return await filter_events_for_client(
self.storage, user.to_string(), events, is_peeking=is_peeking
)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 14f14db449..373b9dcd0d 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -170,6 +170,7 @@ class RoomListHandler(BaseHandler):
"world_readable": room["history_visibility"]
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
+ "join_rule": room["join_rules"],
}
# Filter out Nones – rather omit the field altogether
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a5da97cfe0..312ebc139c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import abc
import logging
import random
@@ -31,7 +31,15 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
+from synapse.types import (
+ JsonDict,
+ Requester,
+ RoomAlias,
+ RoomID,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -122,6 +130,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def remote_knock(
+ self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ ) -> Tuple[str, int]:
+ """Try and knock on a room that this server is not in
+
+ Args:
+ remote_room_hosts: List of servers that can be used to knock via.
+ room_id: Room that we are trying to knock on.
+ user: User who is trying to knock.
+ content: A dict that should be used as the content of the knock event.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
async def remote_reject_invite(
self,
invite_event_id: str,
@@ -145,6 +167,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Rescind a local knock made on a remote room.
+
+ Args:
+ knock_event_id: The ID of the knock event to rescind.
+ txn_id: An optional transaction ID supplied by the client.
+ requester: The user making the request, according to the access token.
+ content: The content of the generated leave event.
+
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
@@ -191,7 +234,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# do it up front for efficiency.)
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
- room_id, requester.user.to_string(), requester.access_token_id, txn_id,
+ room_id,
+ requester.user.to_string(),
+ requester.access_token_id,
+ txn_id,
)
if existing_event_id:
event_pos = await self.store.get_position_for_event(existing_event_id)
@@ -238,7 +284,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
result_event = await self.event_creation_handler.handle_new_client_event(
- requester, event, context, extra_users=[target], ratelimit=ratelimit,
+ requester,
+ event,
+ context,
+ extra_users=[target],
+ ratelimit=ratelimit,
)
if event.membership == Membership.LEAVE:
@@ -300,6 +350,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -340,6 +391,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
@@ -356,6 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
"""Helper for update_membership.
@@ -438,8 +491,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
+ is_published = await self.store.is_room_published(room_id)
+
if not await self.spam_checker.user_may_invite(
- requester.user.to_string(), target_id, room_id
+ requester.user.to_string(),
+ target_id,
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -517,6 +577,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ inviter = await self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if not is_host_in_room:
if ratelimit:
time_now_s = self.clock.time()
@@ -554,50 +633,79 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
- # perhaps we've been invited
+ # Figure out the user's current membership state for the room
(
current_membership_type,
current_membership_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
target.to_string(), room_id
)
- if (
- current_membership_type != Membership.INVITE
- or not current_membership_event_id
- ):
+ if not current_membership_type or not current_membership_event_id:
logger.info(
"%s sent a leave request to %s, but that is not an active room "
- "on this server, and there is no pending invite",
+ "on this server, or there is no pending invite or knock",
target,
room_id,
)
raise SynapseError(404, "Not a known room")
- invite = await self.store.get_event(current_membership_event_id)
- logger.info(
- "%s rejects invite to %s from %s", target, room_id, invite.sender
- )
+ # perhaps we've been invited
+ if current_membership_type == Membership.INVITE:
+ invite = await self.store.get_event(current_membership_event_id)
+ logger.info(
+ "%s rejects invite to %s from %s",
+ target,
+ room_id,
+ invite.sender,
+ )
- if not self.hs.is_mine_id(invite.sender):
- # send the rejection to the inviter's HS (with fallback to
- # local event)
- return await self.remote_reject_invite(
- invite.event_id, txn_id, requester, content,
+ if not self.hs.is_mine_id(invite.sender):
+ # send the rejection to the inviter's HS (with fallback to
+ # local event)
+ return await self.remote_reject_invite(
+ invite.event_id,
+ txn_id,
+ requester,
+ content,
+ )
+
+ # the inviter was on our server, but has now left. Carry on
+ # with the normal rejection codepath, which will also send the
+ # rejection out to any other servers we believe are still in the room.
+
+ # thanks to overzealous cleaning up of event_forward_extremities in
+ # `delete_old_current_state_events`, it's possible to end up with no
+ # forward extremities here. If that happens, let's just hang the
+ # rejection off the invite event.
+ #
+ # see: https://github.com/matrix-org/synapse/issues/7139
+ if len(latest_event_ids) == 0:
+ latest_event_ids = [invite.event_id]
+
+ # or perhaps this is a remote room that a local user has knocked on
+ elif current_membership_type == Membership.KNOCK:
+ knock = await self.store.get_event(current_membership_event_id)
+ return await self.remote_rescind_knock(
+ knock.event_id, txn_id, requester, content
)
- # the inviter was on our server, but has now left. Carry on
- # with the normal rejection codepath, which will also send the
- # rejection out to any other servers we believe are still in the room.
+ elif effective_membership_state == Membership.KNOCK:
+ if not is_host_in_room:
+ # The knock needs to be sent over federation instead
+ remote_room_hosts.append(get_domain_from_id(room_id))
+
+ content["membership"] = Membership.KNOCK
- # thanks to overzealous cleaning up of event_forward_extremities in
- # `delete_old_current_state_events`, it's possible to end up with no
- # forward extremities here. If that happens, let's just hang the
- # rejection off the invite event.
- #
- # see: https://github.com/matrix-org/synapse/issues/7139
- if len(latest_event_ids) == 0:
- latest_event_ids = [invite.event_id]
+ profile = self.profile_handler
+ if "displayname" not in content:
+ content["displayname"] = await profile.get_displayname(target)
+ if "avatar_url" not in content:
+ content["avatar_url"] = await profile.get_avatar_url(target)
+
+ return await self.remote_knock(
+ remote_room_hosts, room_id, target, content
+ )
return await self._local_membership_update(
requester=requester,
@@ -813,6 +921,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
id_server: str,
requester: Requester,
txn_id: Optional[str],
+ new_room: bool = False,
id_access_token: Optional[str] = None,
) -> int:
"""Invite a 3PID to a room.
@@ -860,6 +969,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Codes.FORBIDDEN,
)
+ can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
if not self._enable_lookup:
raise SynapseError(
403, "Looking up third-party identifiers is denied from this server"
@@ -869,6 +988,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
id_server, medium, address, id_access_token
)
+ is_published = await self.store.is_room_published(room_id)
+
+ if not await self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
# Note that update_membership with an action of "invite" can raise
# a ShadowBanError, but this was done above already.
@@ -1056,8 +1188,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
user: UserID,
content: dict,
) -> Tuple[str, int]:
- """Implements RoomMemberHandler._remote_join
- """
+ """Implements RoomMemberHandler._remote_join"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
# and if it is the only entry we'd like to return a 404 rather than a
# 500.
@@ -1158,6 +1289,35 @@ class RoomMemberMasterHandler(RoomMemberHandler):
invite_event, txn_id, requester, content
)
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rescinds a local knock made on a remote room
+
+ Args:
+ knock_event_id: The ID of the knock event to rescind.
+ txn_id: The transaction ID to use.
+ requester: The originator of the request.
+ content: The content of the leave event.
+
+ Implements RoomMemberHandler.remote_rescind_knock
+ """
+ # TODO: We don't yet support rescinding knocks over federation
+ # as we don't know which homeserver to send it to. An obvious
+ # candidate is the remote homeserver we originally knocked through,
+ # however we don't currently store that information.
+
+ # Just rescind the knock locally
+ knock_event = await self.store.get_event(knock_event_id)
+ return await self._generate_local_out_of_band_leave(
+ knock_event, txn_id, requester, content
+ )
+
async def _generate_local_out_of_band_leave(
self,
previous_membership_event: EventBase,
@@ -1211,16 +1371,44 @@ class RoomMemberMasterHandler(RoomMemberHandler):
event.internal_metadata.out_of_band_membership = True
result_event = await self.event_creation_handler.handle_new_client_event(
- requester, event, context, extra_users=[UserID.from_string(target_user)],
+ requester,
+ event,
+ context,
+ extra_users=[UserID.from_string(target_user)],
)
# we know it was persisted, so must have a stream ordering
assert result_event.internal_metadata.stream_ordering
return result_event.event_id, result_event.internal_metadata.stream_ordering
- async def _user_left_room(self, target: UserID, room_id: str) -> None:
- """Implements RoomMemberHandler._user_left_room
+ async def remote_knock(
+ self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ ) -> Tuple[str, int]:
+ """Sends a knock to a room. Attempts to do so via one remote out of a given list.
+
+ Args:
+ remote_room_hosts: A list of homeservers to try knocking through.
+ room_id: The ID of the room to knock on.
+ user: The user to knock on behalf of.
+ content: The content of the knock event.
+
+ Returns:
+ A tuple of (event ID, stream ID).
"""
+ # filter ourselves out of remote_room_hosts
+ remote_room_hosts = [
+ host for host in remote_room_hosts if host != self.hs.hostname
+ ]
+
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ return await self.federation_handler.do_knock(
+ remote_room_hosts, room_id, user.to_string(), content=content
+ )
+
+ async def _user_left_room(self, target: UserID, room_id: str) -> None:
+ """Implements RoomMemberHandler._user_left_room"""
user_left_room(self.distributor, target, room_id)
async def forget(self, user: UserID, room_id: str) -> None:
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index f2e88f6a5b..926d09f40c 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,10 +21,12 @@ from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
+ ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
+ ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
logger = logging.getLogger(__name__)
@@ -33,7 +36,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
+ self._remote_knock_client = ReplRemoteKnock.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
+ self._remote_rescind_client = ReplRescindKnock.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
async def _remote_join(
@@ -44,8 +49,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
user: UserID,
content: dict,
) -> Tuple[str, int]:
- """Implements RoomMemberHandler._remote_join
- """
+ """Implements RoomMemberHandler._remote_join"""
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
@@ -79,9 +83,51 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
)
return ret["event_id"], ret["stream_id"]
- async def _user_left_room(self, target: UserID, room_id: str) -> None:
- """Implements RoomMemberHandler._user_left_room
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
"""
+ Rescinds a local knock made on a remote room
+
+ Args:
+ knock_event_id: the knock event
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the request, according to the access token
+ content: additional content to include in the leave event.
+ Normally an empty dict.
+
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event)
+ """
+ ret = await self._remote_rescind_client(
+ knock_event_id=knock_event_id,
+ txn_id=txn_id,
+ requester=requester,
+ content=content,
+ )
+ return ret["event_id"], ret["stream_id"]
+
+ async def remote_knock(
+ self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
+ ) -> Tuple[str, int]:
+ """Sends a knock to a room.
+
+ Implements RoomMemberHandler.remote_knock
+ """
+ ret = await self._remote_knock_client(
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user=user,
+ content=content,
+ )
+ return ret["event_id"], ret["stream_id"]
+
+ async def _user_left_room(self, target: UserID, room_id: str) -> None:
+ """Implements RoomMemberHandler._user_left_room"""
await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left"
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e88fd59749..a9645b77d8 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
-from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
@@ -122,7 +121,8 @@ class SamlHandler(BaseHandler):
now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
- creation_time=now, ui_auth_session_id=ui_auth_session_id,
+ creation_time=now,
+ ui_auth_session_id=ui_auth_session_id,
)
for key, value in info["headers"]:
@@ -239,12 +239,10 @@ class SamlHandler(BaseHandler):
# Ensure that the attributes of the logged in user meet the required
# attributes.
- for requirement in self._saml2_attribute_requirements:
- if not _check_attribute_requirement(saml2_auth.ava, requirement):
- self._sso_handler.render_error(
- request, "unauthorised", "You are not authorised to log in here."
- )
- return
+ if not self._sso_handler.check_required_attributes(
+ request, saml2_auth.ava, self._saml2_attribute_requirements
+ ):
+ return
# Call the mapper to register/login the user
try:
@@ -373,21 +371,6 @@ class SamlHandler(BaseHandler):
del self._outstanding_requests_dict[reqid]
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
- values = ava.get(req.attribute, [])
- for v in values:
- if v == req.value:
- return True
-
- logger.info(
- "SAML2 attribute %s did not match required value '%s' (was '%s')",
- req.attribute,
- req.value,
- values,
- )
- return False
-
-
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
@@ -468,7 +451,8 @@ class DefaultSamlMappingProvider:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
- "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ "SAML2 response lacks a '%s' attestation",
+ self._mxid_source_attribute,
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 84af2dde7e..cef6b3ae48 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index b450668f1c..514b1f69d8 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc
import logging
from typing import (
TYPE_CHECKING,
+ Any,
Awaitable,
Callable,
Dict,
Iterable,
+ List,
Mapping,
Optional,
Set,
@@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
@@ -324,7 +327,8 @@ class SsoHandler:
# Check if we already have a mapping for this user.
previously_registered_user_id = await self._store.get_user_by_external_id(
- auth_provider_id, remote_user_id,
+ auth_provider_id,
+ remote_user_id,
)
# A match was found, return the user ID.
@@ -413,7 +417,8 @@ class SsoHandler:
with await self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user
user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
+ auth_provider_id,
+ remote_user_id,
)
# Check for grandfathering of users.
@@ -458,7 +463,8 @@ class SsoHandler:
)
async def _call_attribute_mapper(
- self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ self,
+ sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
) -> UserAttributes:
"""Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
@@ -629,7 +635,8 @@ class SsoHandler:
"""
user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
+ auth_provider_id,
+ remote_user_id,
)
user_id_to_verify = await self._auth_handler.get_session_data(
@@ -668,7 +675,8 @@ class SsoHandler:
# render an error page.
html = self._bad_user_template.render(
- server_name=self._server_name, user_id_to_verify=user_id_to_verify,
+ server_name=self._server_name,
+ user_id_to_verify=user_id_to_verify,
)
respond_with_html(request, 200, html)
@@ -692,7 +700,9 @@ class SsoHandler:
raise SynapseError(400, "unknown session")
async def check_username_availability(
- self, localpart: str, session_id: str,
+ self,
+ localpart: str,
+ session_id: str,
) -> bool:
"""Handle an "is username available" callback check
@@ -742,7 +752,11 @@ class SsoHandler:
use_display_name: whether the user wants to use the suggested display name
emails_to_use: emails that the user would like to use
"""
- session = self.get_mapping_session(session_id)
+ try:
+ session = self.get_mapping_session(session_id)
+ except SynapseError as e:
+ self.render_error(request, "bad_session", e.msg, code=e.code)
+ return
# update the session with the user's choices
session.chosen_localpart = localpart
@@ -793,7 +807,12 @@ class SsoHandler:
session_id,
terms_version,
)
- session = self.get_mapping_session(session_id)
+ try:
+ session = self.get_mapping_session(session_id)
+ except SynapseError as e:
+ self.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
session.terms_accepted_version = terms_version
# we're done; now we can register the user
@@ -808,7 +827,11 @@ class SsoHandler:
request: HTTP request
session_id: ID of the username mapping session, extracted from a cookie
"""
- session = self.get_mapping_session(session_id)
+ try:
+ session = self.get_mapping_session(session_id)
+ except SynapseError as e:
+ self.render_error(request, "bad_session", e.msg, code=e.code)
+ return
logger.info(
"[session %s] Registering localpart %s",
@@ -817,7 +840,8 @@ class SsoHandler:
)
attributes = UserAttributes(
- localpart=session.chosen_localpart, emails=session.emails_to_use,
+ localpart=session.chosen_localpart,
+ emails=session.emails_to_use,
)
if session.use_display_name:
@@ -880,6 +904,41 @@ class SsoHandler:
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]
+ def check_required_attributes(
+ self,
+ request: SynapseRequest,
+ attributes: Mapping[str, List[Any]],
+ attribute_requirements: Iterable[SsoAttributeRequirement],
+ ) -> bool:
+ """
+ Confirm that the required attributes were present in the SSO response.
+
+ If all requirements are met, this will return True.
+
+ If any requirement is not met, then the request will be finalized by
+ showing an error page to the user and False will be returned.
+
+ Args:
+ request: The request to (potentially) respond to.
+ attributes: The attributes from the SSO IdP.
+ attribute_requirements: The requirements that attributes must meet.
+
+ Returns:
+ True if all requirements are met, False if any attribute fails to
+ meet the requirement.
+
+ """
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ for requirement in attribute_requirements:
+ if not _check_attribute_requirement(attributes, requirement):
+ self.render_error(
+ request, "unauthorised", "You are not authorised to log in here."
+ )
+ return False
+
+ return True
+
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
@@ -890,3 +949,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")
+
+
+def _check_attribute_requirement(
+ attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
+) -> bool:
+ """Check if SSO attributes meet the proper requirements.
+
+ Args:
+ attributes: A mapping of attributes to an iterable of one or more values.
+ requirement: The configured requirement to check.
+
+ Returns:
+ True if the required attribute was found and had a proper value.
+ """
+ if req.attribute not in attributes:
+ logger.info("SSO attribute missing: %s", req.attribute)
+ return False
+
+ # If the requirement is None, the attribute existing is enough.
+ if req.value is None:
+ return True
+
+ values = attributes[req.attribute]
+ if req.value in values:
+ return True
+
+ logger.info(
+ "SSO attribute %s did not match required value '%s' (was '%s')",
+ req.attribute,
+ req.value,
+ values,
+ )
+ return False
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index d261d7cd4e..388dec5831 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2020 Sorunome
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,8 +65,7 @@ class StatsHandler:
self.clock.call_later(0, self.notify_new_event)
def notify_new_event(self) -> None:
- """Called when there may be more deltas to process
- """
+ """Called when there may be more deltas to process"""
if not self.stats_enabled or self._is_processing:
return
@@ -232,6 +233,8 @@ class StatsHandler:
room_stats_delta["left_members"] -= 1
elif prev_membership == Membership.BAN:
room_stats_delta["banned_members"] -= 1
+ elif prev_membership == Membership.KNOCK:
+ room_stats_delta["knocked_members"] -= 1
else:
raise ValueError(
"%r is not a valid prev_membership" % (prev_membership,)
@@ -253,6 +256,8 @@ class StatsHandler:
room_stats_delta["left_members"] += 1
elif membership == Membership.BAN:
room_stats_delta["banned_members"] += 1
+ elif membership == Membership.KNOCK:
+ room_stats_delta["knocked_members"] += 1
else:
raise ValueError("%r is not a valid membership" % (membership,))
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5c7590f38e..fa6794734b 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -151,6 +151,16 @@ class InvitedSyncResult:
@attr.s(slots=True, frozen=True)
+class KnockedSyncResult:
+ room_id = attr.ib(type=str)
+ knock = attr.ib(type=EventBase)
+
+ def __bool__(self) -> bool:
+ """Knocked rooms should always be reported to the client"""
+ return True
+
+
+@attr.s(slots=True, frozen=True)
class GroupsSyncResult:
join = attr.ib(type=JsonDict)
invite = attr.ib(type=JsonDict)
@@ -183,6 +193,7 @@ class _RoomChanges:
room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
invited = attr.ib(type=List[InvitedSyncResult])
+ knocked = attr.ib(type=List[KnockedSyncResult])
newly_joined_rooms = attr.ib(type=List[str])
newly_left_rooms = attr.ib(type=List[str])
@@ -196,6 +207,7 @@ class SyncResult:
account_data: List of account_data events for the user.
joined: JoinedSyncResult for each joined room.
invited: InvitedSyncResult for each invited room.
+ knocked: KnockedSyncResult for each knocked on room.
archived: ArchivedSyncResult for each archived room.
to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed
@@ -211,6 +223,7 @@ class SyncResult:
account_data = attr.ib(type=List[JsonDict])
joined = attr.ib(type=List[JoinedSyncResult])
invited = attr.ib(type=List[InvitedSyncResult])
+ knocked = attr.ib(type=List[KnockedSyncResult])
archived = attr.ib(type=List[ArchivedSyncResult])
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
@@ -227,6 +240,7 @@ class SyncResult:
self.presence
or self.joined
or self.invited
+ or self.knocked
or self.archived
or self.account_data
or self.to_device
@@ -339,8 +353,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
- """Get the sync for client needed to match what the server has now.
- """
+ """Get the sync for client needed to match what the server has now."""
return await self.generate_sync_result(sync_config, since_token, full_state)
async def push_rules_for_user(self, user: UserID) -> JsonDict:
@@ -564,7 +577,7 @@ class SyncHandler:
stream_position: StreamToken,
state_filter: StateFilter = StateFilter.all(),
) -> StateMap[str]:
- """ Get the room state at a particular stream position
+ """Get the room state at a particular stream position
Args:
room_id: room for which to get state
@@ -598,7 +611,7 @@ class SyncHandler:
state: MutableStateMap[EventBase],
now_token: StreamToken,
) -> Optional[JsonDict]:
- """ Works out a room summary block for this room, summarising the number
+ """Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds
state events to 'state' if needed to describe the heroes.
@@ -743,7 +756,7 @@ class SyncHandler:
now_token: StreamToken,
full_state: bool,
) -> MutableStateMap[EventBase]:
- """ Works out the difference in state between the start of the timeline
+ """Works out the difference in state between the start of the timeline
and the previous sync.
Args:
@@ -820,8 +833,10 @@ class SyncHandler:
)
elif batch.limited:
if batch:
- state_at_timeline_start = await self.state_store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ state_at_timeline_start = (
+ await self.state_store.get_state_ids_for_event(
+ batch.events[0].event_id, state_filter=state_filter
+ )
)
else:
# We can get here if the user has ignored the senders of all
@@ -955,8 +970,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
- """Generates a sync result.
- """
+ """Generates a sync result."""
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
@@ -999,7 +1013,7 @@ class SyncHandler:
res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
+ newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@@ -1008,7 +1022,9 @@ class SyncHandler:
if self.hs_config.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
- sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
+ sync_result_builder,
+ newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users,
)
logger.debug("Fetching to-device data")
@@ -1017,7 +1033,7 @@ class SyncHandler:
device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
- newly_joined_or_invited_users=newly_joined_or_invited_users,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@@ -1030,8 +1046,8 @@ class SyncHandler:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
- unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
- user_id, device_id
+ unused_fallback_key_types = (
+ await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
logger.debug("Fetching group data")
@@ -1051,6 +1067,7 @@ class SyncHandler:
account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined,
invited=sync_result_builder.invited,
+ knocked=sync_result_builder.knocked,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
@@ -1110,7 +1127,7 @@ class SyncHandler:
self,
sync_result_builder: "SyncResultBuilder",
newly_joined_rooms: Set[str],
- newly_joined_or_invited_users: Set[str],
+ newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
) -> DeviceLists:
@@ -1119,8 +1136,9 @@ class SyncHandler:
Args:
sync_result_builder
newly_joined_rooms: Set of rooms user has joined since previous sync
- newly_joined_or_invited_users: Set of users that have joined or
- been invited to a room since previous sync.
+ newly_joined_or_invited_or_knocked_users: Set of users that have joined,
+ been invited to a room or are knocking on a room since
+ previous sync.
newly_left_rooms: Set of rooms user has left since previous sync
newly_left_users: Set of users that have left a room we're in since
previous sync
@@ -1131,7 +1149,9 @@ class SyncHandler:
# We're going to mutate these fields, so lets copy them rather than
# assume they won't get used later.
- newly_joined_or_invited_users = set(newly_joined_or_invited_users)
+ newly_joined_or_invited_or_knocked_users = set(
+ newly_joined_or_invited_or_knocked_users
+ )
newly_left_users = set(newly_left_users)
if since_token and since_token.device_list_key:
@@ -1170,14 +1190,16 @@ class SyncHandler:
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
joined_users = await self.state.get_current_users_in_room(room_id)
- newly_joined_or_invited_users.update(joined_users)
+ newly_joined_or_invited_or_knocked_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
- users_that_have_changed.update(newly_joined_or_invited_users)
+ users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
- user_signatures_changed = await self.store.get_users_whose_signatures_changed(
- user_id, since_token.device_list_key
+ user_signatures_changed = (
+ await self.store.get_users_whose_signatures_changed(
+ user_id, since_token.device_list_key
+ )
)
users_that_have_changed.update(user_signatures_changed)
@@ -1393,8 +1415,10 @@ class SyncHandler:
logger.debug("no-oping sync")
return set(), set(), set(), set()
- ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
- AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+ ignored_account_data = (
+ await self.store.get_global_account_data_by_type_for_user(
+ AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+ )
)
# If there is ignored users account data and it matches the proper type,
@@ -1419,6 +1443,7 @@ class SyncHandler:
room_entries = room_changes.room_entries
invited = room_changes.invited
+ knocked = room_changes.knocked
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
@@ -1439,9 +1464,10 @@ class SyncHandler:
await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
+ sync_result_builder.knocked.extend(knocked)
- # Now we want to get any newly joined or invited users
- newly_joined_or_invited_users = set()
+ # Now we want to get any newly joined, invited or knocking users
+ newly_joined_or_invited_or_knocked_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@@ -1453,19 +1479,22 @@ class SyncHandler:
if (
event.membership == Membership.JOIN
or event.membership == Membership.INVITE
+ or event.membership == Membership.KNOCK
):
- newly_joined_or_invited_users.add(event.state_key)
+ newly_joined_or_invited_or_knocked_users.add(
+ event.state_key
+ )
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
- newly_left_users -= newly_joined_or_invited_users
+ newly_left_users -= newly_joined_or_invited_or_knocked_users
return (
set(newly_joined_rooms),
- newly_joined_or_invited_users,
+ newly_joined_or_invited_or_knocked_users,
set(newly_left_rooms),
newly_left_users,
)
@@ -1499,8 +1528,7 @@ class SyncHandler:
async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
- """Gets the the changes that have happened since the last sync.
- """
+ """Gets the the changes that have happened since the last sync."""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
@@ -1521,6 +1549,7 @@ class SyncHandler:
newly_left_rooms = []
room_entries = []
invited = []
+ knocked = []
for room_id, events in mem_change_events_by_room_id.items():
logger.debug(
"Membership changes in %s: [%s]",
@@ -1600,9 +1629,17 @@ class SyncHandler:
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
if event.sender not in ignored_users:
- room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
- if room_sync:
- invited.append(room_sync)
+ invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+ if invite_room_sync:
+ invited.append(invite_room_sync)
+
+ # Only bother if our latest membership in the room is knock (and we haven't
+ # been accepted/rejected in the meantime).
+ should_knock = non_joins[-1].membership == Membership.KNOCK
+ if should_knock:
+ knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
+ if knock_room_sync:
+ knocked.append(knock_room_sync)
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
@@ -1706,7 +1743,9 @@ class SyncHandler:
)
room_entries.append(entry)
- return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
+ return _RoomChanges(
+ room_entries, invited, knocked, newly_joined_rooms, newly_left_rooms,
+ )
async def _get_all_rooms(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
@@ -1726,6 +1765,7 @@ class SyncHandler:
membership_list = (
Membership.INVITE,
+ Membership.KNOCK,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
@@ -1737,6 +1777,7 @@ class SyncHandler:
room_entries = []
invited = []
+ knocked = []
for event in room_list:
if event.membership == Membership.JOIN:
@@ -1756,8 +1797,11 @@ class SyncHandler:
continue
invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
+ elif event.membership == Membership.KNOCK:
+ knock = await self.store.get_event(event.event_id)
+ knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock))
elif event.membership in (Membership.LEAVE, Membership.BAN):
- # Always send down rooms we were banned or kicked from.
+ # Always send down rooms we were banned from or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if user_id == event.sender:
@@ -1778,7 +1822,7 @@ class SyncHandler:
)
)
- return _RoomChanges(room_entries, invited, [], [])
+ return _RoomChanges(room_entries, invited, knocked, [], [])
async def _generate_room_entry(
self,
@@ -2067,6 +2111,7 @@ class SyncResultBuilder:
account_data (list)
joined (list[JoinedSyncResult])
invited (list[InvitedSyncResult])
+ knocked (list[KnockedSyncResult])
archived (list[ArchivedSyncResult])
groups (GroupsSyncResult|None)
to_device (list)
@@ -2082,6 +2127,7 @@ class SyncResultBuilder:
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
+ knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3f0dfc7a74..096d199f4c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -61,7 +61,8 @@ class FollowerTypingHandler:
if hs.config.worker.writers.typing != hs.get_instance_name():
hs.get_federation_registry().register_instance_for_edu(
- "m.typing", hs.config.worker.writers.typing,
+ "m.typing",
+ hs.config.worker.writers.typing,
)
# map room IDs to serial numbers
@@ -76,8 +77,7 @@ class FollowerTypingHandler:
self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self) -> None:
- """Reset the typing handler's data caches.
- """
+ """Reset the typing handler's data caches."""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
@@ -149,8 +149,7 @@ class FollowerTypingHandler:
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None:
- """Should be called whenever we receive updates for typing stream.
- """
+ """Should be called whenever we receive updates for typing stream."""
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 8aedf5072e..1a8340000a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler):
return results
def notify_new_event(self) -> None:
- """Called when there may be more deltas to process
- """
+ """Called when there may be more deltas to process"""
if not self.update_user_directory:
return
@@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler):
)
async def handle_user_deactivated(self, user_id: str) -> None:
- """Called when a user ID is deactivated
- """
+ """Called when a user ID is deactivated"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
await self.store.remove_from_user_dir(user_id)
@@ -145,6 +143,10 @@ class UserDirectoryHandler(StateDeltasHandler):
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
+ # If still None then the initial background update hasn't happened yet.
+ if self.pos is None:
+ return None
+
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
@@ -172,8 +174,7 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos)
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
- """Called with the state deltas to process
- """
+ """Called with the state deltas to process"""
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
|