diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index da0996edbc..dfe26dea6d 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -37,7 +37,7 @@ def request_registration(
exit=sys.exit,
):
- url = "%s/_matrix/client/r0/admin/register" % (server_location,)
+ url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
# Get the nonce
r = requests.get(url, verify=False)
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index bfcaf68b2a..bc73982dc2 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -76,7 +76,7 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
- self._account_validity = hs.config.account_validity
+ self._account_validity_enabled = hs.config.account_validity_enabled
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
@@ -189,7 +189,7 @@ class Auth:
access_token = self.get_access_token_from_request(request)
- user_id, app_service = await self._get_appservice_user_id(request)
+ user_id, app_service = self._get_appservice_user_id(request)
if user_id:
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
@@ -219,7 +219,7 @@ class Auth:
shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
- if self._account_validity.enabled and not allow_expired:
+ if self._account_validity_enabled and not allow_expired:
if await self.store.is_account_expired(
user_info.user_id, self.clock.time_msec()
):
@@ -265,10 +265,11 @@ class Auth:
except KeyError:
raise MissingClientTokenError()
- async def _get_appservice_user_id(self, request):
+ def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
+
if app_service is None:
return None, None
@@ -286,8 +287,12 @@ class Auth:
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (await self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
+ # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
+ # if not (yield self.store.get_user_by_id(user_id)):
+ # raise AuthError(
+ # 403,
+ # "Application service has not registered this user"
+ # )
return user_id, app_service
async def get_user_by_access_token(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index d8fafd7cb8..9c227218e0 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
+from synapse.types import Requester
logger = logging.getLogger(__name__)
@@ -33,24 +35,47 @@ class AuthBlocking:
self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+ self._server_name = hs.hostname
- async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
+ async def check_auth_blocking(
+ self,
+ user_id: Optional[str] = None,
+ threepid: Optional[dict] = None,
+ user_type: Optional[str] = None,
+ requester: Optional[Requester] = None,
+ ):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
- user_id(str|None): If present, checks for presence against existing
+ user_id: If present, checks for presence against existing
MAU cohort
- threepid(dict|None): If present, checks for presence against configured
+ threepid: If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
- user_type(str|None): If present, is used to decide whether to check against
+ user_type: If present, is used to decide whether to check against
certain blocking reasons like MAU.
+
+ requester: If present, and the authenticated entity is a user, checks for
+ presence against existing MAU cohort. Passing in both a `user_id` and
+ `requester` is an error.
"""
+ if requester and user_id:
+ raise Exception(
+ "Passed in both 'user_id' and 'requester' to 'check_auth_blocking'"
+ )
+
+ if requester:
+ if requester.authenticated_entity.startswith("@"):
+ user_id = requester.authenticated_entity
+ elif requester.authenticated_entity == self._server_name:
+ # We never block the server from doing actions on behalf of
+ # users.
+ return
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cd6670d0a2..90bb01f414 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 9c8dc785c6..895b38ae76 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -32,6 +32,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.util.async_helpers import Linearizer
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
@@ -244,6 +245,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
+ @wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
@@ -254,7 +256,13 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
sdnotify(b"READY=1")
- signal.signal(signal.SIGHUP, handle_sighup)
+ # We defer running the sighup handlers until next reactor tick. This
+ # is so that we're in a sane state, e.g. flushing the logs may fail
+ # if the sighup happens in the middle of writing a log entry.
+ def run_sighup(*args, **kwargs):
+ hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
+
+ signal.signal(signal.SIGHUP, run_sighup)
register_sighup(refresh_certificate, hs)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1b511890aa..e3595a19d1 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -161,7 +161,7 @@ class PresenceStatusStubServlet(RestServlet):
async def on_GET(self, request, user_id):
await self.auth.get_user_by_req(request)
- return 200, {"presence": "offline"}
+ return 200, {"presence": "offline", "user_id": user_id}
async def on_PUT(self, request, user_id):
await self.auth.get_user_by_req(request)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 85f65da4d9..33a917a035 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -22,6 +22,7 @@ import time
import urllib.parse
from collections import OrderedDict
from hashlib import sha256
+from io import open as io_open
from textwrap import dedent
from typing import Any, Callable, List, MutableMapping, Optional
@@ -190,7 +191,7 @@ class Config:
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
+ with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
def read_templates(
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index b8faafa9bd..842a96a478 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,6 +1,7 @@
from typing import Any, List, Optional
from synapse.config import (
+ account_validity,
api,
appservice,
captcha,
@@ -53,6 +54,7 @@ class RootConfig:
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
registration: registration.RegistrationConfig
+ account_validity: account_validity.AccountValidityConfig
metrics: metrics.MetricsConfig
api: api.ApiConfig
appservice: appservice.AppServiceConfig
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
new file mode 100644
index 0000000000..ea3e55c2e8
--- /dev/null
+++ b/synapse/config/account_validity.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.config._base import Config, ConfigError
+
+
+class AccountValidityConfig(Config):
+ section = "account_validity"
+
+ def read_config(self, config, **kwargs):
+ account_validity_config = config.get("account_validity") or {}
+ self.account_validity_enabled = account_validity_config.get("enabled", False)
+ self.account_validity_renew_by_email_enabled = (
+ "renew_at" in account_validity_config
+ )
+
+ if self.account_validity_enabled:
+ if "period" in account_validity_config:
+ self.account_validity_period = self.parse_duration(
+ account_validity_config["period"]
+ )
+ else:
+ raise ConfigError("'period' is required when using account validity")
+
+ if "renew_at" in account_validity_config:
+ self.account_validity_renew_at = self.parse_duration(
+ account_validity_config["renew_at"]
+ )
+
+ if "renew_email_subject" in account_validity_config:
+ self.account_validity_renew_email_subject = account_validity_config[
+ "renew_email_subject"
+ ]
+ else:
+ self.account_validity_renew_email_subject = "Renew your %(app)s account"
+
+ self.account_validity_startup_job_max_delta = (
+ self.account_validity_period * 10.0 / 100.0
+ )
+
+ if self.account_validity_renew_by_email_enabled:
+ if not self.public_baseurl:
+ raise ConfigError("Can't send renewal emails without 'public_baseurl'")
+
+ # Load account validity templates.
+ # We do this here instead of in AccountValidityConfig as read_templates
+ # relies on state that hasn't been initialised in AccountValidityConfig
+ account_renewed_template_filename = account_validity_config.get(
+ "account_renewed_html_path", "account_renewed.html"
+ )
+ account_previously_renewed_template_filename = account_validity_config.get(
+ "account_previously_renewed_html_path", "account_previously_renewed.html"
+ )
+ invalid_token_template_filename = account_validity_config.get(
+ "invalid_token_html_path", "invalid_token.html"
+ )
+ (
+ self.account_validity_account_renewed_template,
+ self.account_validity_account_previously_renewed_template,
+ self.account_validity_invalid_token_template,
+ ) = self.read_templates(
+ [
+ account_renewed_template_filename,
+ account_previously_renewed_template_filename,
+ invalid_token_template_filename,
+ ]
+ )
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ ## Account Validity ##
+ #
+ # Optional account validity configuration. This allows for accounts to be denied
+ # any request after a given period.
+ #
+ # Once this feature is enabled, Synapse will look for registered users without an
+ # expiration date at startup and will add one to every account it found using the
+ # current settings at that time.
+ # This means that, if a validity period is set, and Synapse is restarted (it will
+ # then derive an expiration date from the current validity period), and some time
+ # after that the validity period changes and Synapse is restarted, the users'
+ # expiration dates won't be updated unless their account is manually renewed. This
+ # date will be randomly selected within a range [now + period - d ; now + period],
+ # where d is equal to 10% of the validity period.
+ #
+ account_validity:
+ # The account validity feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # The period after which an account is valid after its registration. When
+ # renewing the account, its validity period will be extended by this amount
+ # of time. This parameter is required when using the account validity
+ # feature.
+ #
+ #period: 6w
+
+ # The amount of time before an account's expiry date at which Synapse will
+ # send an email to the account's email address with a renewal link. By
+ # default, no such emails are sent.
+ #
+ # If you enable this setting, you will also need to fill out the 'email' and
+ # 'public_baseurl' configuration sections.
+ #
+ #renew_at: 1w
+
+ # The subject of the email sent out with the renewal link. '%(app)s' can be
+ # used as a placeholder for the 'app_name' parameter from the 'email'
+ # section.
+ #
+ # Note that the placeholder must be written '%(app)s', including the
+ # trailing 's'.
+ #
+ # If this is not set, a default value is used.
+ #
+ #renew_email_subject: "Renew your %(app)s account"
+
+ # Directory in which Synapse will try to find templates for the HTML files to
+ # serve to the user when trying to renew an account. If not set, default
+ # templates from within the Synapse package will be used.
+ #
+ #template_dir: "res/templates"
+
+ # File within 'template_dir' giving the HTML to be displayed to the user after
+ # they successfully renewed their account. If not set, default text is used.
+ #
+ #account_renewed_html_path: "account_renewed.html"
+
+ # File within 'template_dir' giving the HTML to be displayed when the user
+ # tries to renew an account with an invalid renewal token. If not set,
+ # default text is used.
+ #
+ #invalid_token_html_path: "invalid_token.html"
+ """
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index cceffbfee2..ad96c55798 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -299,7 +299,7 @@ class EmailConfig(Config):
"client_base_url", email_config.get("riot_base_url", None)
)
- if self.account_validity.renew_by_email_enabled:
+ if self.account_validity_renew_by_email_enabled:
expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html"
)
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index be65554524..fe8f2dd524 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -13,8 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from ._base import RootConfig
+from .account_validity import AccountValidityConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .cache import CacheConfig
@@ -67,6 +67,7 @@ class HomeServerConfig(RootConfig):
CaptchaConfig,
VoipConfig,
RegistrationConfig,
+ AccountValidityConfig,
MetricsConfig,
ApiConfig,
AppServiceConfig,
diff --git a/synapse/config/password.py b/synapse/config/password.py
index 9c0ea8c30a..6b2dae78b0 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/config/push.py b/synapse/config/push.py
index a1f3752c8a..3adbfb73e6 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -21,8 +21,11 @@ class PushConfig(Config):
section = "push"
def read_config(self, config, **kwargs):
- push_config = config.get("push", {})
+ push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
+ self.push_group_unread_count_by_room = push_config.get(
+ "group_unread_count_by_room", True
+ )
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
@@ -49,18 +52,33 @@ class PushConfig(Config):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
- # Clients requesting push notifications can either have the body of
- # the message sent in the notification poke along with other details
- # like the sender, or just the event ID and room ID (`event_id_only`).
- # If clients choose the former, this option controls whether the
- # notification request includes the content of the event (other details
- # like the sender are still included). For `event_id_only` push, it
- # has no effect.
- #
- # For modern android devices the notification content will still appear
- # because it is loaded by the app. iPhone, however will send a
- # notification saying only that a message arrived and who it came from.
- #
- #push:
- # include_content: true
+ ## Push ##
+
+ push:
+ # Clients requesting push notifications can either have the body of
+ # the message sent in the notification poke along with other details
+ # like the sender, or just the event ID and room ID (`event_id_only`).
+ # If clients choose the former, this option controls whether the
+ # notification request includes the content of the event (other details
+ # like the sender are still included). For `event_id_only` push, it
+ # has no effect.
+ #
+ # For modern android devices the notification content will still appear
+ # because it is loaded by the app. iPhone, however will send a
+ # notification saying only that a message arrived and who it came from.
+ #
+ # The default value is "true" to include message details. Uncomment to only
+ # include the event ID and room ID in push notification payloads.
+ #
+ #include_content: false
+
+ # When a push notification is received, an unread count is also sent.
+ # This number can either be calculated as the number of unread messages
+ # for the user, or the number of *rooms* the user has unread messages in.
+ #
+ # The default value is "true", meaning push clients will see the number of
+ # rooms with unread messages in them. Uncomment to instead send the number
+ # of unread messages.
+ #
+ #group_unread_count_by_room: false
"""
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 14b8836197..4fca5b6d96 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -76,6 +76,9 @@ class RatelimitConfig(Config):
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {})
+ )
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
@@ -124,6 +127,8 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one that ratelimits third-party invites requests based on the account
+ # that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -153,6 +158,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
+ #
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index b0a77a2e43..770b921aff 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,75 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
from distutils.util import strtobool
-import pkg_resources
-
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols
-class AccountValidityConfig(Config):
- section = "accountvalidity"
-
- def __init__(self, config, synapse_config):
- if config is None:
- return
- super().__init__()
- self.enabled = config.get("enabled", False)
- self.renew_by_email_enabled = "renew_at" in config
-
- if self.enabled:
- if "period" in config:
- self.period = self.parse_duration(config["period"])
- else:
- raise ConfigError("'period' is required when using account validity")
-
- if "renew_at" in config:
- self.renew_at = self.parse_duration(config["renew_at"])
-
- if "renew_email_subject" in config:
- self.renew_email_subject = config["renew_email_subject"]
- else:
- self.renew_email_subject = "Renew your %(app)s account"
-
- self.startup_job_max_delta = self.period * 10.0 / 100.0
-
- if self.renew_by_email_enabled:
- if "public_baseurl" not in synapse_config:
- raise ConfigError("Can't send renewal emails without 'public_baseurl'")
-
- template_dir = config.get("template_dir")
-
- if not template_dir:
- template_dir = pkg_resources.resource_filename("synapse", "res/templates")
-
- if "account_renewed_html_path" in config:
- file_path = os.path.join(template_dir, config["account_renewed_html_path"])
-
- self.account_renewed_html_content = self.read_file(
- file_path, "account_validity.account_renewed_html_path"
- )
- else:
- self.account_renewed_html_content = (
- "<html><body>Your account has been successfully renewed.</body><html>"
- )
-
- if "invalid_token_html_path" in config:
- file_path = os.path.join(template_dir, config["invalid_token_html_path"])
-
- self.invalid_token_html_content = self.read_file(
- file_path, "account_validity.invalid_token_html_path"
- )
- else:
- self.invalid_token_html_content = (
- "<html><body>Invalid renewal token.</body><html>"
- )
-
-
class RegistrationConfig(Config):
section = "registration"
@@ -94,14 +33,21 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"]))
)
- self.account_validity = AccountValidityConfig(
- config.get("account_validity") or {}, config
- )
-
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.check_is_for_allowed_local_3pids = config.get(
+ "check_is_for_allowed_local_3pids", None
+ )
+ self.allow_invited_3pids = config.get("allow_invited_3pids", False)
+
+ self.disable_3pid_changes = config.get("disable_3pid_changes", False)
+
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
+ self.register_just_use_email_for_display_name = config.get(
+ "register_just_use_email_for_display_name", False
+ )
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
@@ -109,7 +55,21 @@ class RegistrationConfig(Config):
)
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+ if (
+ self.account_threepid_delegate_email
+ and not self.account_threepid_delegate_email.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.email must begin with http:// or https://"
+ )
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+ if (
+ self.account_threepid_delegate_msisdn
+ and not self.account_threepid_delegate_msisdn.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.msisdn must begin with http:// or https://"
+ )
if self.account_threepid_delegate_msisdn and not self.public_baseurl:
raise ConfigError(
"The configuration option `public_baseurl` is required if "
@@ -178,6 +138,15 @@ class RegistrationConfig(Config):
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
+ self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
+ if not isinstance(self.replicate_user_profiles_to, list):
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
+
+ self.shadow_server = config.get("shadow_server", None)
+ self.rewrite_identity_server_urls = (
+ config.get("rewrite_identity_server_urls") or {}
+ )
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -192,6 +161,23 @@ class RegistrationConfig(Config):
["auth_success.html"], autoescape=True
)[0]
+ self.bind_new_user_emails_to_sydent = config.get(
+ "bind_new_user_emails_to_sydent"
+ )
+
+ if self.bind_new_user_emails_to_sydent:
+ if not isinstance(
+ self.bind_new_user_emails_to_sydent, str
+ ) or not self.bind_new_user_emails_to_sydent.startswith("http"):
+ raise ConfigError(
+ "Option bind_new_user_emails_to_sydent has invalid value"
+ )
+
+ # Remove trailing slashes
+ self.bind_new_user_emails_to_sydent = self.bind_new_user_emails_to_sydent.strip(
+ "/"
+ )
+
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@@ -211,69 +197,6 @@ class RegistrationConfig(Config):
#
#enable_registration: false
- # Optional account validity configuration. This allows for accounts to be denied
- # any request after a given period.
- #
- # Once this feature is enabled, Synapse will look for registered users without an
- # expiration date at startup and will add one to every account it found using the
- # current settings at that time.
- # This means that, if a validity period is set, and Synapse is restarted (it will
- # then derive an expiration date from the current validity period), and some time
- # after that the validity period changes and Synapse is restarted, the users'
- # expiration dates won't be updated unless their account is manually renewed. This
- # date will be randomly selected within a range [now + period - d ; now + period],
- # where d is equal to 10%% of the validity period.
- #
- account_validity:
- # The account validity feature is disabled by default. Uncomment the
- # following line to enable it.
- #
- #enabled: true
-
- # The period after which an account is valid after its registration. When
- # renewing the account, its validity period will be extended by this amount
- # of time. This parameter is required when using the account validity
- # feature.
- #
- #period: 6w
-
- # The amount of time before an account's expiry date at which Synapse will
- # send an email to the account's email address with a renewal link. By
- # default, no such emails are sent.
- #
- # If you enable this setting, you will also need to fill out the 'email' and
- # 'public_baseurl' configuration sections.
- #
- #renew_at: 1w
-
- # The subject of the email sent out with the renewal link. '%%(app)s' can be
- # used as a placeholder for the 'app_name' parameter from the 'email'
- # section.
- #
- # Note that the placeholder must be written '%%(app)s', including the
- # trailing 's'.
- #
- # If this is not set, a default value is used.
- #
- #renew_email_subject: "Renew your %%(app)s account"
-
- # Directory in which Synapse will try to find templates for the HTML files to
- # serve to the user when trying to renew an account. If not set, default
- # templates from within the Synapse package will be used.
- #
- #template_dir: "res/templates"
-
- # File within 'template_dir' giving the HTML to be displayed to the user after
- # they successfully renewed their account. If not set, default text is used.
- #
- #account_renewed_html_path: "account_renewed.html"
-
- # File within 'template_dir' giving the HTML to be displayed when the user
- # tries to renew an account with an invalid renewal token. If not set,
- # default text is used.
- #
- #invalid_token_html_path: "invalid_token.html"
-
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.
@@ -296,9 +219,32 @@ class RegistrationConfig(Config):
#
#disable_msisdn_registration: true
+ # Derive the user's matrix ID from a type of 3PID used when registering.
+ # This overrides any matrix ID the user proposes when calling /register
+ # The 3PID type should be present in registrations_require_3pid to avoid
+ # users failing to register if they don't specify the right kind of 3pid.
+ #
+ #register_mxid_from_3pid: email
+
+ # Uncomment to set the display name of new users to their email address,
+ # rather than using the default heuristic.
+ #
+ #register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+ # Use an Identity Server to establish which 3PIDs are allowed to register?
+ # Overrides allowed_local_3pids below.
+ #
+ #check_is_for_allowed_local_3pids: matrix.org
+ #
+ # If you are using an IS you can also check whether that IS registers
+ # pending invites for the given 3PID (and then allow it to sign up on
+ # the platform):
+ #
+ #allow_invited_3pids: false
+ #
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
@@ -307,6 +253,11 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # If true, stop users from trying to change the 3PIDs associated with
+ # their accounts.
+ #
+ #disable_3pid_changes: false
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -338,6 +289,30 @@ class RegistrationConfig(Config):
#
#default_identity_server: https://matrix.org
+ # If enabled, user IDs, display names and avatar URLs will be replicated
+ # to this server whenever they change.
+ # This is an experimental API currently implemented by sydent to support
+ # cross-homeserver user directories.
+ #
+ #replicate_user_profiles_to: example.com
+
+ # If specified, attempt to replay registrations, profile changes & 3pid
+ # bindings on the given target homeserver via the AS API. The HS is authed
+ # via a given AS token.
+ #
+ #shadow_server:
+ # hs_url: https://shadow.example.com
+ # hs: shadow.example.com
+ # as_token: 12u394refgbdhivsia
+
+ # If enabled, don't let users set their own display names/avatars
+ # other than for the very first time (unless they are a server admin).
+ # Useful when provisioning users based on the contents of a 3rd party
+ # directory and to avoid ambiguities.
+ #
+ #disable_set_displayname: false
+ #disable_set_avatar_url: false
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -347,8 +322,9 @@ class RegistrationConfig(Config):
# email will be globally disabled.
#
# Additionally, if `msisdn` is not set, registration and password resets via msisdn
- # will be disabled regardless. This is due to Synapse currently not supporting any
- # method of sending SMS messages on its own.
+ # will be disabled regardless, and users will not be able to associate an msisdn
+ # identifier to their account. This is due to Synapse currently not supporting
+ # any method of sending SMS messages on its own.
#
# To enable using an identity server for operations regarding a particular third-party
# identifier type, set the value to the URL of that identity server as shown in the
@@ -463,6 +439,31 @@ class RegistrationConfig(Config):
# Defaults to true.
#
#auto_join_rooms_for_guests: false
+
+ # Rewrite identity server URLs with a map from one URL to another. Applies to URLs
+ # provided by clients (which have https:// prepended) and those specified
+ # in `account_threepid_delegates`. URLs should not feature a trailing slash.
+ #
+ #rewrite_identity_server_urls:
+ # "https://somewhere.example.com": "https://somewhereelse.example.com"
+
+ # When a user registers an account with an email address, it can be useful to
+ # bind that email address to their mxid on an identity server. Typically, this
+ # requires the user to validate their email address with the identity server.
+ # However if Synapse itself is handling email validation on registration, the
+ # user ends up needing to validate their email twice, which leads to poor UX.
+ #
+ # It is possible to force Sydent, one identity server implementation, to bind
+ # threepids using its internal, unauthenticated bind API:
+ # https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api
+ #
+ # Configure the address of a Sydent server here to have Synapse attempt
+ # to automatically bind users' emails following registration. The
+ # internal bind API must be reachable from Synapse, but should NOT be
+ # exposed to any third party, as it allows the creation of bindings
+ # without validation.
+ #
+ #bind_new_user_emails_to_sydent: https://example.com:8091
"""
% locals()
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index ba1e9d2361..bc02754a33 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -104,6 +104,12 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
+
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
@@ -244,6 +250,30 @@ class ContentRepositoryConfig(Config):
#
#max_upload_size: 50M
+ # The largest allowed size for a user avatar. If not defined, no
+ # restriction will be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #max_avatar_size: 10M
+
+ # Allow mimetypes for a user avatar. If not defined, no restriction will
+ # be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 2ff7dfb311..c1b8e98ae0 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -90,6 +90,8 @@ class SAML2Config(Config):
"grandfathered_mxid_source_attribute", "uid"
)
+ self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
+
# user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}
@@ -256,6 +258,12 @@ class SAML2Config(Config):
# remote:
# - url: https://our_idp/metadata.xml
+ # Allowed clock difference in seconds between the homeserver and IdP.
+ #
+ # Uncomment the below to increase the accepted time difference from 0 to 3 seconds.
+ #
+ #accepted_time_diff: 3
+
# By default, the user has to go to our login page first. If you'd like
# to allow IdP-initiated login, set 'allow_unsolicited: true' in a
# 'service.sp' section:
@@ -377,6 +385,14 @@ class SAML2Config(Config):
# value: "staff"
# - attribute: department
# value: "sales"
+
+ # If the metadata XML contains multiple IdP entities then the `idp_entityid`
+ # option must be set to the entity to redirect users to.
+ #
+ # Most deployments only have a single IdP entity and so should omit this
+ # option.
+ #
+ #idp_entityid: 'https://our_idp/entityid'
""" % {
"config_dir_path": config_dir_path
}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 85aa49c02d..61335595ff 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -276,6 +276,12 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # True.
+ self.show_users_in_user_directory = config.get(
+ "show_users_in_user_directory", True
+ )
+
retention_config = config.get("retention")
if retention_config is None:
retention_config = {}
@@ -949,6 +955,74 @@ class ServerConfig(Config):
#
#allow_per_room_profiles: false
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # 'true'.
+ #
+ #show_users_in_user_directory: false
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..43b6c40456 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -26,6 +26,7 @@ class UserDirectoryConfig(Config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
+ self.user_directory_defer_to_id_server = None
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
@@ -34,6 +35,9 @@ class UserDirectoryConfig(Config):
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
@@ -52,4 +56,9 @@ class UserDirectoryConfig(Config):
#user_directory:
# enabled: true
# search_all_users: false
+ #
+ # # If this is set, user search will be delegated to this ID server instead
+ # # of synapse performing the search itself.
+ # # This is an experimental API.
+ # defer_to_id_server: https://id.example.com
"""
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index bad18f7fdf..4404ffffe9 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,13 +15,12 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
import synapse.events
import synapse.server
@@ -60,42 +59,82 @@ class SpamChecker:
return False
def user_may_invite(
- self, inviter_userid: str, invitee_userid: str, room_id: str
+ self,
+ inviter_userid: str,
+ invitee_userid: str,
+ third_party_invite: Optional[Dict],
+ room_id: str,
+ new_room: bool,
+ published_room: bool,
) -> bool:
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- inviter_userid: The user ID of the sender of the invitation
- invitee_userid: The user ID targeted in the invitation
- room_id: The room ID
+ inviter_userid:
+ invitee_userid: The user ID of the invitee. Is None
+ if this is a third party invite and the 3PID is not bound to a
+ user ID.
+ third_party_invite: If a third party invite then is a
+ dict containing the medium and address of the invitee.
+ room_id:
+ new_room: Whether the user is being invited to the room as
+ part of a room creation, if so the invitee would have been
+ included in the call to `user_may_create_room`.
+ published_room: Whether the room the user is being invited
+ to has been published in the local homeserver's public room
+ directory.
Returns:
True if the user may send an invite, otherwise False
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ spam_checker.user_may_invite(
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
+ )
is False
):
return False
return True
- def user_may_create_room(self, userid: str) -> bool:
+ def user_may_create_room(
+ self,
+ userid: str,
+ invite_list: List[str],
+ third_party_invite_list: List[Dict],
+ cloning: bool,
+ ) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid: The ID of the user attempting to create a room
+ invite_list: List of user IDs that would be invited to
+ the new room.
+ third_party_invite_list: List of third party invites
+ for the new room.
+ cloning: Whether the user is cloning an existing room, e.g.
+ upgrading a room.
Returns:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room(userid) is False:
+ if (
+ spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
+ is False
+ ):
return False
return True
@@ -136,6 +175,25 @@ class SpamChecker:
return True
+ def user_may_join_room(self, userid: str, room_id: str, is_invited: bool):
+ """Checks if a given users is allowed to join a room.
+
+ Not called when a user creates a room.
+
+ Args:
+ userid:
+ room_id:
+ is_invited: Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ for spam_checker in self.spam_checkers:
+ if spam_checker.user_may_join_room(userid, room_id, is_invited) is False:
+ return False
+
+ return True
+
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f622e..4f7996f947 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -16,7 +16,7 @@
import logging
import urllib
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
@@ -1004,6 +1004,20 @@ class TransportLayerClient:
return self.client.get_json(destination=destination, path=path)
+ def get_info_of_users(self, destination: str, user_ids: List[str]):
+ """
+ Args:
+ destination: The remote server
+ user_ids: A list of user IDs to query info about
+
+ Returns:
+ Deferred[List]: A dictionary of User ID to information about that user.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
+ data = {"user_ids": user_ids}
+
+ return self.client.post_json(destination=destination, path=path, data=data)
+
def _create_path(federation_prefix, path, *args):
"""
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b53e7a20ec..a9b73d713a 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -31,6 +31,7 @@ from synapse.api.urls import (
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
@@ -842,6 +843,57 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
+class FederationUserInfoServlet(BaseFederationServlet):
+ """
+ Return information about a set of users.
+
+ This API returns expiration and deactivation information about a set of
+ users. Requested users not local to this homeserver will be ignored.
+
+ Example request:
+ POST /users/info
+
+ {
+ "user_ids": [
+ "@alice:example.com",
+ "@bob:example.com"
+ ]
+ }
+
+ Example response
+ {
+ "@alice:example.com": {
+ "expired": false,
+ "deactivated": true
+ }
+ }
+ """
+
+ PATH = "/users/info"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
+ super(FederationUserInfoServlet, self).__init__(
+ handler, authenticator, ratelimiter, server_name
+ )
+ self.handler = handler
+
+ async def on_POST(self, origin, content, query):
+ assert_params_in_dict(content, required=["user_ids"])
+
+ user_ids = content.get("user_ids", [])
+
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ data = await self.handler.store.get_info_for_users(user_ids)
+ return 200, data
+
+
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
@@ -1403,6 +1455,7 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bd8e71ae56..bb81c0e81d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -169,7 +169,9 @@ class BaseHandler:
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = synapse.types.create_requester(target_user, is_guest=True)
+ requester = synapse.types.create_requester(
+ target_user, is_guest=True, authenticated_entity=self.server_name
+ )
handler = self.hs.get_room_member_handler()
await handler.update_membership(
requester,
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 664d09da1c..ce97fa70d7 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,11 +18,14 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import StoreError, SynapseError
from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import UserID
from synapse.util import stringutils
@@ -40,27 +43,37 @@ class AccountValidityHandler:
self.sendmail = self.hs.get_sendmail()
self.clock = self.hs.get_clock()
- self._account_validity = self.hs.config.account_validity
+ self._account_validity_enabled = self.hs.config.account_validity_enabled
+ self._account_validity_renew_by_email_enabled = (
+ self.hs.config.account_validity_renew_by_email_enabled
+ )
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
+
+ self._account_validity_period = None
+ if self._account_validity_enabled:
+ self._account_validity_period = self.hs.config.account_validity_period
if (
- self._account_validity.enabled
- and self._account_validity.renew_by_email_enabled
+ self._account_validity_enabled
+ and self._account_validity_renew_by_email_enabled
):
# Don't do email-specific configuration if renewal by email is disabled.
self._template_html = self.config.account_validity_template_html
self._template_text = self.config.account_validity_template_text
+ account_validity_renew_email_subject = (
+ self.hs.config.account_validity_renew_email_subject
+ )
try:
app_name = self.hs.config.email_app_name
- self._subject = self._account_validity.renew_email_subject % {
- "app": app_name
- }
+ self._subject = account_validity_renew_email_subject % {"app": app_name}
self._from_string = self.hs.config.email_notif_from % {"app": app_name}
except Exception:
# If substitution failed, fall back to the bare strings.
- self._subject = self._account_validity.renew_email_subject
+ self._subject = account_validity_renew_email_subject
self._from_string = self.hs.config.email_notif_from
self._raw_from = email.utils.parseaddr(self._from_string)[1]
@@ -69,6 +82,18 @@ class AccountValidityHandler:
if hs.config.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
+ # Mark users as inactive when they expired. Check once every hour
+ if self._account_validity_enabled:
+
+ def mark_expired_users_as_inactive():
+ # run as a background process to allow async functions to work
+ return run_as_background_process(
+ "_mark_expired_users_as_inactive",
+ self._mark_expired_users_as_inactive,
+ )
+
+ self.clock.looping_call(mark_expired_users_as_inactive, 60 * 60 * 1000)
+
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time
@@ -221,47 +246,107 @@ class AccountValidityHandler:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- async def renew_account(self, renewal_token: str) -> bool:
+ async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
+ If it turns out that the token is valid but has already been used, then the
+ token is considered stale. A token is stale if the 'token_used_ts_ms' db column
+ is non-null.
+
Args:
renewal_token: Token sent with the renewal request.
Returns:
- Whether the provided token is valid.
+ A tuple containing:
+ * A bool representing whether the token is valid and unused.
+ * A bool representing whether the token is stale.
+ * An int representing the user's expiry timestamp as milliseconds since the
+ epoch, or 0 if the token was invalid.
"""
try:
- user_id = await self.store.get_user_from_renewal_token(renewal_token)
+ (
+ user_id,
+ current_expiration_ts,
+ token_used_ts,
+ ) = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- return False
+ return False, False, 0
+
+ # Check whether this token has already been used.
+ if token_used_ts:
+ logger.info(
+ "User '%s' attempted to use previously used token '%s' to renew account",
+ user_id,
+ renewal_token,
+ )
+ return False, True, current_expiration_ts
logger.debug("Renewing an account for user %s", user_id)
- await self.renew_account_for_user(user_id)
- return True
+ # Renew the account. Pass the renewal_token here so that it is not cleared.
+ # We want to keep the token around in case the user attempts to renew their
+ # account with the same token twice (clicking the email link twice).
+ #
+ # In that case, the token will be accepted, but the account's expiration ts
+ # will remain unchanged.
+ new_expiration_ts = await self.renew_account_for_user(
+ user_id, renewal_token=renewal_token
+ )
+
+ return True, False, new_expiration_ts
async def renew_account_for_user(
- self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ self,
+ user_id: str,
+ expiration_ts: Optional[int] = None,
+ email_sent: bool = False,
+ renewal_token: Optional[str] = None,
) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token: Token sent with the renewal request.
+ user_id: The ID of the user to renew.
expiration_ts: New expiration date. Defaults to now + validity period.
- email_sen: Whether an email has been sent for this validity period.
- Defaults to False.
+ email_sent: Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request. The user's token
+ will be cleared if this is None.
Returns:
New expiration date for this account, as a timestamp in
milliseconds since epoch.
"""
+ now = self.clock.time_msec()
if expiration_ts is None:
- expiration_ts = self.clock.time_msec() + self._account_validity.period
+ expiration_ts = now + self._account_validity_period
await self.store.set_account_validity_for_user(
- user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
+ user_id=user_id,
+ expiration_ts=expiration_ts,
+ email_sent=email_sent,
+ renewal_token=renewal_token,
+ token_used_ts=now,
)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ await self.profile_handler.set_active(
+ [UserID.from_string(user_id)], True, True
+ )
+
return expiration_ts
+
+ async def _mark_expired_users_as_inactive(self):
+ """Iterate over active, expired users. Mark them as inactive in order to hide them
+ from the user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get active, expired users
+ active_expired_users = await self.store.get_expired_users()
+
+ # Mark each as non-active
+ await self.profile_handler.set_active(active_expired_users, False, True)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9fc8444228..5c6458eb52 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -226,7 +226,7 @@ class ApplicationServicesHandler:
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
- logger.info("Checking interested services for %s" % (stream_key))
+ logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
# Only handle typing if we have the latest token
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 213baea2e3..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
- self.password_providers = []
- for module, config in hs.config.password_providers:
- try:
- self.password_providers.append(
- module(config=config, account_handler=account_handler)
- )
- except Exception as e:
- logger.error("Error while initializing %r: %s", module, e)
- raise
+ self.password_providers = [
+ PasswordProvider.load(module, config, account_handler)
+ for module, config in hs.config.password_providers
+ ]
- logger.info("Extra password_providers: %r", self.password_providers)
+ logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
+
+ # start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
- if self._password_enabled:
+ if hs.config.password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
+
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
+
+ if not self._password_enabled:
+ login_types.remove(LoginType.PASSWORD)
+
self._supported_login_types = login_types
+
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # Ratelimitier for failed /login attempts
+ self._failed_login_attempts_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
+
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
@@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
- # build a v1-login-style dict out of the authdict and fall back to the
- # v1 code
- user_id = authdict.get("user")
-
- if user_id is None:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
-
- (canonical_id, callback) = await self.validate_login(user_id, authdict)
+ # fall back to the v1 login flow
+ canonical_id, _ = await self.validate_login(authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -698,8 +704,12 @@ class AuthHandler(BaseHandler):
}
async def get_access_token_for_user_id(
- self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ valid_until_ms: Optional[int],
+ puppets_user_id: Optional[str] = None,
+ ) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -725,13 +735,25 @@ class AuthHandler(BaseHandler):
fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
)
- logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
+ if puppets_user_id:
+ logger.info(
+ "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
+ )
+ else:
+ logger.info(
+ "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
+ )
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
- user_id, access_token, device_id, valid_until_ms
+ user_id=user_id,
+ token=access_token,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=puppets_user_id,
)
# the device *should* have been registered before we got here; however,
@@ -808,17 +830,17 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
- self, username: str, login_submission: Dict[str, Any]
+ self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
- Also used by the user-interactive auth flow to validate
- m.login.password auth types.
+ Also used by the user-interactive auth flow to validate auth types which don't
+ have an explicit UIA handler, including m.password.auth.
Args:
- username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
+ ratelimit: whether to apply the failed_login_attempt ratelimiter
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
@@ -827,38 +849,160 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
-
- if username.startswith("@"):
- qualified_user_id = username
- else:
- qualified_user_id = UserID(username, self.hs.hostname).to_string()
-
login_type = login_submission.get("type")
- known_login_type = False
+ if not isinstance(login_type, str):
+ raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
+
+ # ideally, we wouldn't be checking the identifier unless we know we have a login
+ # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
+ #
+ # But the auth providers' check_auth interface requires a username, so in
+ # practice we can only support login methods which we can map to a username
+ # anyway.
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
-
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
- if not password:
- raise SynapseError(400, "Missing parameter: password")
+ if not isinstance(password, str):
+ raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
- for provider in self.password_providers:
- if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
- known_login_type = True
- is_valid = await provider.check_password(qualified_user_id, password)
- if is_valid:
- return qualified_user_id, None
+ # map old-school login fields into new-school "identifier" fields.
+ identifier_dict = convert_client_dict_legacy_fields_to_identifier(
+ login_submission
+ )
- if not hasattr(provider, "get_supported_login_types") or not hasattr(
- provider, "check_auth"
- ):
- # this password provider doesn't understand custom login types
- continue
+ # convert phone type identifiers to generic threepids
+ if identifier_dict["type"] == "m.id.phone":
+ identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
+
+ # convert threepid identifiers to user IDs
+ if identifier_dict["type"] == "m.id.thirdparty":
+ address = identifier_dict.get("address")
+ medium = identifier_dict.get("medium")
+
+ if medium is None or address is None:
+ raise SynapseError(400, "Invalid thirdparty identifier")
+
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
+
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ (medium, address), update=False
+ )
+
+ # Check for login providers that support 3pid login types
+ if login_type == LoginType.PASSWORD:
+ # we've already checked that there is a (valid) password field
+ assert isinstance(password, str)
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = await self.check_password_provider_3pid(medium, address, password)
+ if canonical_user_id:
+ # Authentication through password provider and 3pid succeeded
+ return canonical_user_id, callback_3pid
+
+ # No password providers were able to handle this 3pid
+ # Check local store
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ medium, address
+ )
+ if not user_id:
+ logger.warning(
+ "unknown 3pid identifier medium %s, address %r", medium, address
+ )
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ (medium, address)
+ )
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+ identifier_dict = {"type": "m.id.user", "user": user_id}
+ # by this point, the identifier should be an m.id.user: if it's anything
+ # else, we haven't understood it.
+ if identifier_dict["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+
+ username = identifier_dict.get("user")
+ if not username:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(), update=False
+ )
+
+ try:
+ return await self._validate_userid_login(username, login_submission)
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ qualified_user_id.lower()
+ )
+ raise
+
+ async def _validate_userid_login(
+ self, username: str, login_submission: Dict[str, Any],
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ """Helper for validate_login
+
+ Handles login, once we've mapped 3pids onto userids
+
+ Args:
+ username: the username, from the identifier dict
+ login_submission: the whole of the login submission
+ (including 'type' and other relevant fields)
+ Returns:
+ A tuple of the canonical user id, and optional callback
+ to be called once the access token and device id are issued
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ login_type = login_submission.get("type")
+ # we already checked that we have a valid login type
+ assert isinstance(login_type, str)
+
+ known_login_type = False
+
+ for provider in self.password_providers:
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@@ -883,15 +1027,17 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
- if isinstance(result, str):
- result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
+ # we've already checked that there is a (valid) password field
+ password = login_submission["password"]
+ assert isinstance(password, str)
+
canonical_user_id = await self._check_local_password(
- qualified_user_id, password # type: ignore
+ qualified_user_id, password
)
if canonical_user_id:
@@ -922,19 +1068,9 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
- if hasattr(provider, "check_3pid_auth"):
- # This function is able to return a deferred that either
- # resolves None, meaning authentication failure, or upon
- # success, to a str (which is the user_id) or a tuple of
- # (user_id, callback_func), where callback_func should be run
- # after we've finished everything else
- result = await provider.check_3pid_auth(medium, address, password)
- if result:
- # Check if the return value is a str or a tuple
- if isinstance(result, str):
- # If it's a str, set callback function to None
- result = (result, None)
- return result
+ result = await provider.check_3pid_auth(medium, address, password)
+ if result:
+ return result
return None, None
@@ -992,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- # This might return an awaitable, if it does block the log out
- # until it completes.
- result = provider.on_logged_out(
- user_id=user_info.user_id,
- device_id=user_info.device_id,
- access_token=access_token,
- )
- if inspect.isawaitable(result):
- await result
+ await provider.on_logged_out(
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
+ access_token=access_token,
+ )
# delete pushers associated with this access token
if user_info.token_id is not None:
@@ -1030,11 +1161,10 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- for token, token_id, device_id in tokens_and_devices:
- await provider.on_logged_out(
- user_id=user_id, device_id=device_id, access_token=token
- )
+ for token, token_id, device_id in tokens_and_devices:
+ await provider.on_logged_out(
+ user_id=user_id, device_id=device_id, access_token=token
+ )
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1358,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
+
+
+class PasswordProvider:
+ """Wrapper for a password auth provider module
+
+ This class abstracts out all of the backwards-compatibility hacks for
+ password providers, to provide a consistent interface.
+ """
+
+ @classmethod
+ def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ try:
+ pp = module(config=config, account_handler=module_api)
+ except Exception as e:
+ logger.error("Error while initializing %r: %s", module, e)
+ raise
+ return cls(pp, module_api)
+
+ def __init__(self, pp, module_api: ModuleApi):
+ self._pp = pp
+ self._module_api = module_api
+
+ self._supported_login_types = {}
+
+ # grandfather in check_password support
+ if hasattr(self._pp, "check_password"):
+ self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+ g = getattr(self._pp, "get_supported_login_types", None)
+ if g:
+ self._supported_login_types.update(g())
+
+ def __str__(self):
+ return str(self._pp)
+
+ def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+ """Get the login types supported by this password provider
+
+ Returns a map from a login type identifier (such as m.login.password) to an
+ iterable giving the fields which must be provided by the user in the submission
+ to the /login API.
+
+ This wrapper adds m.login.password to the list if the underlying password
+ provider supports the check_password() api.
+ """
+ return self._supported_login_types
+
+ async def check_auth(
+ self, username: str, login_type: str, login_dict: JsonDict
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ """Check if the user has presented valid login credentials
+
+ This wrapper also calls check_password() if the underlying password provider
+ supports the check_password() api and the login type is m.login.password.
+
+ Args:
+ username: user id presented by the client. Either an MXID or an unqualified
+ username.
+
+ login_type: the login type being attempted - one of the types returned by
+ get_supported_login_types()
+
+ login_dict: the dictionary of login secrets passed by the client.
+
+ Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+ user, and `callback` is an optional callback which will be called with the
+ result from the /login call (including access_token, device_id, etc.)
+ """
+ # first grandfather in a call to check_password
+ if login_type == LoginType.PASSWORD:
+ g = getattr(self._pp, "check_password", None)
+ if g:
+ qualified_user_id = self._module_api.get_qualified_user_id(username)
+ is_valid = await self._pp.check_password(
+ qualified_user_id, login_dict["password"]
+ )
+ if is_valid:
+ return qualified_user_id, None
+
+ g = getattr(self._pp, "check_auth", None)
+ if not g:
+ return None
+ result = await g(username, login_type, login_dict)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def check_3pid_auth(
+ self, medium: str, address: str, password: str
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ g = getattr(self._pp, "check_3pid_auth", None)
+ if not g:
+ return None
+
+ # This function is able to return a deferred that either
+ # resolves None, meaning authentication failure, or upon
+ # success, to a str (which is the user_id) or a tuple of
+ # (user_id, callback_func), where callback_func should be run
+ # after we've finished everything else
+ result = await g(medium, address, password)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def on_logged_out(
+ self, user_id: str, device_id: Optional[str], access_token: str
+ ) -> None:
+ g = getattr(self._pp, "on_logged_out", None)
+ if not g:
+ return
+
+ # This might return an awaitable, if it does block the log out
+ # until it completes.
+ result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ if inspect.isawaitable(result):
+ await result
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048a3b3c0b..f4ea0a9767 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError
@@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -31,10 +34,10 @@ class CasHandler:
Utility class for to handle the response from a CAS SSO service.
Args:
- hs (synapse.server.HomeServer)
+ hs
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@@ -200,27 +203,57 @@ class CasHandler:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
- localpart = map_username_to_mxid_localpart(username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.get_user_agent("")
+ ip_address = self.hs.get_ip_from_request(request)
+
+ # Get the matrix ID from the CAS username.
+ user_id = await self._map_cas_user_to_matrix_user(
+ username, user_display_name, user_agent, ip_address
+ )
if session:
await self._auth_handler.complete_sso_ui_auth(
- registered_user_id, session, request,
+ user_id, session, request,
)
-
else:
- if not registered_user_id:
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=user_display_name,
- user_agent_ips=(user_agent, ip_address),
- )
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url
+ user_id, request, client_redirect_url
)
+
+ async def _map_cas_user_to_matrix_user(
+ self,
+ remote_user_id: str,
+ display_name: Optional[str],
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
+ """
+ Given a CAS username, retrieve the user ID for it and possibly register the user.
+
+ Args:
+ remote_user_id: The username from the CAS response.
+ display_name: The display name from the CAS response.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+
+ Returns:
+ The user ID associated with this response.
+ """
+
+ localpart = map_username_to_mxid_localpart(remote_user_id)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+ # If the user does not exist, register it.
+ if not registered_user_id:
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=display_name,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ return registered_user_id
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4efe6c530a..366787ca33 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -38,7 +38,9 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self._server_name = hs.hostname
# Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False
@@ -48,7 +50,7 @@ class DeactivateAccountHandler(BaseHandler):
if hs.config.run_background_tasks:
hs.get_reactor().callWhenRunning(self._start_user_parting)
- self._account_validity_enabled = hs.config.account_validity.enabled
+ self._account_validity_enabled = hs.config.account_validity_enabled
async def deactivate_account(
self, user_id: str, erase_data: bool, id_server: Optional[str] = None
@@ -111,6 +113,9 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ await self._profile_handler.set_active([user], False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id)
@@ -152,7 +157,7 @@ class DeactivateAccountHandler(BaseHandler):
for room in pending_invites:
try:
await self._room_member_handler.update_membership(
- create_requester(user),
+ create_requester(user, authenticated_entity=self._server_name),
user,
room.room_id,
"leave",
@@ -208,7 +213,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("User parter parting %r from %r", user_id, room_id)
try:
await self._room_member_handler.update_membership(
- create_requester(user),
+ create_requester(user, authenticated_entity=self._server_name),
user,
room_id,
"leave",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e03caea251..f27dd84213 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -68,7 +68,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
- ReplicationStoreRoomOnInviteRestServlet,
+ ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -153,12 +153,14 @@ class FederationHandler(BaseHandler):
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs
)
- self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
+ self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
hs
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
- self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
+ self._maybe_store_room_on_outlier_membership = (
+ self.store.maybe_store_room_on_outlier_membership
+ )
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
@@ -184,7 +186,7 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("handling received PDU: %s", pdu)
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = await self.store.get_event(
@@ -299,6 +301,14 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
+ elif missing_prevs:
+ logger.info(
+ "[%s %s] Not recursively fetching %d missing prev_events: %s",
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -343,12 +353,6 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- logger.info(
- "Event %s is missing prev_events: calculating state for a "
- "backwards extremity",
- event_id,
- )
-
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
@@ -366,7 +370,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "Requesting state at missing prev_event %s", event_id,
+ "[%s %s] Requesting state at missing prev_event %s",
+ room_id,
+ event_id,
+ p,
)
with nested_logging_context(p):
@@ -400,9 +407,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
+ list(state_map.values()), get_prev_content=False,
)
event_map.update(evs)
@@ -1591,8 +1596,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = await self.store.is_room_published(event.room_id)
+
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1618,7 +1630,7 @@ class FederationHandler(BaseHandler):
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
- await self._maybe_store_room_on_invite(
+ await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=room_version
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index bc3e9607ca..b8a6b1d491 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,9 +22,11 @@ import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from synapse.api.errors import (
+ AuthError,
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
from synapse.config.emailconfig import ThreepidBehaviour
@@ -39,31 +41,36 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-id_server_scheme = "https://"
-
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super().__init__(hs)
- self.http_client = SimpleHttpClient(hs)
+ self.hs = hs
+ self.http_client = hs.get_simple_http_client()
# We create a blacklisting instance of SimpleHttpClient for contacting identity
# servers specified by clients
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
self.federation_http_client = hs.get_http_client()
- self.hs = hs
+
+ self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
+ self.trust_any_id_server_just_for_testing_do_not_use = (
+ hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
+ )
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
async def threepid_from_creds(
- self, id_server: str, creds: Dict[str, str]
+ self, id_server_url: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
- id_server: The identity server to validate 3PIDs against. Must be a
+ id_server_url: The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
creds: Dictionary containing the following keys:
* client_secret|clientSecret: A unique secret str provided by the client
@@ -88,7 +95,14 @@ class IdentityHandler(BaseHandler):
query_params = {"sid": session_id, "client_secret": client_secret}
- url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
+ url = "%s%s" % (
+ id_server_url,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
try:
data = await self.http_client.get_json(url, query_params)
@@ -97,7 +111,7 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info(
"%s returned %i for threepid validation for: %s",
- id_server,
+ id_server_url,
e.code,
creds,
)
@@ -111,7 +125,7 @@ class IdentityHandler(BaseHandler):
if "medium" in data:
return data
- logger.info("%s reported non-validated threepid: %s", id_server, creds)
+ logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
return None
async def bind_threepid(
@@ -143,14 +157,19 @@ class IdentityHandler(BaseHandler):
if id_access_token is None:
use_v2 = False
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,)
headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
else:
- bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,)
try:
# Use the blacklisting http client as this call is only to identity servers
@@ -237,9 +256,6 @@ class IdentityHandler(BaseHandler):
True on success, otherwise False if the identity
server doesn't support unbinding
"""
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
-
content = {
"mxid": mxid,
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
@@ -248,6 +264,7 @@ class IdentityHandler(BaseHandler):
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
+ url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method=b"POST",
@@ -257,6 +274,15 @@ class IdentityHandler(BaseHandler):
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,)
+
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
@@ -354,7 +380,8 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "An error was encountered when sending the email")
token_expires = (
- self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
+ self.hs.get_clock().time_msec()
+ + self.hs.config.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(
@@ -370,9 +397,28 @@ class IdentityHandler(BaseHandler):
return session_id
+ def rewrite_id_server_url(self, url: str, add_https=False) -> str:
+ """Given an identity server URL, optionally add a protocol scheme
+ before rewriting it according to the rewrite_identity_server_urls
+ config option
+
+ Adds https:// to the URL if specified, then tries to rewrite the
+ url. Returns either the rewritten URL or the URL with optional
+ protocol scheme additions.
+ """
+ rewritten_url = url
+ if add_https:
+ rewritten_url = "https://" + rewritten_url
+
+ rewritten_url = self.rewrite_identity_server_urls.get(
+ rewritten_url, rewritten_url
+ )
+ logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
+ return rewritten_url
+
async def requestEmailToken(
self,
- id_server: str,
+ id_server_url: str,
email: str,
client_secret: str,
send_attempt: int,
@@ -383,7 +429,7 @@ class IdentityHandler(BaseHandler):
validation.
Args:
- id_server: The identity server to proxy to
+ id_server_url: The identity server to proxy to
email: The email to send the message to
client_secret: The unique client_secret sends by the user
send_attempt: Which attempt this is
@@ -397,6 +443,11 @@ class IdentityHandler(BaseHandler):
"client_secret": client_secret,
"send_attempt": send_attempt,
}
+
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
if next_link:
params["next_link"] = next_link
@@ -411,7 +462,8 @@ class IdentityHandler(BaseHandler):
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
+ "%s/_matrix/identity/api/v1/validate/email/requestToken"
+ % (id_server_url,),
params,
)
return data
@@ -423,7 +475,7 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
- id_server: str,
+ id_server_url: str,
country: str,
phone_number: str,
client_secret: str,
@@ -434,7 +486,7 @@ class IdentityHandler(BaseHandler):
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
- id_server: The identity server to proxy to
+ id_server_url: The identity server to proxy to
country: The country code of the phone number
phone_number: The number to send the message to
client_secret: The unique client_secret sends by the user
@@ -462,9 +514,13 @@ class IdentityHandler(BaseHandler):
"details and update your config file."
)
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
+ "%s/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ % (id_server_url,),
params,
)
except HttpResponseException as e:
@@ -558,6 +614,86 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
+ # TODO: The following two methods are used for proxying IS requests using
+ # the CS API. They should be consolidated with those in RoomMemberHandler
+ # https://github.com/matrix-org/synapse-dinsic/issues/25
+
+ async def proxy_lookup_3pid(
+ self, id_server: str, medium: str, address: str
+ ) -> JsonDict:
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server: The server name (including port, if required)
+ of the identity server to use.
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
+ {"medium": medium, "address": address},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
+ async def proxy_bulk_lookup_3pid(
+ self, id_server: str, threepids: List[List[str]]
+ ) -> JsonDict:
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server: The server name (including port, if required)
+ of the identity server to use.
+ threepids: The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
async def lookup_3pid(
self,
id_server: str,
@@ -578,10 +714,13 @@ class IdentityHandler(BaseHandler):
Returns:
the matrix ID of the 3pid, or None if it is not recognized.
"""
+ # Rewrite id_server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
if id_access_token is not None:
try:
results = await self._lookup_3pid_v2(
- id_server, id_access_token, medium, address
+ id_server_url, id_access_token, medium, address
)
return results
@@ -599,16 +738,17 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return await self._lookup_3pid_v1(id_server, medium, address)
+ return await self._lookup_3pid_v1(id_server, id_server_url, medium, address)
async def _lookup_3pid_v1(
- self, id_server: str, medium: str, address: str
+ self, id_server: str, id_server_url: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server: The server name (including port, if required)
of the identity server to use.
+ id_server_url: The actual, reachable domain of the id server
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
@@ -616,8 +756,8 @@ class IdentityHandler(BaseHandler):
the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
{"medium": medium, "address": address},
)
@@ -634,13 +774,12 @@ class IdentityHandler(BaseHandler):
return None
async def _lookup_3pid_v2(
- self, id_server: str, id_access_token: str, medium: str, address: str
+ self, id_server_url: str, id_access_token: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
- id_server: The server name (including port, if required)
- of the identity server to use.
+ id_server_url: The protocol scheme and domain of the id server
id_access_token: The access token to authenticate to the identity server with
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
@@ -650,8 +789,8 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
+ hash_details = await self.http_client.get_json(
+ "%s/_matrix/identity/v2/hash_details" % (id_server_url,),
{"access_token": id_access_token},
)
except RequestTimedOutError:
@@ -659,15 +798,14 @@ class IdentityHandler(BaseHandler):
if not isinstance(hash_details, dict):
logger.warning(
- "Got non-dict object when checking hash details of %s%s: %s",
- id_server_scheme,
- id_server,
+ "Got non-dict object when checking hash details of %s: %s",
+ id_server_url,
hash_details,
)
raise SynapseError(
400,
- "Non-dict object from %s%s during v2 hash_details request: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Non-dict object from %s during v2 hash_details request: %s"
+ % (id_server_url, hash_details),
)
# Extract information from hash_details
@@ -681,8 +819,8 @@ class IdentityHandler(BaseHandler):
):
raise SynapseError(
400,
- "Invalid hash details received from identity server %s%s: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Invalid hash details received from identity server %s: %s"
+ % (id_server_url, hash_details),
)
# Check if any of the supported lookup algorithms are present
@@ -704,7 +842,7 @@ class IdentityHandler(BaseHandler):
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
- id_server,
+ id_server_url,
supported_lookup_algorithms,
)
raise SynapseError(
@@ -717,8 +855,8 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = await self.blacklisting_http_client.post_json_get_json(
- "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
+ lookup_results = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/v2/lookup" % (id_server_url,),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
@@ -803,15 +941,17 @@ class IdentityHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url,
}
+ # Rewrite the identity server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
- base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+ base_url = "%s/_matrix/identity" % (id_server_url,)
if id_access_token:
- key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (
+ id_server_url,
)
# Attempt a v2 lookup
@@ -830,9 +970,8 @@ class IdentityHandler(BaseHandler):
raise e
if data is None:
- key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+ id_server_url,
)
url = base_url + "/api/v1/store-invite"
@@ -844,10 +983,7 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
- "Error trying to call /store-invite on %s%s: %s",
- id_server_scheme,
- id_server,
- e,
+ "Error trying to call /store-invite on %s: %s", id_server_url, e,
)
if data is None:
@@ -860,10 +996,9 @@ class IdentityHandler(BaseHandler):
)
except HttpResponseException as e:
logger.warning(
- "Error calling /store-invite on %s%s with fallback "
+ "Error calling /store-invite on %s with fallback "
"encoding: %s",
- id_server_scheme,
- id_server,
+ id_server_url,
e,
)
raise e
@@ -884,6 +1019,42 @@ class IdentityHandler(BaseHandler):
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
+ async def bind_email_using_internal_sydent_api(
+ self, id_server_url: str, email: str, user_id: str,
+ ):
+ """Bind an email to a fully qualified user ID using the internal API of an
+ instance of Sydent.
+
+ Args:
+ id_server_url: The URL of the Sydent instance
+ email: The email address to bind
+ user_id: The user ID to bind the email to
+
+ Raises:
+ HTTPResponseException: On a non-2xx HTTP response.
+ """
+ # Extract the domain name from the IS URL as we store IS domains instead of URLs
+ id_server = urllib.parse.urlparse(id_server_url).hostname
+ if not id_server:
+ # We were unable to determine the hostname, bail out
+ return
+
+ # id_server_url is assumed to have no trailing slashes
+ url = id_server_url + "/_matrix/identity/internal/bind"
+ body = {
+ "address": email,
+ "medium": "email",
+ "mxid": user_id,
+ }
+
+ # Bind the threepid
+ await self.http_client.post_json_get_json(url, body)
+
+ # Remember where we bound the threepid
+ await self.store.add_user_bound_threepid(
+ user_id=user_id, medium="email", address=email, id_server=id_server,
+ )
+
def create_id_access_token_header(id_access_token: str) -> List[str]:
"""Create an Authorization header for passing to SimpleHttpClient as the header value
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c6791fb912..96843338ae 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -472,7 +472,7 @@ class EventCreationHandler:
Returns:
Tuple of created event, Context
"""
- await self.auth.check_auth_blocking(requester.user.to_string())
+ await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
@@ -619,7 +619,13 @@ class EventCreationHandler:
if requester.app_service is not None:
return
- user_id = requester.user.to_string()
+ user_id = requester.authenticated_entity
+ if not user_id.startswith("@"):
+ # The authenticated entity might not be a user, e.g. if it's the
+ # server puppetting the user.
+ return
+
+ user = UserID.from_string(user_id)
# exempt the system notices user
if (
@@ -639,9 +645,7 @@ class EventCreationHandler:
if u["consent_version"] == self.config.user_consent_version:
return
- consent_uri = self._consent_uri_builder.build_user_consent_uri(
- requester.user.localpart
- )
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@@ -1252,7 +1256,7 @@ class EventCreationHandler:
for user_id in members:
if not self.hs.is_mine_id(user_id):
continue
- requester = create_requester(user_id)
+ requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
event, context = await self.create_event(
requester,
@@ -1273,11 +1277,6 @@ class EventCreationHandler:
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
)
return True
- except ConsentNotGivenError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of consent. Will try another user" % (room_id, user_id)
- )
except AuthError:
logger.info(
"Failed to send dummy event into room %s for user %s due to "
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 331d4e7e96..c605f7082a 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -34,7 +35,8 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -83,17 +85,12 @@ class OidcError(Exception):
return self.error
-class MappingException(Exception):
- """Used to catch errors when mapping the UserInfo object
- """
-
-
-class OidcHandler:
+class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
+ super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@@ -120,36 +117,13 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._datastore = hs.get_datastore()
- self._clock = hs.get_clock()
- self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
-
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error. Those include
- well-known OAuth2/OIDC error types like invalid_request or
- access_denied.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def _validate_metadata(self):
"""Verifies the provider metadata.
@@ -571,7 +545,7 @@ class OidcHandler:
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
- ``self._render_error`` which displays an HTML page for the error.
+ ``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
@@ -609,7 +583,7 @@ class OidcHandler:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
- self._render_error(request, error, description)
+ self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
@@ -619,7 +593,9 @@ class OidcHandler:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
- self._render_error(request, "missing_session", "No session cookie found")
+ self._sso_handler.render_error(
+ request, "missing_session", "No session cookie found"
+ )
return
# Remove the cookie. There is a good chance that if the callback failed
@@ -637,7 +613,9 @@ class OidcHandler:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
- self._render_error(request, "invalid_request", "State parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "State parameter is missing"
+ )
return
state = request.args[b"state"][0].decode()
@@ -651,17 +629,19 @@ class OidcHandler:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
- self._render_error(request, "invalid_session", str(e))
+ self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
- self._render_error(request, "mismatching_session", str(e))
+ self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
- self._render_error(request, "invalid_request", "Code parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "Code parameter is missing"
+ )
return
logger.debug("Exchanging code")
@@ -670,7 +650,7 @@ class OidcHandler:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
- self._render_error(request, e.error, e.error_description)
+ self._sso_handler.render_error(request, e.error, e.error_description)
return
logger.debug("Successfully obtained OAuth2 access token")
@@ -683,7 +663,7 @@ class OidcHandler:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
- self._render_error(request, "fetch_error", str(e))
+ self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
@@ -691,7 +671,7 @@ class OidcHandler:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
- self._render_error(request, "invalid_token", str(e))
+ self._sso_handler.render_error(request, "invalid_token", str(e))
return
# Pull out the user-agent and IP from the request.
@@ -705,7 +685,7 @@ class OidcHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
@@ -770,7 +750,7 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -845,7 +825,7 @@ class OidcHandler:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(
@@ -885,71 +865,77 @@ class OidcHandler:
# to be strings.
remote_user_id = str(remote_user_id)
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
- )
-
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id,
+ # Older mapping providers don't accept the `failures` argument, so we
+ # try and detect support.
+ mapper_signature = inspect.signature(
+ self._user_mapping_provider.map_user_attributes
)
+ supports_failures = "failures" in mapper_signature.parameters
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
-
- try:
- attributes = await self._user_mapping_provider.map_user_attributes(
- userinfo, token
- )
- except Exception as e:
- raise MappingException(
- "Could not extract user attributes from OIDC response: " + str(e)
- )
+ async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Call the mapping provider to map the OIDC userinfo and token to user attributes.
- logger.debug(
- "Retrieved user attributes from user mapping provider: %r", attributes
- )
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ if supports_failures:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token, failures
+ )
+ else:
+ # If the mapping provider does not support processing failures,
+ # do not continually generate the same Matrix ID since it will
+ # continue to already be in use. Note that the error raised is
+ # arbitrary and will get turned into a MappingException.
+ if failures:
+ raise MappingException(
+ "Mapping provider does not support de-duplicating Matrix IDs"
+ )
- if not attributes["localpart"]:
- raise MappingException("localpart is empty")
+ attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ userinfo, token
+ )
- localpart = map_username_to_mxid_localpart(attributes["localpart"])
+ return UserAttributes(**attributes)
- user_id = UserID(localpart, self._hostname).to_string()
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
- if users:
+ async def grandfather_existing_users() -> Optional[str]:
if self._allow_existing_users:
- if len(users) == 1:
- registered_user_id = next(iter(users))
- elif user_id in users:
- registered_user_id = user_id
- else:
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, list(users.keys())
+ # If allowing existing users we want to generate a single localpart
+ # and attempt to match it.
+ attributes = await oidc_response_to_user_attributes(failures=0)
+
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
)
- )
- else:
- # This mxid is taken
- raise MappingException("mxid '{}' is already taken".format(user_id))
- else:
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id,
+
+ return previously_registered_user_id
+
+ return None
+
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ oidc_response_to_user_attributes,
+ grandfather_existing_users,
)
- return registered_user_id
-UserAttribute = TypedDict(
- "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+ "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -992,13 +978,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
+ failures: How many times a call to this function with this
+ UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1098,10 +1086,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
+
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@@ -1111,7 +1106,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- return UserAttribute(localpart=localpart, display_name=display_name)
+ return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 426b58da9e..5372753707 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -299,17 +299,22 @@ class PaginationHandler:
"""
return self._purges_by_id.get(purge_id)
- async def purge_room(self, room_id: str) -> None:
- """Purge the given room from the database"""
+ async def purge_room(self, room_id: str, force: bool = False) -> None:
+ """Purge the given room from the database.
+
+ Args:
+ room_id: room to be purged
+ force: set true to skip checking for joined users.
+ """
with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
- joined = await self.store.is_host_joined(room_id, self._server_name)
-
- if joined:
- raise SynapseError(400, "Users are still joined to this room")
+ if not force:
+ joined = await self.store.is_host_joined(room_id, self._server_name)
+ if joined:
+ raise SynapseError(400, "Users are still joined to this room")
await self.storage.purge_events.purge_room(room_id)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 8e014c9bb5..22d1e9d35c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
from prometheus_client import Counter
from typing_extensions import ContextManager
@@ -46,8 +46,7 @@ from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 74a1ddd780..40b1c25a36 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +15,11 @@
# limitations under the License.
import logging
import random
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional
+
+from signedjson.sign import sign_json
+
+from twisted.internet import reactor
from synapse.api.errors import (
AuthError,
@@ -24,7 +29,11 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import (
JsonDict,
Requester,
@@ -54,6 +63,8 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -64,11 +75,98 @@ class ProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+ self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to
+
if hs.config.run_background_tasks:
self.clock.looping_call(
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._do_assign_profile_replication_batches)
+ reactor.callWhenRunning(self._start_replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._start_replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ def _do_assign_profile_replication_batches(self):
+ return run_as_background_process(
+ "_assign_profile_replication_batches",
+ self._assign_profile_replication_batches,
+ )
+
+ def _start_replicate_profiles(self):
+ return run_as_background_process(
+ "_replicate_profiles", self._replicate_profiles
+ )
+
+ async def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = await self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ async def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = await self.store.get_replication_hosts()
+ latest_batch = await self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ await self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ async def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = await self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ await self.http_client.post_json_get_json(url, signed_body)
+ await self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info(
+ "Successfully replicated profile batch %d to %s", batchnum, host
+ )
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
@@ -206,10 +304,20 @@ class ProfileHandler(BaseHandler):
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
- requester = create_requester(target_user)
+ requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity,
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
await self.store.set_profile_displayname(
- target_user.localpart, displayname_to_set
+ target_user.localpart, displayname_to_set, new_batchnum
)
if self.hs.config.user_directory_search_all_users:
@@ -220,6 +328,46 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ async def set_active(
+ self, users: List[UserID], active: bool, hide: bool,
+ ):
+ """
+ Sets the 'active' flag on a set of user profiles. If set to false, the
+ accounts are considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the users rather than deactivating them. This means withholding the
+ profiles from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB.
+
+ Note that unlike set_displayname and set_avatar_url, this does *not*
+ perform authorization checks! This is because the only place it's used
+ currently is in account deactivation where we've already done these
+ checks anyway.
+
+ Args:
+ users: The users to modify
+ active: Whether to set the user to active or inactive
+ hide: Whether to hide the user (withold from replication). If
+ False and active is False, user will have their profile
+ erased
+ """
+ if len(self.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ await self.store.set_profiles_active(users, active, hide, new_batchnum)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user):
try:
@@ -284,11 +432,53 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
+ # Enforce a max avatar size if one is defined
+ if self.max_avatar_size or self.allowed_avatar_mimetypes:
+ media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url)
+
+ # Check that this media exists locally
+ media_info = await self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
# Same like set_displayname
if by_admin:
- requester = create_requester(target_user)
+ requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
- await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ await self.store.set_profile_avatar_url(
+ target_user.localpart, new_avatar_url, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -298,6 +488,23 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c242c409cf..153cbae7b9 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -158,7 +158,8 @@ class ReceiptEventSource:
if from_key == to_key:
return [], to_key
- # We first need to fetch all new receipts
+ # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered
+ # by most recent.
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ed1ff62599..c227c4fe91 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,10 +15,12 @@
"""Contains functions for registering clients."""
import logging
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -32,29 +34,31 @@ from synapse.types import RoomAlias, UserID, create_requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_identity_handler()
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker()
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -70,8 +74,21 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
async def check_username(
- self, localpart, guest_access_token=None, assigned_user_id=None
+ self,
+ localpart: str,
+ guest_access_token: Optional[str] = None,
+ assigned_user_id: Optional[str] = None,
):
+ """
+
+ Args:
+ localpart (str|None): The user's localpart
+ guest_access_token (str|None): A guest's access token
+ assigned_user_id (str|None): An existing User ID for this user if pre-calculated
+
+ Returns:
+ Deferred
+ """
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -114,6 +131,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
+
+ # Retrieve guest user information from provided access token
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if (
not user_data.is_guest
@@ -139,39 +158,45 @@ class RegistrationHandler(BaseHandler):
async def register_user(
self,
- localpart=None,
- password_hash=None,
- guest_access_token=None,
- make_guest=False,
- admin=False,
- threepid=None,
- user_type=None,
- default_display_name=None,
- address=None,
- bind_emails=[],
- by_admin=False,
- user_agent_ips=None,
- ):
+ localpart: Optional[str] = None,
+ password_hash: Optional[str] = None,
+ guest_access_token: Optional[str] = None,
+ make_guest: bool = False,
+ admin: bool = False,
+ threepid: Optional[dict] = None,
+ user_type: Optional[str] = None,
+ default_display_name: Optional[str] = None,
+ address: Optional[str] = None,
+ bind_emails: List[str] = [],
+ by_admin: bool = False,
+ user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+ ) -> str:
"""Registers a new client on the server.
Args:
localpart: The local part of the user ID to register. If None,
one will be generated.
- password_hash (str|None): The hashed password to assign to this user so they can
+ password_hash: The hashed password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
- user_type (str|None): type of user. One of the values from
+ guest_access_token: The access token used when this was a guest
+ account.
+ make_guest: True if the the new user should be guest,
+ false to add a regular user account.
+ admin: True if the user should be registered as a server admin.
+ threepid: The threepid used for registering, if any.
+ user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- default_display_name (unicode|None): if set, the new user's displayname
+ default_display_name: if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
- address (str|None): the IP address used to perform the registration.
- bind_emails (List[str]): list of emails to bind to this account.
- by_admin (bool): True if this registration is being made via the
+ address: the IP address used to perform the registration.
+ bind_emails: list of emails to bind to this account.
+ by_admin: True if this registration is being made via the
admin api, otherwise False.
- user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+ user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
Returns:
- str: user_id
+ The registere user_id.
Raises:
SynapseError if there was a problem registering.
"""
@@ -226,6 +251,12 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
+ if default_display_name:
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
@@ -235,8 +266,10 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
- user = None
- while not user:
+ # If a default display name is not given, generate one.
+ generate_display_name = default_display_name is None
+ # This breaks on successful registration *or* errors after 10 failures.
+ while True:
# Fail after being unable to find a suitable ID a few times
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
@@ -245,7 +278,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
- if default_display_name is None:
+ if generate_display_name:
default_display_name = localpart
try:
await self.register_with_store(
@@ -257,12 +290,15 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, default_display_name, by_admin=True
+ )
+
# Successfully registered
break
except SynapseError:
# if user id is taken, just generate another
- user = None
- user_id = None
fail_count += 1
if not self.hs.config.user_consent_at_registration:
@@ -290,11 +326,19 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- await self._register_email_threepid(user_id, threepid_dict, None)
+ await self.register_email_threepid(user_id, threepid_dict, None)
+
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ await self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ await self.profile_handler.set_active([user], False, True)
return user_id
- async def _create_and_join_rooms(self, user_id: str):
+ async def _create_and_join_rooms(self, user_id: str) -> None:
"""
Create the auto-join rooms and join or invite the user to them.
@@ -317,7 +361,8 @@ class RegistrationHandler(BaseHandler):
requires_join = False
if self.hs.config.registration.auto_join_user_id:
fake_requester = create_requester(
- self.hs.config.registration.auto_join_user_id
+ self.hs.config.registration.auto_join_user_id,
+ authenticated_entity=self._server_name,
)
# If the room requires an invite, add the user to the list of invites.
@@ -329,7 +374,9 @@ class RegistrationHandler(BaseHandler):
# being necessary this will occur after the invite was sent.
requires_join = True
else:
- fake_requester = create_requester(user_id)
+ fake_requester = create_requester(
+ user_id, authenticated_entity=self._server_name
+ )
# Choose whether to federate the new room.
if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@@ -362,7 +409,9 @@ class RegistrationHandler(BaseHandler):
# created it, then ensure the first user joins it.
if requires_join:
await room_member_handler.update_membership(
- requester=create_requester(user_id),
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
target=UserID.from_string(user_id),
room_id=info["room_id"],
# Since it was just created, there are no remote hosts.
@@ -370,15 +419,10 @@ class RegistrationHandler(BaseHandler):
action="join",
ratelimit=False,
)
-
- except ConsentNotGivenError as e:
- # Technically not necessary to pull out this error though
- # moving away from bare excepts is a good thing to do.
- logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- async def _join_rooms(self, user_id: str):
+ async def _join_rooms(self, user_id: str) -> None:
"""
Join or invite the user to the auto-join rooms.
@@ -424,9 +468,13 @@ class RegistrationHandler(BaseHandler):
# Send the invite, if necessary.
if requires_invite:
+ # If an invite is required, there must be a auto-join user ID.
+ assert self.hs.config.registration.auto_join_user_id
+
await room_member_handler.update_membership(
requester=create_requester(
- self.hs.config.registration.auto_join_user_id
+ self.hs.config.registration.auto_join_user_id,
+ authenticated_entity=self._server_name,
),
target=UserID.from_string(user_id),
room_id=room_id,
@@ -437,7 +485,9 @@ class RegistrationHandler(BaseHandler):
# Send the join.
await room_member_handler.update_membership(
- requester=create_requester(user_id),
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
target=UserID.from_string(user_id),
room_id=room_id,
remote_room_hosts=remote_room_hosts,
@@ -452,7 +502,7 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- async def _auto_join_rooms(self, user_id: str):
+ async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@@ -475,16 +525,19 @@ class RegistrationHandler(BaseHandler):
else:
await self._join_rooms(user_id)
- async def post_consent_actions(self, user_id):
+ async def post_consent_actions(self, user_id: str) -> None:
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
- user_id (str): The user to join
+ user_id: The user to join
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart, as_token):
+ async def appservice_register(
+ self, user_localpart: str, as_token: str, password_hash: str, display_name: str
+ ):
+ # FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -501,15 +554,31 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
+ display_name = display_name or user.localpart
+
await self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
)
+
+ requester = create_requester(user)
+ await self.profile_handler.set_displayname(
+ user, requester, display_name, by_admin=True
+ )
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = await self.store.get_profileinfo(user_localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
return user_id
- def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+ def check_user_id_not_appservice_exclusive(
+ self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
+ ) -> None:
# don't allow people to register the server notices mxid
if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid:
@@ -533,12 +602,43 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- def check_registration_ratelimit(self, address):
+ async def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
+ def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
Args:
- address (str|None): the IP address used to perform the registration. If this is
+ address: the IP address used to perform the registration. If this is
None, no ratelimiting will be performed.
Raises:
@@ -549,42 +649,39 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter.ratelimit(address)
- def register_with_store(
+ async def register_with_store(
self,
- user_id,
- password_hash=None,
- was_guest=False,
- make_guest=False,
- appservice_id=None,
- create_profile_with_displayname=None,
- admin=False,
- user_type=None,
- address=None,
- shadow_banned=False,
- ):
+ user_id: str,
+ password_hash: Optional[str] = None,
+ was_guest: bool = False,
+ make_guest: bool = False,
+ appservice_id: Optional[str] = None,
+ create_profile_with_displayname: Optional[str] = None,
+ admin: bool = False,
+ user_type: Optional[str] = None,
+ address: Optional[str] = None,
+ shadow_banned: bool = False,
+ ) -> None:
"""Register user in the datastore.
Args:
- user_id (str): The desired user ID to register.
- password_hash (str|None): Optional. The password hash for this user.
- was_guest (bool): Optional. Whether this is a guest account being
+ user_id: The desired user ID to register.
+ password_hash: Optional. The password hash for this user.
+ was_guest: Optional. Whether this is a guest account being
upgraded to a non-guest account.
- make_guest (boolean): True if the the new user should be guest,
+ make_guest: True if the the new user should be guest,
false to add a regular user account.
- appservice_id (str|None): The ID of the appservice registering the user.
- create_profile_with_displayname (unicode|None): Optionally create a
+ appservice_id: The ID of the appservice registering the user.
+ create_profile_with_displayname: Optionally create a
profile for the user, setting their displayname to the given value
- admin (boolean): is an admin user?
- user_type (str|None): type of user. One of the values from
+ admin: is an admin user?
+ user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- address (str|None): the IP address used to perform the registration.
- shadow_banned (bool): Whether to shadow-ban the user
-
- Returns:
- Awaitable
+ address: the IP address used to perform the registration.
+ shadow_banned: Whether to shadow-ban the user
"""
if self.hs.config.worker_app:
- return self._register_client(
+ await self._register_client(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -597,7 +694,7 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
else:
- return self.store.register_user(
+ await self.store.register_user(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -610,22 +707,24 @@ class RegistrationHandler(BaseHandler):
)
async def register_device(
- self, user_id, device_id, initial_display_name, is_guest=False
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ ) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
Args:
- user_id (str): full canonical @user:id
- device_id (str|None): The device ID to check, or None to generate
- a new one.
- initial_display_name (str|None): An optional display name for the
- device.
- is_guest (bool): Whether this is a guest account
+ user_id: full canonical @user:id
+ device_id: The device ID to check, or None to generate a new one.
+ initial_display_name: An optional display name for the device.
+ is_guest: Whether this is a guest account
Returns:
- tuple[str, str]: Tuple of device ID and access token
+ Tuple of device ID and access token
"""
if self.hs.config.worker_app:
@@ -645,7 +744,7 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
- device_id = await self.device_handler.check_device_registered(
+ registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
@@ -655,20 +754,21 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
)
- return (device_id, access_token)
+ return (registered_device_id, access_token)
- async def post_registration_actions(self, user_id, auth_result, access_token):
+ async def post_registration_actions(
+ self, user_id: str, auth_result: dict, access_token: Optional[str]
+ ) -> None:
"""A user has completed registration
Args:
- user_id (str): The user ID that consented
- auth_result (dict): The authenticated credentials of the newly
- registered user.
- access_token (str|None): The access token of the newly logged in
- device, or None if `inhibit_login` enabled.
+ user_id: The user ID that consented
+ auth_result: The authenticated credentials of the newly registered user.
+ access_token: The access token of the newly logged in device, or
+ None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
await self._post_registration_client(
@@ -678,6 +778,7 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
+
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
@@ -685,7 +786,32 @@ class RegistrationHandler(BaseHandler):
):
await self.store.upsert_monthly_active_user(user_id)
- await self._register_email_threepid(user_id, threepid, access_token)
+ await self.register_email_threepid(user_id, threepid, access_token)
+
+ if self.hs.config.bind_new_user_emails_to_sydent:
+ # Attempt to call Sydent's internal bind API on the given identity server
+ # to bind this threepid
+ id_server_url = self.hs.config.bind_new_user_emails_to_sydent
+
+ logger.debug(
+ "Attempting the bind email of %s to identity server: %s using "
+ "internal Sydent bind API.",
+ user_id,
+ self.hs.config.bind_new_user_emails_to_sydent,
+ )
+
+ try:
+ await self.identity_handler.bind_email_using_internal_sydent_api(
+ id_server_url, threepid["address"], user_id
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to bind email of '%s' to Sydent instance '%s' ",
+ "using Sydent internal bind API: %s",
+ user_id,
+ id_server_url,
+ e,
+ )
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
@@ -694,19 +820,20 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.TERMS in auth_result:
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
- async def _on_user_consented(self, user_id, consent_version):
+ async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
"""A user consented to the terms on registration
Args:
- user_id (str): The user ID that consented.
- consent_version (str): version of the policy the user has
- consented to.
+ user_id: The user ID that consented.
+ consent_version: version of the policy the user has consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def _register_email_threepid(self, user_id, threepid, token):
+ async def register_email_threepid(
+ self, user_id: str, threepid: dict, token: Optional[str]
+ ) -> None:
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
@@ -715,10 +842,9 @@ class RegistrationHandler(BaseHandler):
Must be called on master.
Args:
- user_id (str): id of user
- threepid (object): m.login.email.identity auth response
- token (str|None): access_token for the user, or None if not logged
- in.
+ user_id: id of user
+ threepid: m.login.email.identity auth response
+ token: access_token for the user, or None if not logged in.
"""
reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
@@ -744,6 +870,8 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
+ # The token better still exist.
+ assert user_tuple
token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
@@ -758,14 +886,14 @@ class RegistrationHandler(BaseHandler):
data={},
)
- async def _register_msisdn_threepid(self, user_id, threepid):
+ async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
"""Add a phone number as a 3pid identifier
Must be called on master.
Args:
- user_id (str): id of user
- threepid (object): m.login.msisdn auth response
+ user_id: id of user
+ threepid: m.login.msisdn auth response
"""
try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e73031475f..469f6cf10f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -358,7 +358,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -587,7 +599,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- await self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(requester=requester)
if (
self._server_notices_mxid is not None
@@ -608,8 +620,14 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -793,6 +811,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -810,6 +829,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
id_access_token=id_access_token,
)
@@ -886,6 +906,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=False,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
@@ -1257,7 +1278,9 @@ class RoomShutdownHandler:
400, "User must be our own: %s" % (new_room_user_id,)
)
- room_creator_requester = create_requester(new_room_user_id)
+ room_creator_requester = create_requester(
+ new_room_user_id, authenticated_entity=requester_user_id
+ )
info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
@@ -1297,7 +1320,9 @@ class RoomShutdownHandler:
try:
# Kick users from room
- target_requester = create_requester(user_id)
+ target_requester = create_requester(
+ user_id, authenticated_entity=requester_user_id
+ )
_, stream_id = await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7cd858b7db..9b15dd9951 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -31,7 +31,6 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -277,6 +276,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -317,6 +317,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
@@ -333,6 +334,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
"""Helper for update_membership.
@@ -347,7 +349,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# later on.
content = dict(content)
- if not self.allow_per_room_profiles or requester.shadow_banned:
+ # allow the server notices mxid to set room-level profile
+ is_requester_server_notices_user = (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ )
+
+ if (
+ not self.allow_per_room_profiles and not is_requester_server_notices_user
+ ) or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
@@ -401,8 +411,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
+ is_published = await self.store.is_room_published(room_id)
+
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(),
+ target.to_string(),
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -480,6 +497,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ inviter = await self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if not is_host_in_room:
time_now_s = self.clock.time()
(
@@ -515,10 +551,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- invite = await self.store.get_invite_for_local_user_in_room(
- user_id=target.to_string(), room_id=room_id
- ) # type: Optional[RoomsForUser]
- if not invite:
+ (
+ current_membership_type,
+ current_membership_event_id,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ target.to_string(), room_id
+ )
+ if (
+ current_membership_type != Membership.INVITE
+ or not current_membership_event_id
+ ):
logger.info(
"%s sent a leave request to %s, but that is not an active room "
"on this server, and there is no pending invite",
@@ -528,6 +570,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(404, "Not a known room")
+ invite = await self.store.get_event(current_membership_event_id)
logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender
)
@@ -766,6 +809,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
id_server: str,
requester: Requester,
txn_id: Optional[str],
+ new_room: bool = False,
id_access_token: Optional[str] = None,
) -> int:
"""Invite a 3PID to a room.
@@ -813,6 +857,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Codes.FORBIDDEN,
)
+ can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
if not self._enable_lookup:
raise SynapseError(
403, "Looking up third-party identifiers is denied from this server"
@@ -822,6 +876,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
id_server, medium, address, id_access_token
)
+ is_published = await self.store.is_room_published(room_id)
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
# Note that update_membership with an action of "invite" can raise
# a ShadowBanError, but this was done above already.
@@ -965,6 +1032,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor = hs.get_distributor()
self.distributor.declare("user_left_room")
+ self._server_name = hs.hostname
async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str]
@@ -1059,7 +1127,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
# The room is too large. Leave.
- requester = types.create_requester(user, None, False, False, None)
+ requester = types.create_requester(
+ user, authenticated_entity=self._server_name
+ )
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
@@ -1104,32 +1174,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
- return await self._locally_reject_invite(
+ return await self._generate_local_out_of_band_leave(
invite_event, txn_id, requester, content
)
- async def _locally_reject_invite(
+ async def _generate_local_out_of_band_leave(
self,
- invite_event: EventBase,
+ previous_membership_event: EventBase,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
- """Generate a local invite rejection
+ """Generate a local leave event for a room
- This is called after we fail to reject an invite via a remote server. It
- generates an out-of-band membership event locally.
+ This can be called after we e.g fail to reject an invite via a remote server.
+ It generates an out-of-band membership event locally.
Args:
- invite_event: the invite to be rejected
+ previous_membership_event: the previous membership event for this user
txn_id: optional transaction ID supplied by the client
- requester: user making the rejection request, according to the access token
- content: additional content to include in the rejection event.
+ requester: user making the request, according to the access token
+ content: additional content to include in the leave event.
Normally an empty dict.
- """
- room_id = invite_event.room_id
- target_user = invite_event.state_key
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event)
+ """
+ room_id = previous_membership_event.room_id
+ target_user = previous_membership_event.state_key
content["membership"] = Membership.LEAVE
@@ -1141,12 +1213,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
- # the auth events for the new event are the same as that of the invite, plus
- # the invite itself.
+ # the auth events for the new event are the same as that of the previous event, plus
+ # the event itself.
#
- # the prev_events are just the invite.
- prev_event_ids = [invite_event.event_id]
- auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
+ # the prev_events consist solely of the previous membership event.
+ prev_event_ids = [previous_membership_event.event_id]
+ auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index fd6c5e9ea8..76d4169fe2 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class MappingException(Exception):
- """Used to catch errors when mapping the SAML2 response to a user."""
-
-
@attr.s(slots=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@@ -57,17 +54,14 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
-class SamlHandler:
- def __init__(self, hs: "synapse.server.HomeServer"):
- self.hs = hs
+class SamlHandler(BaseHandler):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._auth = hs.get_auth()
+ self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._clock = hs.get_clock()
- self._datastore = hs.get_datastore()
- self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +82,9 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
+ self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -124,13 +101,13 @@ class SamlHandler:
URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
- relay_state=client_redirect_url
+ entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
)
@@ -171,12 +148,12 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
- self._render_error(
+ self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
- self._render_error(
+ self._sso_handler.render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +161,7 @@ class SamlHandler:
return
if saml2_auth.not_signed:
- self._render_error(
+ self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
@@ -210,7 +187,7 @@ class SamlHandler:
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
- self._render_error(
+ self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
@@ -226,7 +203,7 @@ class SamlHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
@@ -272,20 +249,26 @@ class SamlHandler:
"Failed to extract remote user id from SAML response"
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- # first of all, check if we already have a mapping for this user
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
+ async def saml_response_to_remapped_user_attributes(
+ failures: int,
+ ) -> UserAttributes:
+ """
+ Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ # Call the mapping provider.
+ result = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, failures, client_redirect_url
)
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id
+ # Remap some of the results.
+ return UserAttributes(
+ localpart=result.get("mxid_localpart"),
+ display_name=result.get("displayname"),
+ emails=result.get("emails", []),
)
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
+ async def grandfather_existing_users() -> Optional[str]:
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -294,75 +277,35 @@ class SamlHandler:
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
- map_username_to_mxid_localpart(attrval), self._hostname
+ map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
- logger.info(
+
+ logger.debug(
"Looking for existing account based on mapped %s %s",
self._grandfathered_mxid_source_attribute,
user_id,
)
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
- )
return registered_user_id
- # Map saml response to user attributes using the configured mapping provider
- for i in range(1000):
- attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i, client_redirect_url=client_redirect_url,
- )
-
- logger.debug(
- "Retrieved SAML attributes from user mapping provider: %s "
- "(attempt %d)",
- attribute_dict,
- i,
- )
-
- localpart = attribute_dict.get("mxid_localpart")
- if not localpart:
- raise MappingException(
- "Error parsing SAML2 response: SAML mapping provider plugin "
- "did not return a mxid_localpart value"
- )
-
- displayname = attribute_dict.get("displayname")
- emails = attribute_dict.get("emails", [])
-
- # Check if this mxid already exists
- if not await self._datastore.get_users_by_id_case_insensitive(
- UserID(localpart, self._hostname).to_string()
- ):
- # This mxid is free
- break
- else:
- # Unable to generate a username in 1000 iterations
- # Break and return error to the user
- raise MappingException(
- "Unable to generate a Matrix ID from the SAML response"
- )
+ return None
- logger.info("Mapped SAML user to local part %s", localpart)
-
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=displayname,
- bind_emails=emails,
- user_agent_ips=(user_agent, ip_address),
- )
-
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
)
- return registered_user_id
def expire_sessions(self):
- expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+ expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before:
@@ -474,11 +417,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
- # a usable mxid
- localpart = base_mxid_localpart + (str(failures) if failures else "")
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..7713c3cf91 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
new file mode 100644
index 0000000000..47ad96f97e
--- /dev/null
+++ b/synapse/handlers/sso.py
@@ -0,0 +1,244 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+
+import attr
+
+from synapse.api.errors import RedirectException
+from synapse.handlers._base import BaseHandler
+from synapse.http.server import respond_with_html
+from synapse.types import UserID, contains_invalid_mxid_characters
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MappingException(Exception):
+ """Used to catch errors when mapping an SSO response to user attributes.
+
+ Note that the msg that is raised is shown to end-users.
+ """
+
+
+@attr.s
+class UserAttributes:
+ localpart = attr.ib(type=str)
+ display_name = attr.ib(type=Optional[str], default=None)
+ emails = attr.ib(type=List[str], default=attr.Factory(list))
+
+
+class SsoHandler(BaseHandler):
+ # The number of attempts to ask the mapping provider for when generating an MXID.
+ _MAP_USERNAME_RETRIES = 1000
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self._registration_handler = hs.get_registration_handler()
+ self._error_template = hs.config.sso_error_template
+
+ def render_error(
+ self, request, error: str, error_description: Optional[str] = None
+ ) -> None:
+ """Renders the error template and responds with it.
+
+ This is used to show errors to the user. The template of this page can
+ be found under `synapse/res/templates/sso_error.html`.
+
+ Args:
+ request: The incoming request from the browser.
+ We'll respond with an HTML page describing the error.
+ error: A technical identifier for this error.
+ error_description: A human-readable description of the error.
+ """
+ html = self._error_template.render(
+ error=error, error_description=error_description
+ )
+ respond_with_html(request, 400, html)
+
+ async def get_sso_user_by_remote_user_id(
+ self, auth_provider_id: str, remote_user_id: str
+ ) -> Optional[str]:
+ """
+ Maps the user ID of a remote IdP to a mxid for a previously seen user.
+
+ If the user has not been seen yet, this will return None.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The user ID according to the remote IdP. This might
+ be an e-mail address, a GUID, or some other form. It must be
+ unique and immutable.
+
+ Returns:
+ The mxid of a previously seen user.
+ """
+ logger.debug(
+ "Looking for existing mapping for user %s:%s",
+ auth_provider_id,
+ remote_user_id,
+ )
+
+ # Check if we already have a mapping for this user.
+ previously_registered_user_id = await self.store.get_user_by_external_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ # A match was found, return the user ID.
+ if previously_registered_user_id is not None:
+ logger.info(
+ "Found existing mapping for IdP '%s' and remote_user_id '%s': %s",
+ auth_provider_id,
+ remote_user_id,
+ previously_registered_user_id,
+ )
+ return previously_registered_user_id
+
+ # No match.
+ return None
+
+ async def get_mxid_from_sso(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
+ ) -> str:
+ """
+ Given an SSO ID, retrieve the user ID for it and possibly register the user.
+
+ This first checks if the SSO ID has previously been linked to a matrix ID,
+ if it has that matrix ID is returned regardless of the current mapping
+ logic.
+
+ If a callable is provided for grandfathering users, it is called and can
+ potentially return a matrix ID to use. If it does, the SSO ID is linked to
+ this matrix ID for subsequent calls.
+
+ The mapping function is called (potentially multiple times) to generate
+ a localpart for the user.
+
+ If an unused localpart is generated, the user is registered from the
+ given user-agent and IP address and the SSO ID is linked to this matrix
+ ID for subsequent calls.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+ sso_to_matrix_id_mapper: A callable to generate the user attributes.
+ The only parameter is an integer which represents the amount of
+ times the returned mxid localpart mapping has failed.
+
+ It is expected that the mapper can raise two exceptions, which
+ will get passed through to the caller:
+
+ MappingException if there was a problem mapping the response
+ to the user.
+ RedirectException to redirect to an additional page (e.g.
+ to prompt the user for more information).
+ grandfather_existing_users: A callable which can return an previously
+ existing matrix ID. The SSO ID is then linked to the returned
+ matrix ID.
+
+ Returns:
+ The user ID associated with the SSO response.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: if the mapping provider needs to redirect the user
+ to an additional page. (e.g. to prompt for more information)
+
+ """
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+ if previously_registered_user_id:
+ return previously_registered_user_id
+
+ # Check for grandfathering of users.
+ if grandfather_existing_users:
+ previously_registered_user_id = await grandfather_existing_users()
+ if previously_registered_user_id:
+ # Future logins should also match this user ID.
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+ return previously_registered_user_id
+
+ # Otherwise, generate a new user.
+ for i in range(self._MAP_USERNAME_RETRIES):
+ try:
+ attributes = await sso_to_matrix_id_mapper(i)
+ except (RedirectException, MappingException):
+ # Mapping providers are allowed to issue a redirect (e.g. to ask
+ # the user for more information) and can issue a mapping exception
+ # if a name cannot be generated.
+ raise
+ except Exception as e:
+ # Any other exception is unexpected.
+ raise MappingException(
+ "Could not extract user attributes from SSO response."
+ ) from e
+
+ logger.debug(
+ "Retrieved user attributes from user mapping provider: %r (attempt %d)",
+ attributes,
+ i,
+ )
+
+ if not attributes.localpart:
+ raise MappingException(
+ "Error parsing SSO response: SSO mapping provider plugin "
+ "did not return a localpart value"
+ )
+
+ # Check if this mxid already exists
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ if not await self.store.get_users_by_id_case_insensitive(user_id):
+ # This mxid is free
+ break
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise MappingException(
+ "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ # Since the localpart is provided via a potentially untrusted module,
+ # ensure the MXID is valid before registering.
+ if contains_invalid_mxid_characters(attributes.localpart):
+ raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
+
+ logger.debug("Mapped SSO user to local part %s", attributes.localpart)
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=attributes.localpart,
+ default_display_name=attributes.display_name,
+ bind_emails=attributes.emails,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 32e53c2d25..9827c7eb8d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,6 +31,7 @@ from synapse.types import (
Collection,
JsonDict,
MutableStateMap,
+ Requester,
RoomStreamToken,
StateMap,
StreamToken,
@@ -260,6 +261,7 @@ class SyncHandler:
async def wait_for_sync_for_user(
self,
+ requester: Requester,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
@@ -273,7 +275,7 @@ class SyncHandler:
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
- await self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap(
sync_config.request_key,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index f409368802..e5b13593f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
+import urllib.parse
from io import BytesIO
from typing import (
+ TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@@ -31,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
-from netaddr import IPAddress
+from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -39,6 +40,8 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
+ IAddress,
+ IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@@ -53,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
-def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
+def check_against_blacklist(
+ ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
+) -> bool:
"""
+ Compares an IP address to allowed and disallowed IP sets.
+
Args:
- ip_address (netaddr.IPAddress)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ ip_address: The IP address to check
+ ip_whitelist: Allowed IP addresses.
+ ip_blacklist: Disallowed IP addresses.
+
+ Returns:
+ True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""
- def __init__(self, reactor, ip_whitelist, ip_blacklist):
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
"""
Args:
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ reactor: The twisted reactor.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def resolveHostName(self, recv, hostname, portNumber=0):
+ def resolveHostName(
+ self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
+ ) -> IResolutionReceiver:
r = recv()
- addresses = []
+ addresses = [] # type: List[IAddress]
- def _callback():
+ def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
@@ -161,15 +181,15 @@ class IPBlacklistingResolver:
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
- def resolutionBegan(resolutionInProgress):
+ def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
- def addressResolved(address):
+ def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
- def resolutionComplete():
+ def resolutionComplete() -> None:
_callback()
self._reactor.nameResolver.resolveHostName(
@@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""
- def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ agent: IAgent,
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ ):
"""
Args:
- agent (twisted.web.client.Agent): The Agent to wrap.
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ agent: The Agent to wrap.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
@@ -226,23 +256,23 @@ class SimpleHttpClient:
def __init__(
self,
- hs,
- treq_args={},
- ip_whitelist=None,
- ip_blacklist=None,
- http_proxy=None,
- https_proxy=None,
+ hs: "HomeServer",
+ treq_args: Dict[str, Any] = {},
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ http_proxy: Optional[bytes] = None,
+ https_proxy: Optional[bytes] = None,
):
"""
Args:
- hs (synapse.server.HomeServer)
- treq_args (dict): Extra keyword arguments to be given to treq.request.
- ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
+ hs
+ treq_args: Extra keyword arguments to be given to treq.request.
+ ip_blacklist: The IP addresses that are blacklisted that
we may not request.
- ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
+ ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
- http_proxy (bytes): proxy server to use for http connections. host[:port]
- https_proxy (bytes): proxy server to use for https connections. host[:port]
+ http_proxy: proxy server to use for http connections. host[:port]
+ https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -306,7 +336,6 @@ class SimpleHttpClient:
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
- self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@@ -397,7 +426,7 @@ class SimpleHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
- args: Mapping[str, Union[str, List[str]]] = {},
+ args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@@ -422,9 +451,7 @@ class SimpleHttpClient:
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
- query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
- "utf8"
- )
+ query_bytes = encode_query_args(args)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -432,7 +459,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@@ -479,7 +506,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@@ -495,7 +522,10 @@ class SimpleHttpClient:
)
async def get_json(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@@ -516,7 +546,7 @@ class SimpleHttpClient:
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@@ -525,7 +555,7 @@ class SimpleHttpClient:
self,
uri: str,
json_body: Any,
- args: QueryParams = {},
+ args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@@ -546,9 +576,9 @@ class SimpleHttpClient:
ValueError: if the response was not JSON
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
json_str = encode_canonical_json(json_body)
@@ -558,7 +588,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@@ -574,7 +604,10 @@ class SimpleHttpClient:
)
async def get_raw(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@@ -592,13 +625,13 @@ class SimpleHttpClient:
HttpResponseException on a non-2xx HTTP response.
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", uri, headers=Headers(actual_headers))
@@ -641,7 +674,7 @@ class SimpleHttpClient:
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", url, headers=Headers(actual_headers))
@@ -649,12 +682,13 @@ class SimpleHttpClient:
if (
b"Content-Length" in resp_headers
+ and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
- logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
+ logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
- "Requested file is too large > %r bytes" % (self.max_size,),
+ "Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
@@ -668,7 +702,7 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- _readBodyToFile(response, output_stream, max_size)
+ readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
-
-
class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
+ def __init__(
+ self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
+ ):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred = defer.Deferred()
self.transport.loseConnection()
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
+def readBodyToFile(
+ response: IResponse, stream: BinaryIO, max_size: Optional[int]
+) -> defer.Deferred:
+ """
+ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ Args:
+ response: The HTTP response to read from.
+ stream: The file-object to write to.
+ max_size: The maximum file size to allow.
+
+ Returns:
+ A Deferred which resolves to the length of the read body.
+ """
-def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
-def encode_urlencode_args(args):
- return {k: encode_urlencode_arg(v) for k, v in args.items()}
+def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
+ """
+ Encodes a map of query arguments to bytes which can be appended to a URL.
+ Args:
+ args: The query arguments, a mapping of string to string or list of strings.
+
+ Returns:
+ The query arguments encoded as bytes.
+ """
+ if args is None:
+ return b""
-def encode_urlencode_arg(arg):
- if isinstance(arg, str):
- return arg.encode("utf-8")
- elif isinstance(arg, list):
- return [encode_urlencode_arg(i) for i in arg]
- else:
- return arg
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, str):
+ vs = [vs]
+ encoded_args[k] = [v.encode("utf8") for v in vs]
+ query_str = urllib.parse.urlencode(encoded_args, True)
-def _print_ex(e):
- if hasattr(e, "reasons") and e.reasons:
- for ex in e.reasons:
- _print_ex(ex)
- else:
- logger.exception(e)
+ return query_str.encode("utf8")
class InsecureInterceptableContextFactory(ssl.ContextFactory):
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 83d6196d4a..e77f9587d0 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,21 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-import urllib
-from typing import List
+import urllib.parse
+from typing import List, Optional
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import Agent, HTTPConnectionPool
+from twisted.internet.interfaces import (
+ IProtocolFactory,
+ IReactorCore,
+ IStreamClientEndpoint,
+)
+from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IAgentEndpointFactory
+from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -44,30 +48,30 @@ class MatrixFederationAgent:
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
- reactor (IReactor): twisted reactor to use for underlying requests
+ reactor: twisted reactor to use for underlying requests
- tls_client_options_factory (FederationPolicyForHTTPS|None):
+ tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
- user_agent (bytes):
+ user_agent:
The user agent header to use for federation requests.
- _srv_resolver (SrvResolver|None):
- SRVResolver impl to use for looking up SRV records. None to use a default
- implementation.
+ _srv_resolver:
+ SrvResolver implementation to use for looking up SRV records. None
+ to use a default implementation.
- _well_known_resolver (WellKnownResolver|None):
+ _well_known_resolver:
WellKnownResolver to use to perform well-known lookups. None to use a
default implementation.
"""
def __init__(
self,
- reactor,
- tls_client_options_factory,
- user_agent,
- _srv_resolver=None,
- _well_known_resolver=None,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ user_agent: bytes,
+ _srv_resolver: Optional[SrvResolver] = None,
+ _well_known_resolver: Optional[WellKnownResolver] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -99,15 +103,20 @@ class MatrixFederationAgent:
self._well_known_resolver = _well_known_resolver
@defer.inlineCallbacks
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
"""
Args:
- method (bytes): HTTP method: GET/POST/etc
- uri (bytes): Absolute URI to be retrieved
- headers (twisted.web.http_headers.Headers|None):
- HTTP headers to send with the request, or None to
- send no extra headers.
- bodyProducer (twisted.web.iweb.IBodyProducer|None):
+ method: HTTP method: GET/POST/etc
+ uri: Absolute URI to be retrieved
+ headers:
+ HTTP headers to send with the request, or None to send no extra headers.
+ bodyProducer:
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
@@ -123,6 +132,9 @@ class MatrixFederationAgent:
# explicit port.
parsed_uri = urllib.parse.urlparse(uri)
+ # There must be a valid hostname.
+ assert parsed_uri.hostname
+
# If this is a matrix:// URI check if the server has delegated matrix
# traffic using well-known delegation.
#
@@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
- def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ srv_resolver: Optional[SrvResolver],
+ ):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
@@ -203,15 +220,20 @@ class MatrixHostnameEndpoint:
resolution (i.e. via SRV). Does not check for well-known delegation.
Args:
- reactor (IReactor)
- tls_client_options_factory (ClientTLSOptionsFactory|None):
+ reactor: twisted reactor to use for underlying requests
+ tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
- srv_resolver (SrvResolver): The SRV resolver to use
- parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
- to connect to.
+ srv_resolver: The SRV resolver to use
+ parsed_uri: The parsed URI that we're wanting to connect to.
"""
- def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ srv_resolver: SrvResolver,
+ parsed_uri: URI,
+ ):
self._reactor = reactor
self._parsed_uri = parsed_uri
@@ -231,13 +253,13 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
- def connect(self, protocol_factory):
+ def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
"""Implements IStreamClientEndpoint interface
"""
return run_in_background(self._do_connect, protocol_factory)
- async def _do_connect(self, protocol_factory):
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
first_exception = None
server_list = await self._resolve_server()
@@ -303,20 +325,20 @@ class MatrixHostnameEndpoint:
return [Server(host, 8448)]
-def _is_ip_literal(host):
+def _is_ip_literal(host: bytes) -> bool:
"""Test if the given host name is either an IPv4 or IPv6 literal.
Args:
- host (bytes)
+ host: The host name to check
Returns:
- bool
+ True if the hostname is an IP address literal.
"""
- host = host.decode("ascii")
+ host_str = host.decode("ascii")
try:
- IPAddress(host)
+ IPAddress(host_str)
return True
except AddrFormatError:
return False
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 1cc666fbf6..5e08ef1664 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import random
import time
@@ -21,10 +20,11 @@ from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@@ -81,11 +81,11 @@ class WellKnownResolver:
def __init__(
self,
- reactor,
- agent,
- user_agent,
- well_known_cache=None,
- had_well_known_cache=None,
+ reactor: IReactorTime,
+ agent: IAgent,
+ user_agent: bytes,
+ well_known_cache: Optional[TTLCache] = None,
+ had_well_known_cache: Optional[TTLCache] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -127,7 +127,7 @@ class WellKnownResolver:
with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known(
server_name
- ) # type: Tuple[Optional[bytes], float]
+ ) # type: Optional[bytes], float
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 7e17cdb73e..4e27f93b7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,8 +17,9 @@ import cgi
import logging
import random
import sys
-import urllib
+import urllib.parse
from io import BytesIO
+from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
import treq
@@ -27,25 +28,27 @@ from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
-from twisted.internet import defer, protocol
+from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
-from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
- Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
- SynapseError,
)
from synapse.http import QuieterFileBodyProducer
-from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ IPBlacklistingResolver,
+ encode_query_args,
+ readBodyToFile,
+)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@@ -54,6 +57,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -76,47 +80,44 @@ MAXINT = sys.maxsize
_next_id = 1
+QueryArgs = Dict[str, Union[str, List[str]]]
+
+
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
- method = attr.ib()
+ method = attr.ib(type=str)
"""HTTP method
- :type: str
"""
- path = attr.ib()
+ path = attr.ib(type=str)
"""HTTP path
- :type: str
"""
- destination = attr.ib()
+ destination = attr.ib(type=str)
"""The remote server to send the HTTP request to.
- :type: str"""
+ """
- json = attr.ib(default=None)
+ json = attr.ib(default=None, type=Optional[JsonDict])
"""JSON to send in the body.
- :type: dict|None
"""
- json_callback = attr.ib(default=None)
+ json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]])
"""A callback to generate the JSON.
- :type: func|None
"""
- query = attr.ib(default=None)
+ query = attr.ib(default=None, type=Optional[dict])
"""Query arguments.
- :type: dict|None
"""
- txn_id = attr.ib(default=None)
+ txn_id = attr.ib(default=None, type=Optional[str])
"""Unique ID for this request (for logging)
- :type: str|None
"""
uri = attr.ib(init=False, type=bytes)
"""The URI of this request
"""
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
global _next_id
txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1)
@@ -136,7 +137,7 @@ class MatrixFederationRequest:
)
object.__setattr__(self, "uri", uri)
- def get_json(self):
+ def get_json(self) -> Optional[JsonDict]:
if self.json_callback:
return self.json_callback()
return self.json
@@ -148,7 +149,7 @@ async def _handle_json_response(
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
-):
+) -> JsonDict:
"""
Reads the JSON body of a response, with a timeout
@@ -160,7 +161,7 @@ async def _handle_json_response(
start_ms: Timestamp when request was made
Returns:
- dict: parsed JSON response
+ The parsed JSON response
"""
try:
check_content_type_is_json(response.headers)
@@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent,
- self.reactor,
- ip_blacklist=hs.config.federation_ip_range_blacklist,
+ self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@@ -266,27 +265,29 @@ class MatrixFederationHttpClient:
self._cooperator = Cooperator(scheduler=schedule)
async def _send_request_with_optional_trailing_slash(
- self, request, try_trailing_slash_on_400=False, **send_request_args
- ):
+ self,
+ request: MatrixFederationRequest,
+ try_trailing_slash_on_400: bool = False,
+ **send_request_args
+ ) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
due to #3622.
Args:
- request (MatrixFederationRequest): details of request to be sent
- try_trailing_slash_on_400 (bool): Whether on receiving a 400
+ request: details of request to be sent
+ try_trailing_slash_on_400: Whether on receiving a 400
'M_UNRECOGNIZED' from the server to retry the request with a
trailing slash appended to the request path.
- send_request_args (Dict): A dictionary of arguments to pass to
- `_send_request()`.
+ send_request_args: A dictionary of arguments to pass to `_send_request()`.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
Returns:
- Dict: Parsed JSON response body.
+ Parsed JSON response body.
"""
try:
response = await self._send_request(request, **send_request_args)
@@ -313,24 +314,26 @@ class MatrixFederationHttpClient:
async def _send_request(
self,
- request,
- retry_on_dns_fail=True,
- timeout=None,
- long_retries=False,
- ignore_backoff=False,
- backoff_on_404=False,
- ):
+ request: MatrixFederationRequest,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ long_retries: bool = False,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ ) -> IResponse:
"""
Sends a request to the given server.
Args:
- request (MatrixFederationRequest): details of request to be sent
+ request: details of request to be sent
+
+ retry_on_dns_fail: true if the request should be retied on DNS failures
- timeout (int|None): number of milliseconds to wait for the response headers
+ timeout: number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
60s by default.
- long_retries (bool): whether to use the long retry algorithm.
+ long_retries: whether to use the long retry algorithm.
The regular retry algorithm makes 4 attempts, with intervals
[0.5s, 1s, 2s].
@@ -346,14 +349,13 @@ class MatrixFederationHttpClient:
NB: the long retry algorithm takes over 20 minutes to complete, with
a default timeout of 60s!
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- backoff_on_404 (bool): Back off if we get a 404
+ backoff_on_404: Back off if we get a 404
Returns:
- twisted.web.client.Response: resolves with the HTTP
- response object on success.
+ Resolves with the HTTP response object on success.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
@@ -404,7 +406,7 @@ class MatrixFederationHttpClient:
)
# Inject the span into the headers
- headers_dict = {}
+ headers_dict = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -435,7 +437,7 @@ class MatrixFederationHttpClient:
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
- )
+ ) # type: Optional[IBodyProducer]
else:
producer = None
auth_headers = self.build_auth_headers(
@@ -524,14 +526,16 @@ class MatrixFederationHttpClient:
)
body = None
- e = HttpResponseException(response.code, response_phrase, body)
+ exc = HttpResponseException(
+ response.code, response_phrase, body
+ )
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
- raise RequestSendFailed(e, can_retry=True) from e
+ raise RequestSendFailed(exc, can_retry=True) from exc
else:
- raise e
+ raise exc
break
except RequestSendFailed as e:
@@ -582,22 +586,27 @@ class MatrixFederationHttpClient:
return response
def build_auth_headers(
- self, destination, method, url_bytes, content=None, destination_is=None
- ):
+ self,
+ destination: Optional[bytes],
+ method: bytes,
+ url_bytes: bytes,
+ content: Optional[JsonDict] = None,
+ destination_is: Optional[bytes] = None,
+ ) -> List[bytes]:
"""
Builds the Authorization headers for a federation request
Args:
- destination (bytes|None): The destination homeserver of the request.
+ destination: The destination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
- method (bytes): The HTTP method of the request
- url_bytes (bytes): The URI path of the request
- content (object): The body of the request
- destination_is (bytes): As 'destination', but if the destination is an
+ method: The HTTP method of the request
+ url_bytes: The URI path of the request
+ content: The body of the request
+ destination_is: As 'destination', but if the destination is an
identity server
Returns:
- list[bytes]: a list of headers to be added as "Authorization:" headers
+ A list of headers to be added as "Authorization:" headers
"""
request = {
"method": method.decode("ascii"),
@@ -629,33 +638,32 @@ class MatrixFederationHttpClient:
async def put_json(
self,
- destination,
- path,
- args={},
- data={},
- json_data_callback=None,
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- backoff_on_404=False,
- try_trailing_slash_on_400=False,
- ):
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ ) -> Union[JsonDict, list]:
""" Sends the specified json data using PUT
Args:
- destination (str): The remote server to send the HTTP request
- to.
- path (str): The HTTP path.
- args (dict): query params
- data (dict): A dict containing the data that will be used as
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path.
+ args: query params
+ data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- json_data_callback (callable): A callable returning the dict to
+ json_data_callback: A callable returning the dict to
use as the request body.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -663,19 +671,19 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- backoff_on_404 (bool): True if we should count a 404 response as
+ backoff_on_404: True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests).
- try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -721,29 +729,28 @@ class MatrixFederationHttpClient:
async def post_json(
self,
- destination,
- path,
- data={},
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- args={},
- ):
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryArgs] = None,
+ ) -> Union[JsonDict, list]:
""" Sends the specified json data using POST
Args:
- destination (str): The remote server to send the HTTP request
- to.
+ destination: The remote server to send the HTTP request to.
- path (str): The HTTP path.
+ path: The HTTP path.
- data (dict): A dict containing the data that will be used as
+ data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -751,10 +758,10 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data and
+ ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
- args (dict): query params
+ args: query params
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -795,26 +802,25 @@ class MatrixFederationHttpClient:
async def get_json(
self,
- destination,
- path,
- args=None,
- retry_on_dns_fail=True,
- timeout=None,
- ignore_backoff=False,
- try_trailing_slash_on_400=False,
- ):
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ ) -> Union[JsonDict, list]:
""" GETs some json from the given host homeserver and path
Args:
- destination (str): The remote server to send the HTTP request
- to.
+ destination: The remote server to send the HTTP request to.
- path (str): The HTTP path.
+ path: The HTTP path.
- args (dict|None): A dictionary used to create query strings, defaults to
+ args: A dictionary used to create query strings, defaults to
None.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -822,14 +828,14 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -870,24 +876,23 @@ class MatrixFederationHttpClient:
async def delete_json(
self,
- destination,
- path,
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- args={},
- ):
+ destination: str,
+ path: str,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryArgs] = None,
+ ) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response
Args:
- destination (str): The remote server to send the HTTP request
- to.
- path (str): The HTTP path.
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -895,12 +900,12 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data and
+ ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
- args (dict): query params
+ args: query params
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -938,25 +943,25 @@ class MatrixFederationHttpClient:
async def get_file(
self,
- destination,
- path,
+ destination: str,
+ path: str,
output_stream,
- args={},
- retry_on_dns_fail=True,
- max_size=None,
- ignore_backoff=False,
- ):
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ max_size: Optional[int] = None,
+ ignore_backoff: bool = False,
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
"""GETs a file from a given homeserver
Args:
- destination (str): The remote server to send the HTTP request to.
- path (str): The HTTP path to GET.
- output_stream (file): File to write the response body to.
- args (dict): Optional dictionary used to create the query string.
- ignore_backoff (bool): true to ignore the historical backoff data
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path to GET.
+ output_stream: File to write the response body to.
+ args: Optional dictionary used to create the query string.
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
- tuple[int, dict]: Resolves with an (int,dict) tuple of
+ Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Raises:
@@ -980,7 +985,7 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
- d = _readBodyToFile(response, output_stream, max_size)
+ d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@@ -1004,40 +1009,6 @@ class MatrixFederationHttpClient:
return (length, headers)
-class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
- self.stream = stream
- self.deferred = deferred
- self.length = 0
- self.max_size = max_size
-
- def dataReceived(self, data):
- self.stream.write(data)
- self.length += len(data)
- if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
- self.deferred = defer.Deferred()
- self.transport.loseConnection()
-
- def connectionLost(self, reason):
- if reason.check(ResponseDone):
- self.deferred.callback(self.length)
- else:
- self.deferred.errback(reason)
-
-
-def _readBodyToFile(response, stream, max_size):
- d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
- return d
-
-
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@@ -1049,13 +1020,13 @@ def _flatten_response_never_received(e):
return repr(e)
-def check_content_type_is_json(headers):
+def check_content_type_is_json(headers: Headers) -> None:
"""
Check that a set of HTTP headers have a Content-Type header, and that it
is application/json.
Args:
- headers (twisted.web.http_headers.Headers): headers to check
+ headers: headers to check
Raises:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
@@ -1078,18 +1049,3 @@ def check_content_type_is_json(headers):
),
can_retry=False,
)
-
-
-def encode_query_args(args):
- if args is None:
- return b""
-
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, str):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.parse.urlencode(encoded_args, True)
-
- return query_bytes.encode("utf8")
diff --git a/synapse/http/server.py b/synapse/http/server.py
index c0919f8cb7..6a4e429a6c 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -25,7 +25,7 @@ from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
-from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
+from canonicaljson import iterencode_canonical_json
from zope.interface import implementer
from twisted.internet import defer, interfaces
@@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
- request,
- error_code,
- error_dict,
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
+ request, error_code, error_dict, send_cors=True,
)
@@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
code,
response_object,
send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
@@ -587,7 +582,6 @@ def respond_with_json(
code: int,
json_object: Any,
send_cors: bool = False,
- pretty_print: bool = False,
canonical_json: bool = True,
):
"""Sends encoded JSON in response to the given request.
@@ -598,8 +592,6 @@ def respond_with_json(
json_object: The object to serialize to JSON.
send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
- pretty_print: Whether to include indentation and line-breaks in the
- resulting JSON bytes.
canonical_json: Whether to use the canonicaljson algorithm when encoding
the JSON bytes.
@@ -615,13 +607,10 @@ def respond_with_json(
)
return None
- if pretty_print:
- encoder = iterencode_pretty_printed_json
+ if canonical_json:
+ encoder = iterencode_canonical_json
else:
- if canonical_json:
- encoder = iterencode_canonical_json
- else:
- encoder = _encode_json_bytes
+ encoder = _encode_json_bytes
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@@ -685,7 +674,7 @@ def set_cors_headers(request: Request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
- b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
+ b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
)
@@ -759,11 +748,3 @@ def finish_request(request: Request):
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
-
-
-def _request_user_agent_is_curl(request: Request) -> bool:
- user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
- for user_agent in user_agents:
- if b"curl" in user_agent:
- return True
- return False
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 0142542852..72ab5750cc 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -49,6 +49,7 @@ class ModuleApi:
self._store = hs.get_datastore()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
+ self._server_name = hs.hostname
# We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
@@ -336,7 +337,9 @@ class ModuleApi:
SynapseError if the event was not allowed.
"""
# Create a requester object
- requester = create_requester(event_dict["sender"])
+ requester = create_requester(
+ event_dict["sender"], authenticated_entity=self._server_name
+ )
# Create and send the event
(
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index f5788c1de7..f5c6c977c7 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -482,7 +482,11 @@ BASE_APPEND_UNDERRIDE_RULES = [
"_id": "_message",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
@@ -496,7 +500,11 @@ BASE_APPEND_UNDERRIDE_RULES = [
"_id": "_encrypted",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
{
"rule_id": "global/underride/.im.vector.jitsi",
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 793d0db2d9..eff0975b6a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -75,6 +75,7 @@ class HttpPusher:
self.failing_since = pusherdict["failing_since"]
self.timed_call = None
self._is_processing = False
+ self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
@@ -136,7 +137,11 @@ class HttpPusher:
async def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
- badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ badge = await push_tools.get_badge_count(
+ self.hs.get_datastore(),
+ self.user_id,
+ group_by_room=self._group_unread_count_by_room,
+ )
await self._send_badge(badge)
def on_timer(self):
@@ -283,7 +288,11 @@ class HttpPusher:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
- badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ badge = await push_tools.get_badge_count(
+ self.hs.get_datastore(),
+ self.user_id,
+ group_by_room=self._group_unread_count_by_room,
+ )
event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..6e7c880dc0 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
+from synapse.storage.databases.main import DataStore
-async def get_badge_count(store, user_id):
+async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
@@ -34,9 +34,15 @@ async def get_badge_count(store, user_id):
room_id, user_id, last_unread_event_id
)
)
- # return one badge count per conversation, as count per
- # message is so noisy as to be almost useless
- badge += 1 if notifs["notify_count"] else 0
+ if notifs["notify_count"] == 0:
+ continue
+
+ if group_by_room:
+ # return one badge count per conversation
+ badge += 1
+ else:
+ # increment the badge count by the number of unread messages in the room
+ badge += notifs["notify_count"]
return badge
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f325964983..32794af6ca 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -64,7 +64,7 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
- self._account_validity = hs.config.account_validity
+ self._account_validity_enabled = hs.config.account_validity_enabled
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
@@ -225,7 +225,7 @@ class PusherPool:
for u in users_affected:
# Don't push if the user account has expired
- if self._account_validity.enabled:
+ if self._account_validity_enabled:
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
@@ -253,7 +253,7 @@ class PusherPool:
for u in users_affected:
# Don't push if the user account has expired
- if self._account_validity.enabled:
+ if self._account_validity_enabled:
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 040f76bb53..8e5d3e3c52 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -40,6 +40,10 @@ logger = logging.getLogger(__name__)
# Note that these both represent runtime dependencies (and the versions
# installed are checked at runtime).
#
+# Also note that we replicate these constraints in the Synapse Dockerfile while
+# pre-installing dependencies. If these constraints are updated here, the same
+# change should be made in the Dockerfile.
+#
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
REQUIREMENTS = [
@@ -69,14 +73,7 @@ REQUIREMENTS = [
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
# we use GaugeHistogramMetric, which was added in prom-client 0.4.0.
- # prom-client has a history of breaking backwards compatibility between
- # minor versions (https://github.com/prometheus/client_python/issues/317),
- # so we also pin the minor version.
- #
- # Note that we replicate these constraints in the Synapse Dockerfile while
- # pre-installing dependencies. If these constraints are updated here, the
- # same change should be made in the Dockerfile.
- "prometheus_client>=0.4.0,<0.9.0",
+ "prometheus_client>=0.4.0",
# we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
# Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
# is out in November.)
@@ -115,6 +112,8 @@ CONDITIONAL_REQUIREMENTS = {
"redis": ["txredisapi>=1.4.7", "hiredis"],
}
+CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
+
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index b4f4a68b5c..7a0dbb5b1a 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -254,20 +254,20 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return 200, {}
-class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
+class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
Request format:
- POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
+ POST /_synapse/replication/store_room_on_outlier_membership/:room_id/:txn_id
{
"room_version": "1",
}
"""
- NAME = "store_room_on_invite"
+ NAME = "store_room_on_outlier_membership"
PATH_ARGS = ("room_id",)
def __init__(self, hs):
@@ -282,7 +282,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
async def _handle_request(self, request, room_id):
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
- await self.store.maybe_store_room_on_invite(room_id, room_version)
+ await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
@@ -291,4 +291,4 @@ def register_servlets(hs, http_server):
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
- ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
+ ReplicationStoreRoomOnOutlierMembershipRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index f0c37eaf5e..84e002f934 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.http import Request
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- async def _serialize_payload(
- requester, room_id, user_id, remote_room_hosts, content
- ):
+ async def _serialize_payload( # type: ignore
+ requester: Requester,
+ room_id: str,
+ user_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ ) -> JsonDict:
"""
Args:
- requester(Requester)
- room_id (str)
- user_id (str)
- remote_room_hosts (list[str]): Servers to try and join via
- content(dict): The event content to use for the join event
+ requester: The user making the request according to the access token
+ room_id: The ID of the room.
+ user_id: The ID of the user.
+ remote_room_hosts: Servers to try and join via
+ content: The event content to use for the join event
+
+ Returns:
+ A dict representing the payload of the request.
"""
return {
"requester": requester.serialize(),
@@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- async def _handle_request(self, request, room_id, user_id):
+ async def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
- ):
+ ) -> JsonDict:
"""
Args:
- invite_event_id: ID of the invite to be rejected
- txn_id: optional transaction ID supplied by the client
- requester: user making the rejection request, according to the access token
- content: additional content to include in the rejection event.
+ invite_event_id: The ID of the invite to be rejected.
+ txn_id: Optional transaction ID supplied by the client
+ requester: User making the rejection request, according to the access token
+ content: Additional content to include in the rejection event.
Normally an empty dict.
+
+ Returns:
+ A dict representing the payload of the request.
"""
return {
"txn_id": txn_id,
@@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"content": content,
}
- async def _handle_request(self, request, invite_event_id):
+ async def _handle_request( # type: ignore
+ self, request: Request, invite_event_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
@@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
- async def _serialize_payload(room_id, user_id, change):
+ async def _serialize_payload( # type: ignore
+ room_id: str, user_id: str, change: str
+ ) -> JsonDict:
"""
Args:
- room_id (str)
- user_id (str)
- change (str): "left"
+ room_id: The ID of the room.
+ user_id: The ID of the user.
+ change: "left"
+
+ Returns:
+ A dict representing the payload of the request.
"""
assert change == "left"
return {}
- def _handle_request(self, request, room_id, user_id, change):
+ def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str, change: str
+ ) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)
diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html
new file mode 100644
index 0000000000..b751359bdf
--- /dev/null
+++ b/synapse/res/templates/account_previously_renewed.html
@@ -0,0 +1 @@
+<html><body>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
index 894da030af..e8c0f52f05 100644
--- a/synapse/res/templates/account_renewed.html
+++ b/synapse/res/templates/account_renewed.html
@@ -1 +1 @@
-<html><body>Your account has been successfully renewed.</body><html>
+<html><body>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 40f5c32db2..5d9cdf4bde 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -119,6 +119,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 2a4f7a1740..55ddebb4fe 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -21,11 +21,7 @@ import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.rest.admin._base import (
- admin_patterns,
- assert_requester_is_admin,
- historical_admin_path_patterns,
-)
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
DeviceRestServlet,
@@ -61,6 +57,7 @@ from synapse.rest.admin.users import (
UserRestServletV2,
UsersRestServlet,
UsersRestServletV2,
+ UserTokenRestServlet,
WhoisRestServlet,
)
from synapse.types import RoomStreamToken
@@ -83,7 +80,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns(
+ PATTERNS = admin_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
@@ -168,9 +165,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns(
- "/purge_history_status/(?P<purge_id>[^/]+)"
- )
+ PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
def __init__(self, hs):
"""
@@ -223,6 +218,7 @@ def register_servlets(hs, http_server):
UserAdminServlet(hs).register(http_server)
UserMediaRestServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
+ UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index db9fea263a..e09234c644 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -22,28 +22,6 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
-def historical_admin_path_patterns(path_regex):
- """Returns the list of patterns for an admin endpoint, including historical ones
-
- This is a backwards-compatibility hack. Previously, the Admin API was exposed at
- various paths under /_matrix/client. This function returns a list of patterns
- matching those paths (as well as the new one), so that existing scripts which rely
- on the endpoints being available there are not broken.
-
- Note that this should only be used for existing endpoints: new ones should just
- register for the /_synapse/admin path.
- """
- return [
- re.compile(prefix + path_regex)
- for prefix in (
- "^/_synapse/admin/v1",
- "^/_matrix/client/api/v1/admin",
- "^/_matrix/client/unstable/admin",
- "^/_matrix/client/r0/admin",
- )
- ]
-
-
def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index 0b54ca09f4..d0c86b204a 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -16,10 +16,7 @@ import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
-from synapse.rest.admin._base import (
- assert_user_is_admin,
- historical_admin_path_patterns,
-)
+from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
logger = logging.getLogger(__name__)
@@ -28,7 +25,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups
"""
- PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
+ PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs):
self.group_server = hs.get_groups_server_handler()
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index ba50cb876d..c82b4f87d6 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -22,7 +22,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
logger = logging.getLogger(__name__)
@@ -34,10 +33,10 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = (
- historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+ admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+
# This path kept around for legacy reasons
- historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
def __init__(self, hs):
@@ -63,9 +62,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
- PATTERNS = historical_admin_path_patterns(
- "/user/(?P<user_id>[^/]+)/media/quarantine"
- )
+ PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -90,7 +87,7 @@ class QuarantineMediaByID(RestServlet):
it via this server.
"""
- PATTERNS = historical_admin_path_patterns(
+ PATTERNS = admin_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
@@ -116,7 +113,7 @@ class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
- PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
+ PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -134,7 +131,7 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/purge_media_cache")
+ PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs):
self.media_repository = hs.get_media_repository()
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f5304ff43d..25f89e4685 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -29,7 +29,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -44,7 +43,7 @@ class ShutdownRoomRestServlet(RestServlet):
joined to the new room.
"""
- PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
+ PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
def __init__(self, hs):
self.hs = hs
@@ -71,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet):
class DeleteRoomRestServlet(RestServlet):
- """Delete a room from server. It is a combination and improvement of
- shut down and purge room.
+ """Delete a room from server.
+
+ It is a combination and improvement of shutdown and purge room.
+
Shuts down a room by removing all local users from the room.
Blocking all future invites and joins to the room is optional.
+
If desired any local aliases will be repointed to a new room
- created by `new_room_user_id` and kicked users will be auto
+ created by `new_room_user_id` and kicked users will be auto-
joined to the new room.
- It will remove all trace of a room from the database.
+
+ If 'purge' is true, it will remove all traces of a room from the database.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
@@ -111,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
+ force_purge = content.get("force_purge", False)
+ if not isinstance(force_purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'force_purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@@ -122,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet):
# Purge room
if purge:
- await self.pagination_handler.purge_room(room_id)
+ await self.pagination_handler.purge_room(room_id, force=force_purge)
return (200, ret)
@@ -309,7 +320,9 @@ class JoinRoomAliasServlet(RestServlet):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- fake_requester = create_requester(target_user)
+ fake_requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity
+ )
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 3638e219f2..b0ff5e1ead 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,7 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
-from typing import Tuple
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -33,10 +33,13 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import JsonDict, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
_GET_PUSHERS_ALLOWED_KEYS = {
@@ -52,7 +55,7 @@ _GET_PUSHERS_ALLOWED_KEYS = {
class UsersRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs):
self.hs = hs
@@ -335,7 +338,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds.
"""
- PATTERNS = historical_admin_path_patterns("/register")
+ PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
@@ -458,7 +461,14 @@ class UserRegisterServlet(RestServlet):
class WhoisRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
+ path_regex = "/whois/(?P<user_id>[^/]*)$"
+ PATTERNS = (
+ admin_patterns(path_regex)
+ +
+ # URL for spec reason
+ # https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid
+ client_patterns("/admin" + path_regex, v1=True)
+ )
def __init__(self, hs):
self.hs = hs
@@ -482,7 +492,7 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@@ -513,7 +523,7 @@ class DeactivateAccountRestServlet(RestServlet):
class AccountValidityRenewServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
+ PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs):
"""
@@ -556,9 +566,7 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error.
"""
- PATTERNS = historical_admin_path_patterns(
- "/reset_password/(?P<target_user_id>[^/]*)"
- )
+ PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -600,7 +608,7 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.hs = hs
@@ -828,3 +836,52 @@ class UserMediaRestServlet(RestServlet):
ret["next_token"] = start + len(media)
return 200, ret
+
+
+class UserTokenRestServlet(RestServlet):
+ """An admin API for logging in as a user.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/login
+ {}
+
+ 200 OK
+ {
+ "access_token": "<some_token>"
+ }
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+
+ async def on_POST(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ auth_user = requester.user
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be logged in as")
+
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+
+ valid_until_ms = body.get("valid_until_ms")
+ if valid_until_ms and not isinstance(valid_until_ms, int):
+ raise SynapseError(400, "'valid_until_ms' parameter must be an int")
+
+ if auth_user.to_string() == user_id:
+ raise SynapseError(400, "Cannot use admin API to login as self")
+
+ token = await self.auth_handler.get_access_token_for_user_id(
+ user_id=auth_user.to_string(),
+ device_id=None,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=user_id,
+ )
+
+ return 200, {"access_token": token}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 94452fcbf5..d7ae148214 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.handlers.auth import (
- convert_client_dict_legacy_fields_to_identifier,
- login_id_phone_to_thirdparty,
-)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
-from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet):
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
- self._failed_attempts_ratelimiter = Ratelimiter(
- clock=hs.get_clock(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- )
def on_GET(self, request: SynapseRequest):
flows = []
@@ -140,27 +130,31 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- def _get_qualified_user_id(self, identifier):
- if identifier["type"] != "m.id.user":
- raise SynapseError(400, "Unknown login identifier type")
- if "user" not in identifier:
- raise SynapseError(400, "User identifier is missing 'user' key")
-
- if identifier["user"].startswith("@"):
- return identifier["user"]
- else:
- return UserID(identifier["user"], self.hs.hostname).to_string()
-
async def _do_appservice_login(
self, login_submission: JsonDict, appservice: ApplicationService
):
- logger.info(
- "Got appservice login request with identifier: %r",
- login_submission.get("identifier"),
- )
+ identifier = login_submission.get("identifier")
+ logger.info("Got appservice login request with identifier: %r", identifier)
+
+ if not isinstance(identifier, dict):
+ raise SynapseError(
+ 400, "Invalid identifier in login submission", Codes.INVALID_PARAM
+ )
+
+ # this login flow only supports identifiers of type "m.id.user".
+ if identifier.get("type") != "m.id.user":
+ raise SynapseError(
+ 400, "Unknown login identifier type", Codes.INVALID_PARAM
+ )
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
- qualified_user_id = self._get_qualified_user_id(identifier)
+ user = identifier.get("user")
+ if not isinstance(user, str):
+ raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
+
+ if user.startswith("@"):
+ qualified_user_id = user
+ else:
+ qualified_user_id = UserID(user, self.hs.hostname).to_string()
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
@@ -186,91 +180,9 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
-
- # convert phone type identifiers to generic threepids
- if identifier["type"] == "m.id.phone":
- identifier = login_id_phone_to_thirdparty(identifier)
-
- # convert threepid identifiers to user IDs
- if identifier["type"] == "m.id.thirdparty":
- address = identifier.get("address")
- medium = identifier.get("medium")
-
- if medium is None or address is None:
- raise SynapseError(400, "Invalid thirdparty identifier")
-
- # For emails, canonicalise the address.
- # We store all email addresses canonicalised in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- if medium == "email":
- try:
- address = canonicalise_email(address)
- except ValueError as e:
- raise SynapseError(400, str(e))
-
- # We also apply account rate limiting using the 3PID as a key, as
- # otherwise using 3PID bypasses the ratelimiting based on user ID.
- self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
-
- # Check for login providers that support 3pid login types
- (
- canonical_user_id,
- callback_3pid,
- ) = await self.auth_handler.check_password_provider_3pid(
- medium, address, login_submission["password"]
- )
- if canonical_user_id:
- # Authentication through password provider and 3pid succeeded
-
- result = await self._complete_login(
- canonical_user_id, login_submission, callback_3pid
- )
- return result
-
- # No password providers were able to handle this 3pid
- # Check local store
- user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- medium, address
- )
- if not user_id:
- logger.warning(
- "unknown 3pid identifier medium %s, address %r", medium, address
- )
- # We mark that we've failed to log in here, as
- # `check_password_provider_3pid` might have returned `None` due
- # to an incorrect password, rather than the account not
- # existing.
- #
- # If it returned None but the 3PID was bound then we won't hit
- # this code path, which is fine as then the per-user ratelimit
- # will kick in below.
- self._failed_attempts_ratelimiter.can_do_action((medium, address))
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- identifier = {"type": "m.id.user", "user": user_id}
-
- # by this point, the identifier should be an m.id.user: if it's anything
- # else, we haven't understood it.
- qualified_user_id = self._get_qualified_user_id(identifier)
-
- # Check if we've hit the failed ratelimit (but don't update it)
- self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), update=False
+ canonical_user_id, callback = await self.auth_handler.validate_login(
+ login_submission, ratelimit=True
)
-
- try:
- canonical_user_id, callback = await self.auth_handler.validate_login(
- identifier["user"], login_submission
- )
- except LoginError:
- # The user has failed to log in, so we need to update the rate
- # limiter. Using `can_do_action` avoids us raising a ratelimit
- # exception and masking the LoginError. The actual ratelimiting
- # should have happened above.
- self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
- raise
-
result = await self._complete_login(
canonical_user_id, login_submission, callback
)
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 23a529f8e3..94bfe2d1b0 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -49,9 +49,7 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(
- state, self.clock.time_msec(), include_user_id=False
- )
+ state = format_user_presence_state(state, self.clock.time_msec())
return 200, state
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 85a66458c5..b5fa1cc464 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -65,8 +67,27 @@ class ProfileDisplaynameRestServlet(RestServlet):
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
+ return 200, {}
+
+ def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -75,6 +96,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -113,8 +135,27 @@ class ProfileAvatarURLRestServlet(RestServlet):
user, requester, new_avatar_url, is_admin
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
+
+ return 200, {}
+
+ def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 25d3cc6148..0c658f120e 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -18,7 +18,7 @@
import logging
import re
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@@ -48,8 +48,7 @@ from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID,
from synapse.util import json_decoder
from synapse.util.stringutils import random_string
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
@@ -726,7 +725,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
)
except ShadowBanError:
# Pretend the request succeeded.
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index a54e1011f7..8e4a01a7e8 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import random
+import re
from http import HTTPStatus
from typing import TYPE_CHECKING
from urllib.parse import urlparse
@@ -40,6 +41,7 @@ from synapse.http.servlet import (
)
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
@@ -115,7 +117,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -162,6 +164,7 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore()
self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -187,26 +190,33 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- try:
- params, session_id = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
- except InteractiveAuthIncompleteError as e:
- # The user needs to provide more steps to complete auth, but
- # they're not required to provide the password again.
- #
- # If a password is available now, hash the provided password and
- # store it for later.
- if new_password:
- password_hash = await self.auth_handler.hash(new_password)
- await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ try:
+ (
+ params,
+ session_id,
+ ) = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
)
- raise
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
user_id = requester.user.to_string()
else:
requester = None
@@ -271,8 +281,28 @@ class PasswordRestServlet(RestServlet):
user_id, password_hash, logout_devices, requester
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ await self.shadow_password(params, shadow_user.to_string())
+
+ return 200, {}
+
+ def on_OPTIONS(self, _):
return 200, {}
+ async def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -368,10 +398,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", email)):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -387,7 +417,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -447,7 +477,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
@@ -466,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@@ -617,7 +647,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -636,6 +667,29 @@ class ThreepidRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ # skip validation if this is a shadow 3PID from an AS
+ if requester.app_service:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
+
+ await self.auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return 200, {}
+
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
if threepid_creds is None:
raise SynapseError(
@@ -657,12 +711,35 @@ class ThreepidRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ async def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
@@ -673,6 +750,7 @@ class ThreepidAddRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -708,12 +786,33 @@ class ThreepidAddRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ async def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
@@ -783,6 +882,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
@@ -807,6 +907,12 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ await self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
@@ -814,6 +920,74 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+ async def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_identity_handler()
+
+ async def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ await self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+
+ return 200, ret
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_identity_handler()
+
+ async def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ await self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = await self.identity_handler.proxy_bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ return 200, ret
+
def assert_valid_next_link(hs: "HomeServer", next_link: str):
"""
@@ -880,4 +1054,6 @@ def register_servlets(hs, http_server):
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..617ee6d62a 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -39,6 +40,7 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
+ self._profile_handler = hs.get_profile_handler()
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
@@ -50,6 +52,11 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ await self._profile_handler.set_active([user], not hide_profile, True)
+
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index bd7f9ae203..40c5bd4d8c 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -37,24 +37,38 @@ class AccountValidityRenewServlet(RestServlet):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
- self.success_html = hs.config.account_validity.account_renewed_html_content
- self.failure_html = hs.config.account_validity.invalid_token_html_content
+ self.account_renewed_template = (
+ hs.config.account_validity_account_renewed_template
+ )
+ self.account_previously_renewed_template = (
+ hs.config.account_validity_account_previously_renewed_template
+ )
+ self.invalid_token_template = hs.config.account_validity_invalid_token_template
async def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
- token_valid = await self.account_activity_handler.renew_account(
+ (
+ token_valid,
+ token_stale,
+ expiration_ts,
+ ) = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)
if token_valid:
status_code = 200
- response = self.success_html
+ response = self.account_renewed_template.render(expiration_ts=expiration_ts)
+ elif token_stale:
+ status_code = 200
+ response = self.account_previously_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
else:
status_code = 404
- response = self.failure_html
+ response = self.invalid_token_template.render(expiration_ts=expiration_ts)
respond_with_html(request, status_code, response)
@@ -72,10 +86,12 @@ class AccountValiditySendMailServlet(RestServlet):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
- self.account_validity = self.hs.config.account_validity
+ self.account_validity_renew_by_email_enabled = (
+ self.hs.config.account_validity_renew_by_email_enabled
+ )
async def on_POST(self, request):
- if not self.account_validity.renew_by_email_enabled:
+ if not self.account_validity_renew_by_email_enabled:
raise AuthError(
403, "Account renewal via email is disabled on this server."
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ea68114026..36807458de 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,6 +18,7 @@
import hmac
import logging
import random
+import re
from typing import List, Union
import synapse
@@ -118,10 +120,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
@@ -135,7 +137,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -197,7 +199,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -214,7 +218,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
@@ -353,15 +357,9 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
- ip = self.hs.get_ip_from_request(request)
- with self.ratelimiter.ratelimit(ip) as wait_deferred:
- await wait_deferred
-
- username = parse_string(request, "username", required=True)
-
- await self.registration_handler.check_username(username)
-
- return 200, {"available": True}
+ # We are not interested in logging in via a username in this deployment.
+ # Simply allow anything here as it won't be used later.
+ return 200, {"available": True}
class RegisterRestServlet(RestServlet):
@@ -411,18 +409,27 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
- # Pull out the provided username and do basic sanity checks early since
- # the auth layer will store these in sessions.
+ # We don't care about usernames for this deployment. In fact, the act
+ # of checking whether they exist already can leak metadata about
+ # which users are already registered.
+ #
+ # Usernames are already derived via the provided email.
+ # So, if they're not necessary, just ignore them.
+ #
+ # (we do still allow appservices to set them below)
desired_username = None
- if "username" in body:
- if not isinstance(body["username"], str) or len(body["username"]) > 512:
- raise SynapseError(400, "Invalid username")
- desired_username = body["username"]
+
+ desired_display_name = body.get("display_name")
appservice = None
if self.auth.has_access_token(request):
appservice = self.auth.get_appservice_by_req(request)
+ # We need to retrieve the password early in order to pass it to
+ # application service registration
+ # This is specific to shadow server registration of users via an AS
+ password = body.pop("password", None)
+
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -431,7 +438,7 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
- desired_username = body.get("user", desired_username)
+ desired_username = body.get("user", body.get("username"))
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
@@ -444,7 +451,7 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Desired Username is missing or not a string")
result = await self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username, password, desired_display_name, access_token, body
)
return 200, result
@@ -453,16 +460,6 @@ class RegisterRestServlet(RestServlet):
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
- # For regular registration, convert the provided username to lowercase
- # before attempting to register it. This should mean that people who try
- # to register with upper-case in their usernames don't get a nasty surprise.
- #
- # Note that we treat usernames case-insensitively in login, so they are
- # free to carry on imagining that their username is CrAzYh4cKeR if that
- # keeps them happy.
- if desired_username is not None:
- desired_username = desired_username.lower()
-
# Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
@@ -471,7 +468,6 @@ class RegisterRestServlet(RestServlet):
# Note that we remove the password from the body since the auth layer
# will store the body in the session and we don't want a plaintext
# password store there.
- password = body.pop("password", None)
if password is not None:
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
@@ -501,17 +497,10 @@ class RegisterRestServlet(RestServlet):
session_id, "password_hash", None
)
- # Ensure that the username is valid.
- if desired_username is not None:
- await self.registration_handler.check_username(
- desired_username,
- guest_access_token=guest_access_token,
- assigned_user_id=registered_user_id,
- )
-
# Check if the user-interactive authentication flows are complete, if
# not this will raise a user-interactive auth error.
try:
+ print("Trying")
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
self._registration_flows,
request,
@@ -547,7 +536,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (await check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -555,6 +544,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = await self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = _map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -569,7 +632,12 @@ class RegisterRestServlet(RestServlet):
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
- desired_username = params.get("username", None)
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
+
guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
@@ -618,6 +686,7 @@ class RegisterRestServlet(RestServlet):
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
user_agent_ips=entries,
@@ -630,6 +699,14 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
+ if self.hs.config.shadow_server:
+ await self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
@@ -651,11 +728,37 @@ class RegisterRestServlet(RestServlet):
return 200, return_dict
- async def _do_appservice_registration(self, username, as_token, body):
+ async def _do_appservice_registration(
+ self, username, password, display_name, as_token, body
+ ):
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
+
+ if password:
+ # Hash the password
+ #
+ # In mainline hashing of the password was moved further on in the registration
+ # flow, but we need it here for the AS use-case of shadow servers
+ password = await self.auth_handler.hash(password)
user_id = await self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password, display_name
)
- return await self._create_registration_details(user_id, body)
+ result = await self._create_registration_details(user_id, body)
+
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ await self.registration_handler.register_email_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ await self.registration_handler.register_msisdn_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ return result
async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
@@ -706,6 +809,60 @@ class RegisterRestServlet(RestServlet):
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 2b84eb89c0..8e52e4cca4 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
)
with context:
sync_result = await self.sync_handler.wait_for_sync_for_user(
+ requester,
sync_config,
since_token=since_token,
timeout=timeout,
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index ad598cefe0..eeddfa31f8 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@
# limitations under the License.
import logging
+from typing import Dict
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from signedjson.sign import sign_json
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import UserID
from ._base import client_patterns
@@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
"""Searches for users in directory
@@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = await self.http_client.post_json_get_json(url, signed_body)
+ return 200, resp
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,5 +95,124 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
+class SingleUserInfoServlet(RestServlet):
+ """
+ Deprecated and replaced by `/users/info`
+
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(SingleUserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ async def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = await self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ return 200, res
+
+ user_id_to_info = await self.store.get_info_for_users([user_id])
+ return 200, user_id_to_info[user_id]
+
+ async def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
+ return user_ids_to_info_dict[user_id]
+
+
+class UserInfoServlet(RestServlet):
+ """Bulk version of `/user/{user_id}/info` endpoint
+
+ GET /users/info HTTP/1.1
+
+ Returns a dictionary of user_id to info dictionary. Supports remote users
+ """
+
+ PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+
+ async def on_POST(self, request):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ # Extract the user_ids from the request
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, required=["user_ids"])
+
+ user_ids = body["user_ids"]
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # Separate local and remote users
+ local_user_ids = set()
+ remote_server_to_user_ids = {} # type: Dict[str, set]
+ for user_id in user_ids:
+ user = UserID.from_string(user_id)
+
+ if self.hs.is_mine(user):
+ local_user_ids.add(user_id)
+ else:
+ remote_server_to_user_ids.setdefault(user.domain, set())
+ remote_server_to_user_ids[user.domain].add(user_id)
+
+ # Retrieve info of all local users
+ user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)
+
+ # Request info of each remote user from their remote homeserver
+ for server_name, user_id_set in remote_server_to_user_ids.items():
+ # Make a request to the given server about their own users
+ res = await self.transport_layer.get_info_of_users(
+ server_name, list(user_id_set)
+ )
+
+ user_id_to_info_dict.update(res)
+
+ return 200, user_id_to_info_dict
+
+
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ SingleUserInfoServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d24a199318..c9b9e7f5ff 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -72,9 +72,12 @@ class VersionsRestServlet(RestServlet):
# MSC2326.
"org.matrix.label_based_filtering": True,
# Implements support for cross signing as described in MSC1756
- "org.matrix.e2e_cross_signing": True,
+ # "org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
+ # Tchap does not currently assume this rule for r0.5.0
+ # XXX: Remove this when it does
+ "m.lazy_load_members": True,
# Implements additional endpoints as described in MSC2666
"uk.half-shot.msc2666": True,
# Whether new rooms will be set to encrypted or not (based on presets).
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index c16280f668..d8e8e48c1c 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -66,7 +66,7 @@ class LocalKey(Resource):
def __init__(self, hs):
self.config = hs.config
- self.clock = hs.clock
+ self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
new file mode 100644
index 0000000000..6f2a1931c5
--- /dev/null
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.config._base import ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DomainRuleChecker(object):
+ """
+ A re-implementation of the SpamChecker that prevents users in one domain from
+ inviting users in other domains to rooms, based on a configuration.
+
+ Takes a config in the format:
+
+ spam_checker:
+ module: "rulecheck.DomainRuleChecker"
+ config:
+ domain_mapping:
+ "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
+ "other_inviter_domain": [ "invitee_domain_permitted" ]
+ default: False
+
+ # Only let local users join rooms if they were explicitly invited.
+ can_only_join_rooms_with_invite: false
+
+ # Only let local users create rooms if they are inviting only one
+ # other user, and that user matches the rules above.
+ can_only_create_one_to_one_rooms: false
+
+ # Only let local users invite during room creation, regardless of the
+ # domain mapping rules above.
+ can_only_invite_during_room_creation: false
+
+ # Prevent local users from inviting users from certain domains to
+ # rooms published in the room directory.
+ domains_prevented_from_being_invited_to_published_rooms: []
+
+ # Allow third party invites
+ can_invite_by_third_party_id: true
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config):
+ self.domain_mapping = config["domain_mapping"] or {}
+ self.default = config["default"]
+
+ self.can_only_join_rooms_with_invite = config.get(
+ "can_only_join_rooms_with_invite", False
+ )
+ self.can_only_create_one_to_one_rooms = config.get(
+ "can_only_create_one_to_one_rooms", False
+ )
+ self.can_only_invite_during_room_creation = config.get(
+ "can_only_invite_during_room_creation", False
+ )
+ self.can_invite_by_third_party_id = config.get(
+ "can_invite_by_third_party_id", True
+ )
+ self.domains_prevented_from_being_invited_to_published_rooms = config.get(
+ "domains_prevented_from_being_invited_to_published_rooms", []
+ )
+
+ def check_event_for_spam(self, event):
+ """Implements synapse.events.SpamChecker.check_event_for_spam
+ """
+ return False
+
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
+ """Implements synapse.events.SpamChecker.user_may_invite
+ """
+ if self.can_only_invite_during_room_creation and not new_room:
+ return False
+
+ if not self.can_invite_by_third_party_id and third_party_invite:
+ return False
+
+ # This is a third party invite (without a bound mxid), so unless we have
+ # banned all third party invites (above) we allow it.
+ if not invitee_userid:
+ return True
+
+ inviter_domain = self._get_domain_from_id(inviter_userid)
+ invitee_domain = self._get_domain_from_id(invitee_userid)
+
+ if inviter_domain not in self.domain_mapping:
+ return self.default
+
+ if (
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
+ ):
+ return False
+
+ return invitee_domain in self.domain_mapping[inviter_domain]
+
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
+ """Implements synapse.events.SpamChecker.user_may_create_room
+ """
+
+ if cloning:
+ return True
+
+ if not self.can_invite_by_third_party_id and third_party_invite_list:
+ return False
+
+ number_of_invites = len(invite_list) + len(third_party_invite_list)
+
+ if self.can_only_create_one_to_one_rooms and number_of_invites != 1:
+ return False
+
+ return True
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias
+ """
+ return True
+
+ def user_may_publish_room(self, userid, room_id):
+ """Implements synapse.events.SpamChecker.user_may_publish_room
+ """
+ return True
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Implements synapse.events.SpamChecker.user_may_join_room
+ """
+ if self.can_only_join_rooms_with_invite and not is_invited:
+ return False
+
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ """Implements synapse.events.SpamChecker.parse_config
+ """
+ if "default" in config:
+ return config
+ else:
+ raise ConfigError("No default set for spam_config DomainRuleChecker")
+
+ @staticmethod
+ def _get_domain_from_id(mxid):
+ """Parses a string and returns the domain part of the mxid.
+
+ Args:
+ mxid (str): a valid mxid
+
+ Returns:
+ str: the domain part of the mxid
+
+ """
+ idx = mxid.find(":")
+ if idx == -1:
+ raise Exception("Invalid ID: %r" % (mxid,))
+ return mxid[idx + 1 :]
diff --git a/synapse/server.py b/synapse/server.py
index 21a232bbd9..b017e3489f 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -27,7 +27,8 @@ import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
-import twisted
+import twisted.internet.base
+import twisted.internet.tcp
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS
@@ -89,6 +90,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
@@ -145,7 +147,8 @@ def cache_in_self(builder: T) -> T:
"@cache_in_self can only be used on functions starting with `get_`"
)
- depname = builder.__name__[len("get_") :]
+ # get_attr -> _attr
+ depname = builder.__name__[len("get") :]
building = [False]
@@ -233,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master"
- self.clock = Clock(reactor)
- self.distributor = Distributor()
-
- self.registration_ratelimiter = Ratelimiter(
- clock=self.clock,
- rate_hz=config.rc_registration.per_second,
- burst_count=config.rc_registration.burst_count,
- )
-
self.version_string = version_string
self.datastores = None # type: Optional[Databases]
@@ -299,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname
+ @cache_in_self
def get_clock(self) -> Clock:
- return self.clock
+ return Clock(self._reactor)
def get_datastore(self) -> DataStore:
if not self.datastores:
@@ -317,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_config(self) -> HomeServerConfig:
return self.config
+ @cache_in_self
def get_distributor(self) -> Distributor:
- return self.distributor
+ return Distributor()
+ @cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:
- return self.registration_ratelimiter
+ return Ratelimiter(
+ clock=self.get_clock(),
+ rate_hz=self.config.rc_registration.per_second,
+ burst_count=self.config.rc_registration.burst_count,
+ )
@cache_in_self
def get_federation_client(self) -> FederationClient:
@@ -391,6 +392,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return FollowerTypingHandler(self)
@cache_in_self
+ def get_sso_handler(self) -> SsoHandler:
+ return SsoHandler(self)
+
+ @cache_in_self
def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)
@@ -681,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
- return FederationRateLimiter(self.clock, config=self.config.rc_federation)
+ return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
@cache_in_self
def get_module_api(self) -> ModuleApi:
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index d464c75c03..100dbd5e2c 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -39,6 +39,7 @@ class ServerNoticesManager:
self._room_member_handler = hs.get_room_member_handler()
self._event_creation_handler = hs.get_event_creation_handler()
self._is_mine_id = hs.is_mine_id
+ self._server_name = hs.hostname
self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.server_notices_mxid
@@ -72,7 +73,9 @@ class ServerNoticesManager:
await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid
- requester = create_requester(system_mxid)
+ requester = create_requester(
+ system_mxid, authenticated_entity=self._server_name
+ )
logger.info("Sending server notice to %s", user_id)
@@ -145,7 +148,9 @@ class ServerNoticesManager:
"avatar_url": self._config.server_notices_mxid_avatar_url,
}
- requester = create_requester(self.server_notices_mxid)
+ requester = create_requester(
+ self.server_notices_mxid, authenticated_entity=self._server_name
+ )
info, _ = await self._room_creation_handler.create_room(
requester,
config={
@@ -174,7 +179,9 @@ class ServerNoticesManager:
user_id: The ID of the user to invite.
room_id: The ID of the room to invite the user to.
"""
- requester = create_requester(self.server_notices_mxid)
+ requester = create_requester(
+ self.server_notices_mxid, authenticated_entity=self._server_name
+ )
# Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them.
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 0e25ca3d7a..2570003457 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.roommember import ProfileInfo
+from synapse.types import UserID
+from synapse.util.caches.descriptors import cached
+
+BATCH_SIZE = 100
class ProfileWorkerStore(SQLBaseStore):
@@ -39,6 +44,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -47,6 +53,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -55,6 +62,58 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ async def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return await self.db_pool.runInteraction(
+ "get_latest_profile_replication_batch_number", f
+ )
+
+ async def get_profile_batch(self, batchnum):
+ return await self.db_pool.simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ async def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return await self.db_pool.runInteraction("assign_profile_batch", f)
+
+ async def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return await self.db_pool.runInteraction("get_replication_hosts", f)
+
+ async def update_replication_batch_for_host(
+ self, host: str, last_synced_batch: int
+ ):
+ return await self.db_pool.simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
async def get_from_remote_profile_cache(
self, user_id: str
) -> Optional[Dict[str, Any]]:
@@ -72,32 +131,95 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_displayname(
- self, user_localpart: str, new_displayname: Optional[str]
+ self, user_localpart: str, new_displayname: Optional[str], batchnum: int
) -> None:
- await self.db_pool.simple_update_one(
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
async def set_profile_avatar_url(
- self, user_localpart: str, new_avatar_url: str
+ self, user_localpart: str, new_avatar_url: str, batchnum: int
) -> None:
- await self.db_pool.simple_update_one(
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ async def set_profiles_active(
+ self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ ) -> None:
+ """Given a set of users, set active and hidden flags on them.
+
+ Args:
+ users: A list of UserIDs
+ active: Whether to set the users to active or inactive
+ hide: Whether to hide the users (withold from replication). If
+ False and active is False, users will have their profiles
+ erased
+ batchnum: The batch number, used for profile replication
+ """
+ # Convert list of localparts to list of tuples containing localparts
+ user_localparts = [(user.localpart,) for user in users]
+
+ # Generate list of value tuples for each user
+ value_names = ("active", "batch")
+ values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
+
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile information
+ value_names += ("avatar_url", "displayname")
+ values = [v + (None, None) for v in values]
+
+ return await self.db_pool.runInteraction(
+ "set_profiles_active",
+ self.db_pool.simple_upsert_many_txn,
+ table="profiles",
+ key_names=("user_id",),
+ key_values=user_localparts,
+ value_names=value_names,
+ value_values=values,
+ )
+
+ async def add_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> None:
+ """Ensure we are caching the remote user's profiles.
+
+ This should only be called when `is_subscribed_remote_profile_for_user`
+ would return true for the user.
+ """
+ await self.db_pool.simple_upsert(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="add_remote_profile_cache",
)
async def update_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> int:
- return await self.db_pool.simple_update(
+ return await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
- updatevalues={
+ values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
@@ -166,6 +288,17 @@ class ProfileWorkerStore(SQLBaseStore):
class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc6717b3..5d668aadb2 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
for table in (
"event_auth",
"event_edges",
+ "event_json",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
@@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
- "event_json",
"event_push_actions",
"event_search",
"events",
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index ca7917c989..1e7949a323 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -278,7 +278,8 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
- """Get receipts for all rooms between two stream_ids.
+ """Get receipts for all rooms between two stream_ids, up
+ to a limit of the latest 100 read receipts.
Args:
to_key: Max stream id to fetch receipts upto.
@@ -294,12 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id DESC
+ LIMIT 100
"""
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ?
+ ORDER BY stream_id DESC
+ LIMIT 100
"""
txn.execute(sql, [to_key])
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e5d07ce72a..1624738409 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -82,12 +82,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
- self._account_validity = hs.config.account_validity
- if hs.config.run_background_tasks and self._account_validity.enabled:
- self._clock.call_later(
- 0.0, self._set_expiration_date_when_missing,
+ self._account_validity_enabled = hs.config.account_validity_enabled
+ if self._account_validity_enabled:
+ self._account_validity_period = hs.config.account_validity_period
+ self._account_validity_startup_job_max_delta = (
+ hs.config.account_validity_startup_job_max_delta
)
+ if hs.config.run_background_tasks:
+ self._clock.call_later(
+ 0.0, self._set_expiration_date_when_missing,
+ )
+
# Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks:
self._clock.looping_call(
@@ -183,6 +189,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts: int,
email_sent: bool,
renewal_token: Optional[str] = None,
+ token_used_ts: Optional[int] = None,
) -> None:
"""Updates the account validity properties of the given account, with the
given values.
@@ -196,6 +203,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
period.
renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token.
+ token_used_ts: A timestamp of when the current token was used to renew
+ the account.
"""
def set_account_validity_for_user_txn(txn):
@@ -207,6 +216,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"expiration_ts_ms": expiration_ts,
"email_sent": email_sent,
"renewal_token": renewal_token,
+ "token_used_ts_ms": token_used_ts,
},
)
self._invalidate_cache_and_stream(
@@ -217,10 +227,41 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"set_account_validity_for_user", set_account_validity_for_user_txn
)
+ async def get_expired_users(self):
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
+
+ Returns:
+ Deferred[List[UserID]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
+ sql = """
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+
+ return [UserID.from_string(row[0]) for row in rows]
+
+ res = await self.db_pool.runInteraction(
+ "get_expired_users", get_expired_users_txn, self._clock.time_msec()
+ )
+
+ return res
+
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
- """Defines a renewal token for a given user.
+ """Defines a renewal token for a given user, and clears the token_used timestamp.
Args:
user_id: ID of the user to set the renewal token for.
@@ -233,26 +274,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
- updatevalues={"renewal_token": renewal_token},
+ updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None},
desc="set_renewal_token_for_user",
)
- async def get_user_from_renewal_token(self, renewal_token: str) -> str:
- """Get a user ID from a renewal token.
+ async def get_user_from_renewal_token(
+ self, renewal_token: str
+ ) -> Tuple[str, int, Optional[int]]:
+ """Get a user ID and renewal status from a renewal token.
Args:
renewal_token: The renewal token to perform the lookup with.
Returns:
- The ID of the user to which the token belongs.
+ A tuple of containing the following values:
+ * The ID of a user to which the token belongs.
+ * An int representing the user's expiry timestamp as milliseconds since the
+ epoch, or 0 if the token was invalid.
+ * An optional int representing the timestamp of when the user renewed their
+ account timestamp as milliseconds since the epoch. None if the account
+ has not been renewed using the current token yet.
"""
- return await self.db_pool.simple_select_one_onecol(
+ ret_dict = await self.db_pool.simple_select_one(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
- retcol="user_id",
+ retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
)
+ return (
+ ret_dict["user_id"],
+ ret_dict["expiration_ts_ms"],
+ ret_dict["token_used_ts_ms"],
+ )
+
async def get_renewal_token_for_user(self, user_id: str) -> str:
"""Get the renewal token associated with a given user ID.
@@ -291,7 +346,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"get_users_expiring_soon",
select_users_txn,
self._clock.time_msec(),
- self.config.account_validity.renew_at,
+ self.config.account_validity_renew_at,
)
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@@ -323,6 +378,54 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="delete_account_validity_for_user",
)
+ async def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = await self.db_pool.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self._clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
@@ -885,11 +988,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
delta equal to 10% of the validity period.
"""
now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
+ expiration_ts = now_ms + self._account_validity_period
if use_delta:
expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts - self._account_validity_startup_job_max_delta,
expiration_ts,
)
@@ -1110,6 +1213,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
+ puppets_user_id: Optional[str] = None,
) -> int:
"""Adds an access token for the given user.
@@ -1133,6 +1237,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"token": token,
"device_id": device_id,
"valid_until_ms": valid_until_ms,
+ "puppets_user_id": puppets_user_id,
},
desc="add_access_token_to_user",
)
@@ -1279,7 +1384,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
except self.database_engine.module.IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
- if self._account_validity.enabled:
+ if self._account_validity_enabled:
self.set_expiration_date_for_user_txn(txn, user_id)
if create_profile_with_displayname:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index dc0c4b5499..4808bf6294 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -356,6 +356,23 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ async def is_room_published(self, room_id: str) -> bool:
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id
+ Returns:
+ Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = await self.get_room(room_id)
+ if not room_info:
+ return False
+
+ # Check the is_public value
+ return room_info.get("is_public", False)
+
async def get_rooms_paginate(
self,
start: int,
@@ -564,6 +581,11 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ return {"min_lifetime": None, "max_lifetime": None}
def get_retention_policy_for_room_txn(txn):
txn.execute(
@@ -1240,13 +1262,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+ async def maybe_store_room_on_outlier_membership(
+ self, room_id: str, room_version: RoomVersion
+ ):
"""
- When we receive an invite over federation, store the version of the room if we
- don't already know the room version.
+ When we receive an invite or any other event over federation that may relate to a room
+ we are not in, store the version of the room if we don't already know the room version.
"""
await self.db_pool.simple_upsert(
- desc="maybe_store_room_on_invite",
+ desc="maybe_store_room_on_outlier_membership",
table="rooms",
keyvalues={"room_id": room_id},
values={},
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 01d9dbb36f..dcdaf09682 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -350,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
+ async def get_local_current_membership_for_user_in_room(
+ self, user_id: str, room_id: str
+ ) -> Tuple[Optional[str], Optional[str]]:
+ """Retrieve the current local membership state and event ID for a user in a room.
+
+ Args:
+ user_id: The ID of the user.
+ room_id: The ID of the room.
+
+ Returns:
+ A tuple of (membership_type, event_id). Both will be None if a
+ room_id/user_id pair is not found.
+ """
+ # Paranoia check.
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'get_local_current_membership_for_user_in_room' on "
+ "non-local user %s" % (user_id,),
+ )
+
+ results_dict = await self.db_pool.simple_select_one(
+ "local_current_membership",
+ {"room_id": room_id, "user_id": user_id},
+ ("membership", "event_id"),
+ allow_none=True,
+ desc="get_local_current_membership_for_user_in_room",
+ )
+ if not results_dict:
+ return None, None
+
+ return results_dict.get("membership"), results_dict.get("event_id")
+
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
diff --git a/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..96051ac179
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..7542ab8cbd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
index b64926e9c9..3275ae2b20 100644
--- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -20,14 +20,14 @@
*/
-- add new index that includes method to local media
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('local_media_repository_thumbnails_method_idx', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5807, 'local_media_repository_thumbnails_method_idx', '{}');
-- add new index that includes method to remote media
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
-- drop old index
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
index cade5dcca8..fd733adf13 100644
--- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
+++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
@@ -28,5 +28,5 @@
-- functionality as the old one. This effectively restarts the background job
-- from the beginning, without running it twice in a row, supporting both
-- upgrade usecases.
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('populate_stats_process_rooms_2', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5812, 'populate_stats_process_rooms_2', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql b/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql
new file mode 100644
index 0000000000..4836dac16e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track when users renew their account using the value of the 'renewal_token' column.
+-- This field should be set to NULL after a fresh token is generated.
+ALTER TABLE account_validity ADD token_used_ts_ms BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
index a2842687f1..e1a35be831 100644
--- a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
+++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
@@ -1,2 +1,2 @@
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('users_have_local_media', '{}');
\ No newline at end of file
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5822, 'users_have_local_media', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
index 61c558db77..75c3915a94 100644
--- a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
@@ -13,5 +13,5 @@
* limitations under the License.
*/
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('e2e_cross_signing_keys_idx', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5823, 'e2e_cross_signing_keys_idx', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
new file mode 100644
index 0000000000..8a39d54aed
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this index is essentially redundant. The only time it was ever used was when purging
+-- rooms - and Synapse 1.24 will change that.
+
+DROP INDEX IF EXISTS event_json_room_id;
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
new file mode 100644
index 0000000000..f158ab6676
--- /dev/null
+++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,966 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import email.utils
+import logging
+from typing import Dict, List, Optional, Tuple
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.config._base import ConfigError
+from synapse.events import EventBase
+from synapse.module_api import ModuleApi
+from synapse.types import Requester, StateMap, UserID, get_domain_from_id
+
+logger = logging.getLogger(__name__)
+
+ACCESS_RULES_TYPE = "im.vector.room.access_rules"
+
+
+class AccessRules:
+ DIRECT = "direct"
+ RESTRICTED = "restricted"
+ UNRESTRICTED = "unrestricted"
+
+
+VALID_ACCESS_RULES = (
+ AccessRules.DIRECT,
+ AccessRules.RESTRICTED,
+ AccessRules.UNRESTRICTED,
+)
+
+# Rules to which we need to apply the power levels restrictions.
+#
+# These are all of the rules that neither:
+# * forbid users from joining based on a server blacklist (which means that there
+# is no need to apply power level restrictions), nor
+# * target direct chats (since we allow both users to be room admins in this case).
+#
+# The power-level restrictions, when they are applied, prevent the following:
+# * the default power level for users (users_default) being set to anything other than 0.
+# * a non-default power level being assigned to any user which would be forbidden from
+# joining a restricted room.
+RULES_WITH_RESTRICTED_POWER_LEVELS = (AccessRules.UNRESTRICTED,)
+
+
+class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(
+ self, config: Dict, module_api: ModuleApi,
+ ):
+ self.id_server = config["id_server"]
+ self.module_api = module_api
+
+ self.domains_forbidden_when_restricted = config.get(
+ "domains_forbidden_when_restricted", []
+ )
+
+ @staticmethod
+ def parse_config(config: Dict) -> Dict:
+ """Parses and validates the options specified in the homeserver config.
+
+ Args:
+ config: The config dict.
+
+ Returns:
+ The config dict.
+
+ Raises:
+ ConfigError: If there was an issue with the provided module configuration.
+ """
+ if "id_server" not in config:
+ raise ConfigError("No IS for event rules TchapEventRules")
+
+ return config
+
+ async def on_create_room(
+ self, requester: Requester, config: Dict, is_requester_admin: bool,
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room.
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+
+ Checks if a m.rooms.power_levels event is being set during room creation.
+ If yes, make sure the event is allowed. Otherwise, set power_level_content_override
+ in the config dict to our modified version of the default room power levels.
+
+ Args:
+ requester: The user who is making the createRoom request.
+ config: The createRoom config dict provided by the user.
+ is_requester_admin: Whether the requester is a Synapse admin.
+
+ Returns:
+ Whether the request is allowed.
+
+ Raises:
+ SynapseError: If the createRoom config dict is invalid or its contents blocked.
+ """
+ is_direct = config.get("is_direct")
+ preset = config.get("preset")
+ access_rule = None
+ join_rule = None
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and deny the request if not.
+ for event in config.get("initial_state", []):
+ if event["type"] == ACCESS_RULES_TYPE:
+ access_rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if access_rule is None:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule name is valid.
+ if access_rule not in VALID_ACCESS_RULES:
+ raise SynapseError(400, "Invalid access rule")
+
+ if (is_direct and access_rule != AccessRules.DIRECT) or (
+ access_rule == AccessRules.DIRECT and not is_direct
+ ):
+ raise SynapseError(400, "Invalid access rule")
+
+ if event["type"] == EventTypes.JoinRules:
+ join_rule = event["content"].get("join_rule")
+
+ if access_rule is None:
+ # If there's no access rules event in the initial state, create one with the
+ # default setting.
+ if is_direct:
+ default_rule = AccessRules.DIRECT
+ else:
+ # If the default value for non-direct chat changes, we should make another
+ # case here for rooms created with either a "public" join_rule or the
+ # "public_chat" preset to make sure those keep defaulting to "restricted"
+ default_rule = AccessRules.RESTRICTED
+
+ if not config.get("initial_state"):
+ config["initial_state"] = []
+
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
+ }
+ )
+
+ access_rule = default_rule
+
+ # Check that the preset in use is compatible with the access rule, whether it's
+ # user-defined or the default.
+ #
+ # Direct rooms may not have their join_rules set to JoinRules.PUBLIC.
+ if (
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule == AccessRules.DIRECT:
+ raise SynapseError(400, "Invalid access rule")
+
+ default_power_levels = self._get_default_power_levels(
+ requester.user.to_string()
+ )
+
+ # Check if the creator can override values for the power levels.
+ allowed = self._is_power_level_content_allowed(
+ config.get("power_level_content_override", {}),
+ access_rule,
+ default_power_levels,
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content override")
+
+ custom_user_power_levels = config.get("power_level_content_override")
+
+ # Second loop for events we need to know the current rule to process.
+ for event in config.get("initial_state", []):
+ if event["type"] == EventTypes.PowerLevels:
+ allowed = self._is_power_level_content_allowed(
+ event["content"], access_rule, default_power_levels
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content")
+
+ custom_user_power_levels = event["content"]
+ if custom_user_power_levels:
+ # If the user is using their own power levels, but failed to provide an expected
+ # key in the power levels content dictionary, fill it in from the defaults instead
+ for key, value in default_power_levels.items():
+ custom_user_power_levels.setdefault(key, value)
+ else:
+ # If power levels were not overridden by the user, completely override with the
+ # defaults instead
+ config["power_level_content_override"] = default_power_levels
+
+ return True
+
+ # If power levels are not overridden by the user during room creation, the following
+ # rules are used instead. Changes from Synapse's default power levels are noted.
+ #
+ # The same power levels are currently applied regardless of room preset.
+ @staticmethod
+ def _get_default_power_levels(user_id: str) -> Dict:
+ return {
+ "users": {user_id: 100},
+ "users_default": 0,
+ "events": {
+ EventTypes.Name: 50,
+ EventTypes.PowerLevels: 100,
+ EventTypes.RoomHistoryVisibility: 100,
+ EventTypes.CanonicalAlias: 50,
+ EventTypes.RoomAvatar: 50,
+ EventTypes.Tombstone: 100,
+ EventTypes.ServerACL: 100,
+ EventTypes.RoomEncryption: 100,
+ },
+ "events_default": 0,
+ "state_default": 100, # Admins should be the only ones to perform other tasks
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 50, # All rooms should require mod to invite, even private
+ }
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(
+ self, medium: str, address: str, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited.
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+
+ Args:
+ medium: The medium of the threepid.
+ address: The address of the threepid.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room the threepid is being invited to.
+
+ Returns:
+ Whether the threepid invite is allowed.
+ """
+ rule = self._get_rule_from_state(state_events)
+
+ if medium != "email":
+ return False
+
+ if rule != AccessRules.RESTRICTED:
+ # Only "restricted" requires filtering 3PID invites. We don't need to do
+ # anything for "direct" here, because only "restricted" requires filtering
+ # based on the HS the address is mapped to.
+ return True
+
+ parsed_address = email.utils.parseaddr(address)[1]
+ if parsed_address != address:
+ # Avoid reproducing the security issue described here:
+ # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2
+ # It's probably not worth it but let's just be overly safe here.
+ return False
+
+ # Get the HS this address belongs to from the identity server.
+ res = yield self.module_api.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
+ {"medium": medium, "address": address},
+ )
+
+ # Look for a domain that's not forbidden from being invited.
+ if not res.get("hs"):
+ return False
+ if res.get("hs") in self.domains_forbidden_when_restricted:
+ return False
+
+ return True
+
+ async def check_event_allowed(
+ self, event: EventBase, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed.
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+
+ Args:
+ event: The event to check.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room the event originated from.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ if event.type == ACCESS_RULES_TYPE:
+ return await self._on_rules_change(event, state_events)
+
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ if event.type == EventTypes.PowerLevels:
+ return self._is_power_level_content_allowed(
+ event.content, rule, on_room_creation=False
+ )
+
+ if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
+ return await self._on_membership_or_invite(event, rule, state_events)
+
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
+ if event.type == EventTypes.RoomAvatar:
+ return self._on_room_avatar_change(event, rule)
+
+ if event.type == EventTypes.Name:
+ return self._on_room_name_change(event, rule)
+
+ if event.type == EventTypes.Topic:
+ return self._on_room_topic_change(event, rule)
+
+ return True
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, state_events: StateMap[EventBase], new_visibility: str
+ ) -> bool:
+ """Implements
+ synapse.events.ThirdPartyEventRules.check_visibility_can_be_modified
+
+ Determines whether a room can be published, or removed from, the public room
+ list. A room is published if its visibility is set to "public". Otherwise,
+ its visibility is "private". A room with access rule other than "restricted"
+ may not be published.
+
+ Args:
+ room_id: The ID of the room.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ Whether the room is allowed to be published to, or removed from, the public
+ rooms directory.
+ """
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ # Allow adding a room to the public rooms list only if it is restricted
+ if new_visibility == "public":
+ return rule == AccessRules.RESTRICTED
+
+ # By default a room is created as "restricted", meaning it is allowed to be
+ # published to the public rooms directory.
+ return True
+
+ async def _on_rules_change(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ):
+ """Checks whether an im.vector.room.access_rules event is forbidden or allowed.
+
+ Args:
+ event: The im.vector.room.access_rules event.
+ state_events: A dict mapping (event type, state key) to state event.
+ State events in the room before the event was sent.
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ new_rule = event.content.get("rule")
+
+ # Check for invalid values.
+ if new_rule not in VALID_ACCESS_RULES:
+ return False
+
+ # Make sure we don't apply "direct" if the room has more than two members.
+ if new_rule == AccessRules.DIRECT:
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ if len(existing_members) > 2 or len(threepid_tokens) > 1:
+ return False
+
+ if new_rule != AccessRules.RESTRICTED:
+ # Block this change if this room is currently listed in the public rooms
+ # directory
+ if await self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ event.room_id
+ ):
+ return False
+
+ prev_rules_event = state_events.get((ACCESS_RULES_TYPE, ""))
+
+ # Now that we know the new rule doesn't break the "direct" case, we can allow any
+ # new rule in rooms that had none before.
+ if prev_rules_event is None:
+ return True
+
+ prev_rule = prev_rules_event.content.get("rule")
+
+ # Currently, we can only go from "restricted" to "unrestricted".
+ return (
+ prev_rule == AccessRules.RESTRICTED and new_rule == AccessRules.UNRESTRICTED
+ )
+
+ async def _on_membership_or_invite(
+ self, event: EventBase, rule: str, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Applies the correct rule for incoming m.room.member and
+ m.room.third_party_invite events.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+ state_events: A dict mapping (event type, state key) to state event.
+ The state of the room before the event was sent.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ if rule == AccessRules.RESTRICTED:
+ ret = self._on_membership_or_invite_restricted(event)
+ elif rule == AccessRules.UNRESTRICTED:
+ ret = self._on_membership_or_invite_unrestricted(event, state_events)
+ elif rule == AccessRules.DIRECT:
+ ret = self._on_membership_or_invite_direct(event, state_events)
+ else:
+ # We currently apply the default (restricted) if we don't know the rule, we
+ # might want to change that in the future.
+ ret = self._on_membership_or_invite_restricted(event)
+
+ if event.type == "m.room.member":
+ # If this is an admin leaving, and they are the last admin in the room,
+ # raise the power levels of the room so that the room is 'frozen'.
+ #
+ # We have to freeze the room by puppeting an admin user, which we can
+ # only do for local users
+ if (
+ self._is_local_user(event.sender)
+ and event.membership == Membership.LEAVE
+ ):
+ await self._freeze_room_if_last_admin_is_leaving(event, state_events)
+
+ return ret
+
+ async def _freeze_room_if_last_admin_is_leaving(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ):
+ power_level_state_event = state_events.get(
+ (EventTypes.PowerLevels, "")
+ ) # type: EventBase
+ if not power_level_state_event:
+ return
+ power_level_content = power_level_state_event.content
+
+ # Do some validation checks on the power level state event
+ if (
+ not isinstance(power_level_content, dict)
+ or "users" not in power_level_content
+ or not isinstance(power_level_content["users"], dict)
+ ):
+ # We can't use this power level event to determine whether the room should be
+ # frozen. Bail out.
+ return
+
+ user_id = event.get("sender")
+ if not user_id:
+ return
+
+ # Get every admin user defined in the room's state
+ admin_users = {
+ user
+ for user, power_level in power_level_content["users"].items()
+ if power_level >= 100
+ }
+
+ if user_id not in admin_users:
+ # This user is not an admin, ignore them
+ return
+
+ if any(
+ event_type == EventTypes.Member
+ and event.membership in [Membership.JOIN, Membership.INVITE]
+ and state_key in admin_users
+ and state_key != user_id
+ for (event_type, state_key), event in state_events.items()
+ ):
+ # There's another admin user in, or invited to, the room
+ return
+
+ # Freeze the room by raising the required power level to send events to 100
+ logger.info("Freezing room '%s'", event.room_id)
+
+ # Modify the existing power levels to raise all required types to 100
+ #
+ # This changes a power level state event's content from something like:
+ # {
+ # "redact": 50,
+ # "state_default": 50,
+ # "ban": 50,
+ # "notifications": {
+ # "room": 50
+ # },
+ # "events": {
+ # "m.room.avatar": 50,
+ # "m.room.encryption": 50,
+ # "m.room.canonical_alias": 50,
+ # "m.room.name": 50,
+ # "im.vector.modular.widgets": 50,
+ # "m.room.topic": 50,
+ # "m.room.tombstone": 50,
+ # "m.room.history_visibility": 100,
+ # "m.room.power_levels": 100
+ # },
+ # "users_default": 0,
+ # "events_default": 0,
+ # "users": {
+ # "@admin:example.com": 100,
+ # },
+ # "kick": 50,
+ # "invite": 0
+ # }
+ #
+ # to
+ #
+ # {
+ # "redact": 100,
+ # "state_default": 100,
+ # "ban": 100,
+ # "notifications": {
+ # "room": 50
+ # },
+ # "events": {}
+ # "users_default": 0,
+ # "events_default": 100,
+ # "users": {
+ # "@admin:example.com": 100,
+ # },
+ # "kick": 100,
+ # "invite": 100
+ # }
+ new_content = {}
+ for key, value in power_level_content.items():
+ # Do not change "users_default", as that key specifies the default power
+ # level of new users
+ if isinstance(value, int) and key != "users_default":
+ value = 100
+ new_content[key] = value
+
+ # Set some values in case they are missing from the original
+ # power levels event content
+ new_content.update(
+ {
+ # Clear out any special-cased event keys
+ "events": {},
+ # Ensure state_default and events_default keys exist and are 100.
+ # Otherwise a lower PL user could potentially send state events that
+ # aren't explicitly mentioned elsewhere in the power level dict
+ "state_default": 100,
+ "events_default": 100,
+ # Membership events default to 50 if they aren't present. Set them
+ # to 100 here, as they would be set to 100 if they were present anyways
+ "ban": 100,
+ "kick": 100,
+ "invite": 100,
+ "redact": 100,
+ }
+ )
+
+ await self.module_api.create_and_send_event_into_room(
+ {
+ "room_id": event.room_id,
+ "sender": user_id,
+ "type": EventTypes.PowerLevels,
+ "content": new_content,
+ "state_key": "",
+ }
+ )
+
+ def _on_membership_or_invite_restricted(self, event: EventBase) -> bool:
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ "restricted" currently means that users can only invite users if their server is
+ included in a limited list of domains.
+
+ Args:
+ event: The event to check.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # We're not applying the rules on m.room.third_party_member events here because
+ # the filtering on threepids is done in check_threepid_can_be_invited, which is
+ # called before check_event_allowed.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return True
+
+ # We only need to process "join" and "invite" memberships, in order to be backward
+ # compatible, e.g. if a user from a blacklisted server joined a restricted room
+ # before the rules started being enforced on the server, that user must be able to
+ # leave it.
+ if event.membership not in [Membership.JOIN, Membership.INVITE]:
+ return True
+
+ invitee_domain = get_domain_from_id(event.state_key)
+ return invitee_domain not in self.domains_forbidden_when_restricted
+
+ def _on_membership_or_invite_unrestricted(
+ self, event: EventBase, state_events: StateMap[EventBase]
+ ) -> bool:
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ "unrestricted" currently means that forbidden users cannot join without an invite.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # If this is a join from a forbidden user and they don't have an invite to the
+ # room, then deny it
+ if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+ # Check if this user is from a forbidden server
+ target_domain = get_domain_from_id(event.state_key)
+ if target_domain in self.domains_forbidden_when_restricted:
+ # If so, they'll need an invite to join this room. Check if one exists
+ if not self._user_is_invited_to_room(event.state_key, state_events):
+ return False
+
+ return True
+
+ def _on_membership_or_invite_direct(
+ self, event: EventBase, state_events: StateMap[EventBase],
+ ) -> bool:
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ "direct" currently means that no member is allowed apart from the two initial
+ members the room was created for (i.e. the room's creator and their first invitee).
+
+ Args:
+ event: The event to check.
+ state_events: A dict mapping (event type, state key) to state event.
+ The state of the room before the event was sent.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ # Get the room memberships and 3PID invite tokens from the room's state.
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ # There should never be more than one 3PID invite in the room state: if the second
+ # original user came and left, and we're inviting them using their email address,
+ # given we know they have a Matrix account binded to the address (so they could
+ # join the first time), Synapse will successfully look it up before attempting to
+ # store an invite on the IS.
+ if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite:
+ # If we already have a 3PID invite in flight, don't accept another one, unless
+ # the new one has the same invite token as its state key. This is because 3PID
+ # invite revocations must be allowed, and a revocation is basically a new 3PID
+ # invite event with an empty content and the same token as the invite it
+ # revokes.
+ return event.state_key in threepid_tokens
+
+ if len(existing_members) == 2:
+ # If the user was within the two initial user of the room, Synapse would have
+ # looked it up successfully and thus sent a m.room.member here instead of
+ # m.room.third_party_invite.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return False
+
+ # We can only have m.room.member events here. The rule in this case is to only
+ # allow the event if its target is one of the initial two members in the room,
+ # i.e. the state key of one of the two m.room.member states in the room.
+ return event.state_key in existing_members
+
+ # We're alone in the room (and always have been) and there's one 3PID invite in
+ # flight.
+ if len(existing_members) == 1 and len(threepid_tokens) == 1:
+ # We can only have m.room.member events here. In this case, we can only allow
+ # the event if it's either a m.room.member from the joined user (we can assume
+ # that the only m.room.member event is a join otherwise we wouldn't be able to
+ # send an event to the room) or an an invite event which target is the invited
+ # user.
+ target = event.state_key
+ is_from_threepid_invite = self._is_invite_from_threepid(
+ event, threepid_tokens[0]
+ )
+ return is_from_threepid_invite or target == existing_members[0]
+
+ return True
+
+ def _is_power_level_content_allowed(
+ self,
+ content: Dict,
+ access_rule: str,
+ default_power_levels: Optional[Dict] = None,
+ on_room_creation: bool = True,
+ ) -> bool:
+ """Check if a given power levels event is permitted under the given access rule.
+
+ It shouldn't be allowed if it either changes the default PL to a non-0 value or
+ gives a non-0 PL to a user that would have been forbidden from joining the room
+ under a more restrictive access rule.
+
+ Args:
+ content: The content of the m.room.power_levels event to check.
+ access_rule: The access rule in place in this room.
+ default_power_levels: The default power levels when a room is created with
+ the specified access rule. Required if on_room_creation is True.
+ on_room_creation: True if this call is happening during a room's
+ creation, False otherwise.
+
+ Returns:
+ Whether the content of the power levels event is valid.
+ """
+ # Only enforce these rules during room creation
+ #
+ # We want to allow admins to modify or fix the power levels in a room if they
+ # have a special circumstance, but still want to encourage a certain pattern during
+ # room creation.
+ if on_room_creation:
+ # We specifically don't fail if "invite" or "state_default" are None, as those
+ # values should be replaced with our "default" power level values anyways,
+ # which are compliant
+
+ invite = default_power_levels["invite"]
+ state_default = default_power_levels["state_default"]
+
+ # If invite requirements are less than our required defaults
+ if content.get("invite", invite) < invite:
+ return False
+
+ # If "other" state requirements are less than our required defaults
+ if content.get("state_default", state_default) < state_default:
+ return False
+
+ # Check if we need to apply the restrictions with the current rule.
+ if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
+ return True
+
+ # If users_default is explicitly set to a non-0 value, deny the event.
+ users_default = content.get("users_default", 0)
+ if users_default:
+ return False
+
+ users = content.get("users", {})
+ for user_id, power_level in users.items():
+ server_name = get_domain_from_id(user_id)
+ # Check the domain against the blacklist. If found, and the PL isn't 0, deny
+ # the event.
+ if (
+ server_name in self.domains_forbidden_when_restricted
+ and power_level != 0
+ ):
+ return False
+
+ return True
+
+ def _on_join_rule_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a join rule change is allowed.
+
+ A join rule change is always allowed unless the new join rule is "public" and
+ the current access rule is "direct".
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ Whether the change is allowed.
+ """
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
+ return rule != AccessRules.DIRECT
+
+ return True
+
+ def _on_room_avatar_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room avatar is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ def _on_room_name_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room name is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ def _on_room_topic_change(self, event: EventBase, rule: str) -> bool:
+ """Check whether a change of room topic is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event: The event to check.
+ rule: The name of the rule to apply.
+
+ Returns:
+ True if the event can be allowed, False otherwise.
+ """
+ return rule != AccessRules.DIRECT
+
+ @staticmethod
+ def _get_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]:
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events: A dict mapping (event type, state key) to state event.
+
+ Returns:
+ The name of the rule (either "direct", "restricted" or "unrestricted") if found,
+ else None.
+ """
+ access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
+ if access_rules is None:
+ return AccessRules.RESTRICTED
+
+ return access_rules.content.get("rule")
+
+ @staticmethod
+ def _get_join_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]:
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+
+ Returns:
+ The name of the join rule (either "public", or "invite") if found, else None.
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return None
+
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
+ def _get_members_and_tokens_from_state(
+ state_events: StateMap[EventBase],
+ ) -> Tuple[List[str], List[str]]:
+ """Retrieves the list of users that have a m.room.member event in the room,
+ as well as 3PID invites tokens in the room.
+
+ Args:
+ state_events: A dict mapping (event type, state key) to state event.
+
+ Returns:
+ A tuple containing the:
+ * targets of the m.room.member events in the state.
+ * 3PID invite tokens in the state.
+ """
+ existing_members = []
+ threepid_invite_tokens = []
+ for key, state_event in state_events.items():
+ if key[0] == EventTypes.Member and state_event.content:
+ existing_members.append(state_event.state_key)
+ if key[0] == EventTypes.ThirdPartyInvite and state_event.content:
+ # Don't include revoked invites.
+ threepid_invite_tokens.append(state_event.state_key)
+
+ return existing_members, threepid_invite_tokens
+
+ @staticmethod
+ def _is_invite_from_threepid(invite: EventBase, threepid_invite_token: str) -> bool:
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite: The m.room.member event with "invite" membership.
+ threepid_invite_token: The state key from the 3PID invite.
+
+ Returns:
+ Whether the invite is due to the given 3PID invite.
+ """
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
+
+ return token == threepid_invite_token
+
+ def _is_local_user(self, user_id: str) -> bool:
+ """Checks whether a given user ID belongs to this homeserver, or a remote
+
+ Args:
+ user_id: A user ID to check.
+
+ Returns:
+ True if the user belongs to this homeserver, False otherwise.
+ """
+ user = UserID.from_string(user_id)
+
+ # Extract the localpart and ask the module API for a user ID from the localpart
+ # The module API will append the local homeserver's server_name
+ local_user_id = self.module_api.get_qualified_user_id(user.localpart)
+
+ # If the user ID we get based on the localpart is the same as the original user ID,
+ # then they were a local user
+ return user_id == local_user_id
+
+ def _user_is_invited_to_room(
+ self, user_id: str, state_events: StateMap[EventBase]
+ ) -> bool:
+ """Checks whether a given user has been invited to a room
+
+ A user has an invite for a room if its state contains a `m.room.member`
+ event with membership "invite" and their user ID as the state key.
+
+ Args:
+ user_id: The user to check.
+ state_events: The state events from the room.
+
+ Returns:
+ True if the user has been invited to the room, or False if they haven't.
+ """
+ for (event_type, state_key), state_event in state_events.items():
+ if (
+ event_type == EventTypes.Member
+ and state_key == user_id
+ and state_event.membership == Membership.INVITE
+ ):
+ return True
+
+ return False
diff --git a/synapse/types.py b/synapse/types.py
index 66bb5bac8d..7c9f716804 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import (
import attr
from signedjson.key import decode_verify_key_bytes
+from six.moves import filter
from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
@@ -317,18 +318,31 @@ mxid_localpart_allowed_characters = set(
)
-def contains_invalid_mxid_characters(localpart):
+def contains_invalid_mxid_characters(localpart: str) -> bool:
"""Check for characters not allowed in an mxid or groupid localpart
Args:
- localpart (basestring): the localpart to be checked
+ localpart: the localpart to be checked
Returns:
- bool: True if there are any naughty characters
+ True if there are any naughty characters
"""
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+def strip_invalid_mxid_characters(localpart):
+ """Removes any invalid characters from an mxid
+
+ Args:
+ localpart (basestring): the localpart to be stripped
+
+ Returns:
+ localpart (basestring): the localpart having been stripped
+ """
+ filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
+ return "".join(filtered)
+
+
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 43c2e0ac23..cfdaa1c5d9 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,11 +73,11 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
def canonicalise_email(address: str) -> str:
|