diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 06ba6604f3..768b840b6a 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Optional
@@ -22,7 +21,6 @@ from netaddr import IPAddress
from twisted.internet import defer
from twisted.web.server import Request
-import synapse.logging.opentracing as opentracing
import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
@@ -35,6 +33,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
+from synapse.logging import opentracing as opentracing
from synapse.types import StateMap, UserID
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
@@ -194,6 +193,7 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
+
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
@@ -260,11 +260,11 @@ class Auth(object):
except KeyError:
raise MissingClientTokenError()
- @defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
+
if app_service is None:
return None, None
@@ -282,8 +282,12 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (yield self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
+ # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
+ # if not (yield self.store.get_user_by_id(user_id)):
+ # raise AuthError(
+ # 403,
+ # "Application service has not registered this user"
+ # )
return user_id, app_service
@defer.inlineCallbacks
@@ -538,7 +542,7 @@ class Auth(object):
# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
# to the event's `auth_events` (e.g. joins pointing to previous joins
- # when room is publically joinable). Dropping event IDs has the
+ # when room is publicly joinable). Dropping event IDs has the
# advantage that the auth chain for the room grows slower, but we use
# the auth chain in state resolution v2 to order events, which means
# care must be taken if dropping events to ensure that it doesn't
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 5305038c21..20aff11c81 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -72,6 +73,13 @@ class Codes(object):
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
+ PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
+ PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
+ PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
+ PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
+ PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
+ PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
+ WEAK_PASSWORD = "M_WEAK_PASSWORD"
class CodeMessageException(RuntimeError):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 27a3fc9ed6..f6792d9fc8 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
-from twisted.internet import defer, reactor
+from twisted.internet import address, defer, reactor
import synapse
import synapse.events
@@ -206,10 +206,30 @@ class KeyUploadServlet(RestServlet):
if body:
# They're actually trying to upload something, proxy to main synapse.
- # Pass through the auth headers, if any, in case the access token
- # is there.
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
- headers = {"Authorization": auth_headers}
+
+ # Proxy headers from the original request, such as the auth headers
+ # (in case the access token is there) and the original IP /
+ # User-Agent of the request.
+ headers = {
+ header: request.requestHeaders.getRawHeaders(header, [])
+ for header in (b"Authorization", b"User-Agent")
+ }
+ # Add the previous hop the the X-Forwarded-For header.
+ x_forwarded_for = request.requestHeaders.getRawHeaders(
+ b"X-Forwarded-For", []
+ )
+ if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
+ previous_host = request.client.host.encode("ascii")
+ # If the header exists, add to the comma-separated list of the first
+ # instance of the header. Otherwise, generate a new header.
+ if x_forwarded_for:
+ x_forwarded_for = [
+ x_forwarded_for[0] + b", " + previous_host
+ ] + x_forwarded_for[1:]
+ else:
+ x_forwarded_for = [previous_host]
+ headers[b"X-Forwarded-For"] = x_forwarded_for
+
try:
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index da9a5e86d4..f92bfb420b 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -98,7 +98,6 @@ class ApplicationServiceApi(SimpleHttpClient):
if service.url is None:
return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
- response = None
try:
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index fca35b008c..65043d5b5b 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -16,6 +16,7 @@ from synapse.config._base import ConfigError
if __name__ == "__main__":
import sys
+
from synapse.config.homeserver import HomeServerConfig
action = sys.argv[1]
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1391e5fc43..f2830c609d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,6 +19,7 @@ import argparse
import errno
import os
from collections import OrderedDict
+from io import open as io_open
from textwrap import dedent
from typing import Any, MutableMapping, Optional
@@ -179,7 +180,7 @@ class Config(object):
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
+ with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index ca61214454..b1dc7ad502 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from __future__ import print_function
# This file can't be called email.py because if it is, we cannot:
@@ -73,7 +72,7 @@ class EmailConfig(Config):
template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and
- # we don't yet know what auxilliary templates like mail.css we will need).
+ # we don't yet know what auxiliary templates like mail.css we will need).
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
@@ -145,8 +144,8 @@ class EmailConfig(Config):
or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
):
# make sure we can import the required deps
- import jinja2
import bleach
+ import jinja2
# prevent unused warnings
jinja2
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index a568726985..fce96b4acf 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -45,10 +45,37 @@ class JWTConfig(Config):
def generate_config_section(self, **kwargs):
return """\
- # The JWT needs to contain a globally unique "sub" (subject) claim.
+ # JSON web token integration. The following settings can be used to make
+ # Synapse JSON web tokens for authentication, instead of its internal
+ # password database.
+ #
+ # Each JSON Web Token needs to contain a "sub" (subject) claim, which is
+ # used as the localpart of the mxid.
+ #
+ # Note that this is a non-standard login type and client support is
+ # expected to be non-existant.
+ #
+ # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md.
#
#jwt_config:
- # enabled: true
- # secret: "a secret"
- # algorithm: "HS256"
+ # Uncomment the following to enable authorization using JSON web
+ # tokens. Defaults to false.
+ #
+ #enabled: true
+
+ # This is either the private shared secret or the public key used to
+ # decode the contents of the JSON web token.
+ #
+ # Required if 'enabled' is true.
+ #
+ #secret: "provided-by-your-issuer"
+
+ # The algorithm used to sign the JSON web token.
+ #
+ # Supported algorithms are listed at
+ # https://pyjwt.readthedocs.io/en/latest/algorithms.html
+ #
+ # Required if 'enabled' is true.
+ #
+ #algorithm: "provided-by-your-issuer"
"""
diff --git a/synapse/config/password.py b/synapse/config/password.py
index 9c0ea8c30a..6b2dae78b0 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 2dd94bae2b..b1981d4d15 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -76,6 +76,9 @@ class RatelimitConfig(Config):
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {})
+ )
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
@@ -115,6 +118,8 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one that ratelimits third-party invites requests based on the account
+ # that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -140,6 +145,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
+ #
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 6badf4e75d..080b9bc445 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -100,8 +100,19 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.check_is_for_allowed_local_3pids = config.get(
+ "check_is_for_allowed_local_3pids", None
+ )
+ self.allow_invited_3pids = config.get("allow_invited_3pids", False)
+
+ self.disable_3pid_changes = config.get("disable_3pid_changes", False)
+
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
+ self.register_just_use_email_for_display_name = config.get(
+ "register_just_use_email_for_display_name", False
+ )
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
@@ -109,7 +120,21 @@ class RegistrationConfig(Config):
)
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+ if (
+ self.account_threepid_delegate_email
+ and not self.account_threepid_delegate_email.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.email must begin with http:// or https://"
+ )
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+ if (
+ self.account_threepid_delegate_msisdn
+ and not self.account_threepid_delegate_msisdn.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.msisdn must begin with http:// or https://"
+ )
if self.account_threepid_delegate_msisdn and not self.public_baseurl:
raise ConfigError(
"The configuration option `public_baseurl` is required if "
@@ -178,6 +203,15 @@ class RegistrationConfig(Config):
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
+ self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
+ if not isinstance(self.replicate_user_profiles_to, list):
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
+
+ self.shadow_server = config.get("shadow_server", None)
+ self.rewrite_identity_server_urls = (
+ config.get("rewrite_identity_server_urls") or {}
+ )
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -187,6 +221,23 @@ class RegistrationConfig(Config):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime
+ self.bind_new_user_emails_to_sydent = config.get(
+ "bind_new_user_emails_to_sydent"
+ )
+
+ if self.bind_new_user_emails_to_sydent:
+ if not isinstance(
+ self.bind_new_user_emails_to_sydent, str
+ ) or not self.bind_new_user_emails_to_sydent.startswith("http"):
+ raise ConfigError(
+ "Option bind_new_user_emails_to_sydent has invalid value"
+ )
+
+ # Remove trailing slashes
+ self.bind_new_user_emails_to_sydent = self.bind_new_user_emails_to_sydent.strip(
+ "/"
+ )
+
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@@ -291,9 +342,32 @@ class RegistrationConfig(Config):
#
#disable_msisdn_registration: true
+ # Derive the user's matrix ID from a type of 3PID used when registering.
+ # This overrides any matrix ID the user proposes when calling /register
+ # The 3PID type should be present in registrations_require_3pid to avoid
+ # users failing to register if they don't specify the right kind of 3pid.
+ #
+ #register_mxid_from_3pid: email
+
+ # Uncomment to set the display name of new users to their email address,
+ # rather than using the default heuristic.
+ #
+ #register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+ # Use an Identity Server to establish which 3PIDs are allowed to register?
+ # Overrides allowed_local_3pids below.
+ #
+ #check_is_for_allowed_local_3pids: matrix.org
+ #
+ # If you are using an IS you can also check whether that IS registers
+ # pending invites for the given 3PID (and then allow it to sign up on
+ # the platform):
+ #
+ #allow_invited_3pids: false
+ #
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
@@ -302,6 +376,11 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # If true, stop users from trying to change the 3PIDs associated with
+ # their accounts.
+ #
+ #disable_3pid_changes: false
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -351,6 +430,30 @@ class RegistrationConfig(Config):
# - matrix.org
# - vector.im
+ # If enabled, user IDs, display names and avatar URLs will be replicated
+ # to this server whenever they change.
+ # This is an experimental API currently implemented by sydent to support
+ # cross-homeserver user directories.
+ #
+ #replicate_user_profiles_to: example.com
+
+ # If specified, attempt to replay registrations, profile changes & 3pid
+ # bindings on the given target homeserver via the AS API. The HS is authed
+ # via a given AS token.
+ #
+ #shadow_server:
+ # hs_url: https://shadow.example.com
+ # hs: shadow.example.com
+ # as_token: 12u394refgbdhivsia
+
+ # If enabled, don't let users set their own display names/avatars
+ # other than for the very first time (unless they are a server admin).
+ # Useful when provisioning users based on the contents of a 3rd party
+ # directory and to avoid ambiguities.
+ #
+ #disable_set_displayname: false
+ #disable_set_avatar_url: false
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -476,6 +579,31 @@ class RegistrationConfig(Config):
# Defaults to true.
#
#auto_join_rooms_for_guests: false
+
+ # Rewrite identity server URLs with a map from one URL to another. Applies to URLs
+ # provided by clients (which have https:// prepended) and those specified
+ # in `account_threepid_delegates`. URLs should not feature a trailing slash.
+ #
+ #rewrite_identity_server_urls:
+ # "https://somewhere.example.com": "https://somewhereelse.example.com"
+
+ # When a user registers an account with an email address, it can be useful to
+ # bind that email address to their mxid on an identity server. Typically, this
+ # requires the user to validate their email address with the identity server.
+ # However if Synapse itself is handling email validation on registration, the
+ # user ends up needing to validate their email twice, which leads to poor UX.
+ #
+ # It is possible to force Sydent, one identity server implementation, to bind
+ # threepids using its internal, unauthenticated bind API:
+ # https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api
+ #
+ # Configure the address of a Sydent server here to have Synapse attempt
+ # to automatically bind users' emails following registration. The
+ # internal bind API must be reachable from Synapse, but should NOT be
+ # exposed to any third party, as it allows the creation of bindings
+ # without validation.
+ #
+ #bind_new_user_emails_to_sydent: https://example.com:8091
"""
% locals()
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 01009f3924..54f565ad5b 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -104,6 +104,12 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
+
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
@@ -244,6 +250,30 @@ class ContentRepositoryConfig(Config):
#
#max_upload_size: 10M
+ # The largest allowed size for a user avatar. If not defined, no
+ # restriction will be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #max_avatar_size: 10M
+
+ # Allow mimetypes for a user avatar. If not defined, no restriction will
+ # be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 8204664883..43ab5d62d6 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -311,6 +311,12 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # True.
+ self.show_users_in_user_directory = config.get(
+ "show_users_in_user_directory", True
+ )
+
retention_config = config.get("retention")
if retention_config is None:
retention_config = {}
@@ -968,6 +974,74 @@ class ServerConfig(Config):
#
#allow_per_room_profiles: false
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # 'true'.
+ #
+ #show_users_in_user_directory: false
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..43b6c40456 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -26,6 +26,7 @@ class UserDirectoryConfig(Config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
+ self.user_directory_defer_to_id_server = None
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
@@ -34,6 +35,9 @@ class UserDirectoryConfig(Config):
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
@@ -52,4 +56,9 @@ class UserDirectoryConfig(Config):
#user_directory:
# enabled: true
# search_all_users: false
+ #
+ # # If this is set, user search will be delegated to this ID server instead
+ # # of synapse performing the search itself.
+ # # This is an experimental API.
+ # defer_to_id_server: https://id.example.com
"""
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c582355146..c0981eee62 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -65,14 +65,16 @@ def check(
room_id = event.room_id
- # I'm not really expecting to get auth events in the wrong room, but let's
- # sanity-check it
+ # We need to ensure that the auth events are actually for the same room, to
+ # stop people from using powers they've been granted in other rooms for
+ # example.
for auth_event in auth_events.values():
if auth_event.room_id != room_id:
- raise Exception(
+ raise AuthError(
+ 403,
"During auth for event %s in room %s, found event %s in the state "
"which is in room %s"
- % (event.event_id, room_id, auth_event.event_id, auth_event.room_id)
+ % (event.event_id, room_id, auth_event.event_id, auth_event.room_id),
)
if do_sig_check:
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index a0c4a40c27..92aadfe7ef 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -162,7 +162,7 @@ class EventBuilderFactory(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.hostname = hs.hostname
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 1ffc9525d1..3c11e317fd 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,7 +15,7 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
from synapse.spam_checker_api import SpamCheckerApi
@@ -58,42 +58,82 @@ class SpamChecker(object):
return False
def user_may_invite(
- self, inviter_userid: str, invitee_userid: str, room_id: str
+ self,
+ inviter_userid: str,
+ invitee_userid: str,
+ third_party_invite: Optional[Dict],
+ room_id: str,
+ new_room: bool,
+ published_room: bool,
) -> bool:
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- inviter_userid: The user ID of the sender of the invitation
- invitee_userid: The user ID targeted in the invitation
- room_id: The room ID
+ inviter_userid:
+ invitee_userid: The user ID of the invitee. Is None
+ if this is a third party invite and the 3PID is not bound to a
+ user ID.
+ third_party_invite: If a third party invite then is a
+ dict containing the medium and address of the invitee.
+ room_id:
+ new_room: Whether the user is being invited to the room as
+ part of a room creation, if so the invitee would have been
+ included in the call to `user_may_create_room`.
+ published_room: Whether the room the user is being invited
+ to has been published in the local homeserver's public room
+ directory.
Returns:
True if the user may send an invite, otherwise False
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ spam_checker.user_may_invite(
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
+ )
is False
):
return False
return True
- def user_may_create_room(self, userid: str) -> bool:
+ def user_may_create_room(
+ self,
+ userid: str,
+ invite_list: List[str],
+ third_party_invite_list: List[Dict],
+ cloning: bool,
+ ) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid: The ID of the user attempting to create a room
+ invite_list: List of user IDs that would be invited to
+ the new room.
+ third_party_invite_list: List of third party invites
+ for the new room.
+ cloning: Whether the user is cloning an existing room, e.g.
+ upgrading a room.
Returns:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room(userid) is False:
+ if (
+ spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
+ is False
+ ):
return False
return True
@@ -134,6 +174,25 @@ class SpamChecker(object):
return True
+ def user_may_join_room(self, userid: str, room_id: str, is_invited: bool):
+ """Checks if a given users is allowed to join a room.
+
+ Not called when a user creates a room.
+
+ Args:
+ userid:
+ room_id:
+ is_invited: Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ for spam_checker in self.spam_checkers:
+ if spam_checker.user_may_join_room(userid, room_id, is_invited) is False:
+ return False
+
+ return True
+
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 687cd841ac..a37cc9cb4a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -87,7 +87,7 @@ class FederationClient(FederationBase):
self.transport_layer = hs.get_federation_transport_client()
self.hostname = hs.hostname
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@@ -245,7 +245,7 @@ class FederationClient(FederationBase):
event_id: event to fetch
room_version: version of the room
outlier: Indicates whether the PDU is an `outlier`, i.e. if
- it's from an arbitary point in the context as opposed to part
+ it's from an arbitrary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
@@ -351,7 +351,7 @@ class FederationClient(FederationBase):
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
- """Takes a list of PDUs and checks the signatures and hashs of each
+ """Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e704cf2f44..86051decd4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -717,7 +717,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
- logger.warning("Ignorning non-bool allow_ip_literals flag")
+ logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
@@ -731,7 +731,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
- logger.warning("Ignorning non-list deny ACL %s", deny)
+ logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
@@ -741,7 +741,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
- logger.warning("Ignorning non-list allow ACL %s", allow)
+ logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 6bbd762681..860b03f7b9 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -359,7 +359,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
@staticmethod
def from_data(data):
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 4e698981a4..12966e239b 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -119,7 +119,7 @@ class PerDestinationQueue(object):
)
def send_pdu(self, pdu: EventBase, order: int) -> None:
- """Add a PDU to the queue, and start the transmission loop if neccessary
+ """Add a PDU to the queue, and start the transmission loop if necessary
Args:
pdu: pdu to send
@@ -129,7 +129,7 @@ class PerDestinationQueue(object):
self.attempt_new_transaction()
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
- """Add presence updates to the queue. Start the transmission loop if neccessary.
+ """Add presence updates to the queue. Start the transmission loop if necessary.
Args:
states: presence to send
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9f99311419..e502c12050 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -16,7 +16,7 @@
import logging
import urllib
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from twisted.internet import defer
@@ -746,7 +746,7 @@ class TransportLayerClient(object):
def remove_user_from_group(
self, destination, group_id, requester_user_id, user_id, content
):
- """Remove a user fron a group
+ """Remove a user from a group
"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
@@ -1020,6 +1020,20 @@ class TransportLayerClient(object):
return self.client.get_json(destination=destination, path=path)
+ def get_info_of_users(self, destination: str, user_ids: List[str]):
+ """
+ Args:
+ destination: The remote server
+ user_ids: A list of user IDs to query info about
+
+ Returns:
+ Deferred[List]: A dictionary of User ID to information about that user.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
+ data = {"user_ids": user_ids}
+
+ return self.client.post_json(destination=destination, path=path, data=data)
+
def _create_path(federation_prefix, path, *args):
"""
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index af4595498c..1478ee03a5 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -33,6 +33,7 @@ from synapse.api.urls import (
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
@@ -109,7 +110,7 @@ class Authenticator(object):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
- self.notifer = hs.get_notifier()
+ self.notifier = hs.get_notifier()
self.replication_client = None
if hs.config.worker.worker_app:
@@ -175,7 +176,7 @@ class Authenticator(object):
await self.store.set_destination_retry_timings(origin, None, 0, 0)
# Inform the relevant places that the remote server is back up.
- self.notifer.notify_remote_server_up(origin)
+ self.notifier.notify_remote_server_up(origin)
if self.replication_client:
# If we're on a worker we try and inform master about this. The
# replication client doesn't hook into the notifier to avoid
@@ -361,11 +362,7 @@ class BaseFederationServlet(object):
continue
server.register_paths(
- method,
- (pattern,),
- self._wrap(code),
- self.__class__.__name__,
- trace=False,
+ method, (pattern,), self._wrap(code), self.__class__.__name__,
)
@@ -849,6 +846,57 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
+class FederationUserInfoServlet(BaseFederationServlet):
+ """
+ Return information about a set of users.
+
+ This API returns expiration and deactivation information about a set of
+ users. Requested users not local to this homeserver will be ignored.
+
+ Example request:
+ POST /users/info
+
+ {
+ "user_ids": [
+ "@alice:example.com",
+ "@bob:example.com"
+ ]
+ }
+
+ Example response
+ {
+ "@alice:example.com": {
+ "expired": false,
+ "deactivated": true
+ }
+ }
+ """
+
+ PATH = "/users/info"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
+ super(FederationUserInfoServlet, self).__init__(
+ handler, authenticator, ratelimiter, server_name
+ )
+ self.handler = handler
+
+ async def on_POST(self, origin, content, query):
+ assert_params_in_dict(content, required=["user_ids"])
+
+ user_ids = content.get("user_ids", [])
+
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ data = await self.handler.store.get_info_for_users(user_ids)
+ return 200, data
+
+
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
@@ -1410,6 +1458,7 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 27b0c02655..dab13c243f 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -70,7 +70,7 @@ class GroupAttestationSigning(object):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
self.server_name = hs.hostname
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
@defer.inlineCallbacks
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 8db8ab1b7b..8cb922ddc7 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -41,7 +41,7 @@ class GroupsServerWorkerHandler(object):
self.clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self.server_name = hs.hostname
self.attestations = hs.get_groups_attestation_signing()
self.transport_client = hs.get_federation_transport_client()
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 590135d19c..0c2bcda4d0 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -20,6 +20,8 @@ from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import List
+from twisted.internet import defer
+
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -43,6 +45,8 @@ class AccountValidityHandler(object):
self.clock = self.hs.get_clock()
self._account_validity = self.hs.config.account_validity
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
if (
self._account_validity.enabled
@@ -86,6 +90,13 @@ class AccountValidityHandler(object):
self.clock.looping_call(send_emails, 30 * 60 * 1000)
+ # If account_validity is enabled,check every hour to remove expired users from
+ # the user directory
+ if self._account_validity.enabled:
+ self.clock.looping_call(
+ self._mark_expired_users_as_inactive, 60 * 60 * 1000
+ )
+
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
@@ -266,4 +277,25 @@ class AccountValidityHandler(object):
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ await self.profile_handler.set_active(
+ [UserID.from_string(user_id)], True, True
+ )
+
return expiration_ts
+
+ @defer.inlineCallbacks
+ def _mark_expired_users_as_inactive(self):
+ """Iterate over active, expired users. Mark them as inactive in order to hide them
+ from the user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get active, expired users
+ active_expired_users = yield self.store.get_expired_users()
+
+ # Mark each as non-active
+ yield self.profile_handler.set_active(active_expired_users, False, True)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 904c96eeec..92d4c6e16c 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -48,8 +48,7 @@ class ApplicationServicesHandler(object):
self.current_max = 0
self.is_processing = False
- @defer.inlineCallbacks
- def notify_interested_services(self, current_id):
+ async def notify_interested_services(self, current_id):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
@@ -74,7 +73,7 @@ class ApplicationServicesHandler(object):
(
upper_bound,
events,
- ) = yield self.store.get_new_events_for_appservice(
+ ) = await self.store.get_new_events_for_appservice(
self.current_max, limit
)
@@ -85,10 +84,9 @@ class ApplicationServicesHandler(object):
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
- @defer.inlineCallbacks
- def handle_event(event):
+ async def handle_event(event):
# Gather interested services
- services = yield self._get_services_for_event(event)
+ services = await self._get_services_for_event(event)
if len(services) == 0:
return # no services need notifying
@@ -96,9 +94,9 @@ class ApplicationServicesHandler(object):
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
- yield self._check_user_exists(event.sender)
+ await self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
- yield self._check_user_exists(event.state_key)
+ await self._check_user_exists(event.state_key)
if not self.started_scheduler:
@@ -115,17 +113,16 @@ class ApplicationServicesHandler(object):
self.scheduler.submit_event_for_as(service, event)
now = self.clock.time_msec()
- ts = yield self.store.get_received_ts(event.event_id)
+ ts = await self.store.get_received_ts(event.event_id)
synapse.metrics.event_processing_lag_by_event.labels(
"appservice_sender"
).observe((now - ts) / 1000)
- @defer.inlineCallbacks
- def handle_room_events(events):
+ async def handle_room_events(events):
for event in events:
- yield handle_event(event)
+ await handle_event(event)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
@@ -135,10 +132,10 @@ class ApplicationServicesHandler(object):
)
)
- yield self.store.set_appservice_last_pos(upper_bound)
+ await self.store.set_appservice_last_pos(upper_bound)
now = self.clock.time_msec()
- ts = yield self.store.get_received_ts(events[-1].event_id)
+ ts = await self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_positions.labels(
"appservice_sender"
@@ -161,8 +158,7 @@ class ApplicationServicesHandler(object):
finally:
self.is_processing = False
- @defer.inlineCallbacks
- def query_user_exists(self, user_id):
+ async def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.
Args:
@@ -170,15 +166,14 @@ class ApplicationServicesHandler(object):
Returns:
True if this user exists on at least one application service.
"""
- user_query_services = yield self._get_services_for_user(user_id=user_id)
+ user_query_services = self._get_services_for_user(user_id=user_id)
for user_service in user_query_services:
- is_known_user = yield self.appservice_api.query_user(user_service, user_id)
+ is_known_user = await self.appservice_api.query_user(user_service, user_id)
if is_known_user:
return True
return False
- @defer.inlineCallbacks
- def query_room_alias_exists(self, room_alias):
+ async def query_room_alias_exists(self, room_alias):
"""Check if an application service knows this room alias exists.
Args:
@@ -193,19 +188,18 @@ class ApplicationServicesHandler(object):
s for s in services if (s.is_interested_in_alias(room_alias_str))
]
for alias_service in alias_query_services:
- is_known_alias = yield self.appservice_api.query_alias(
+ is_known_alias = await self.appservice_api.query_alias(
alias_service, room_alias_str
)
if is_known_alias:
# the alias exists now so don't query more ASes.
- result = yield self.store.get_association_from_room_alias(room_alias)
+ result = await self.store.get_association_from_room_alias(room_alias)
return result
- @defer.inlineCallbacks
- def query_3pe(self, kind, protocol, fields):
- services = yield self._get_services_for_3pn(protocol)
+ async def query_3pe(self, kind, protocol, fields):
+ services = self._get_services_for_3pn(protocol)
- results = yield make_deferred_yieldable(
+ results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
@@ -224,8 +218,7 @@ class ApplicationServicesHandler(object):
return ret
- @defer.inlineCallbacks
- def get_3pe_protocols(self, only_protocol=None):
+ async def get_3pe_protocols(self, only_protocol=None):
services = self.store.get_app_services()
protocols = {}
@@ -238,7 +231,7 @@ class ApplicationServicesHandler(object):
if p not in protocols:
protocols[p] = []
- info = yield self.appservice_api.get_3pe_protocol(s, p)
+ info = await self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
@@ -263,8 +256,7 @@ class ApplicationServicesHandler(object):
return protocols
- @defer.inlineCallbacks
- def _get_services_for_event(self, event):
+ async def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event.
Args:
@@ -280,7 +272,7 @@ class ApplicationServicesHandler(object):
# inside of a list comprehension anymore.
interested_list = []
for s in services:
- if (yield s.is_interested(event, self.store)):
+ if await s.is_interested(event, self.store):
interested_list.append(s)
return interested_list
@@ -288,21 +280,20 @@ class ApplicationServicesHandler(object):
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
- return defer.succeed(interested_list)
+ return interested_list
def _get_services_for_3pn(self, protocol):
services = self.store.get_app_services()
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
- return defer.succeed(interested_list)
+ return interested_list
- @defer.inlineCallbacks
- def _is_unknown_user(self, user_id):
+ async def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
return False
- user_info = yield self.store.get_user_by_id(user_id)
+ user_info = await self.store.get_user_by_id(user_id)
if user_info:
return False
@@ -311,10 +302,9 @@ class ApplicationServicesHandler(object):
service_list = [s for s in services if s.sender == user_id]
return len(service_list) == 0
- @defer.inlineCallbacks
- def _check_user_exists(self, user_id):
- unknown_user = yield self._is_unknown_user(user_id)
+ async def _check_user_exists(self, user_id):
+ unknown_user = await self._is_unknown_user(user_id)
if unknown_user:
- exists = yield self.query_user_exists(user_id)
+ exists = await self.query_user_exists(user_id)
return exists
return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c3f86e7414..a162392e4c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import time
import unicodedata
@@ -24,7 +23,6 @@ import attr
import bcrypt # type: ignore[import]
import pymacaroons
-import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -45,6 +43,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
+from synapse.util import stringutils as stringutils
+from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
@@ -928,7 +928,7 @@ class AuthHandler(BaseHandler):
# for the presence of an email address during password reset was
# case sensitive).
if medium == "email":
- address = address.lower()
+ address = canonicalise_email(address)
await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
@@ -956,7 +956,7 @@ class AuthHandler(BaseHandler):
# 'Canonicalise' email addresses as per above
if medium == "email":
- address = address.lower()
+ address = canonicalise_email(address)
identity_handler = self.hs.get_handlers().identity_handler
result = await identity_handler.try_unbind_threepid(
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 76f213723a..d79ffefdb5 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import urllib
-import xml.etree.ElementTree as ET
from typing import Dict, Optional, Tuple
+from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 2afb390a92..fe62c3f973 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -33,6 +33,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_handlers().identity_handler
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
@@ -104,6 +105,9 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ await self._profile_handler.set_active([user], False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4dbd8e1d98..90623fcd4a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,8 +19,9 @@
import itertools
import logging
+from collections import Container
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import attr
from signedjson.key import decode_verify_key_bytes
@@ -173,7 +174,7 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("handling received PDU: %s", pdu)
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = await self.store.get_event(
@@ -288,6 +289,14 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
+ elif missing_prevs:
+ logger.info(
+ "[%s %s] Not recursively fetching %d missing prev_events: %s",
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -332,12 +341,6 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- logger.info(
- "Event %s is missing prev_events: calculating state for a "
- "backwards extremity",
- event_id,
- )
-
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
@@ -355,7 +358,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "Requesting state at missing prev_event %s", event_id,
+ "[%s %s] Requesting state at missing prev_event %s",
+ room_id,
+ event_id,
+ p,
)
with nested_logging_context(p):
@@ -390,9 +396,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
+ list(state_map.values()), get_prev_content=False,
)
event_map.update(evs)
@@ -617,6 +621,11 @@ class FederationHandler(BaseHandler):
will be omitted from the result. Likewise, any events which turn out not to
be in the given room.
+ This function *does not* automatically get missing auth events of the
+ newly fetched events. Callers must include the full auth chain of
+ of the missing events in the `event_ids` argument, to ensure that any
+ missing auth events are correctly fetched.
+
Returns:
map from event_id to event
"""
@@ -742,6 +751,9 @@ class FederationHandler(BaseHandler):
# device and recognize the algorithm then we can work out the
# exact key to expect. Otherwise check it matches any key we
# have for that device.
+
+ current_keys = [] # type: Container[str]
+
if device:
keys = device.get("keys", {}).get("keys", {})
@@ -758,15 +770,15 @@ class FederationHandler(BaseHandler):
current_keys = keys.values()
elif device_id:
# We don't have any keys for the device ID.
- current_keys = []
+ pass
else:
# The event didn't include a device ID, so we just look for
# keys across all devices.
- current_keys = (
+ current_keys = [
key
for device in cached_devices
for key in device.get("keys", {}).get("keys", {}).values()
- )
+ ]
# We now check that the sender key matches (one of) the expected
# keys.
@@ -1011,7 +1023,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
- joined_domains = {}
+ joined_domains = {} # type: Dict[str, int]
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
@@ -1127,12 +1139,16 @@ class FederationHandler(BaseHandler):
):
"""Fetch the given events from a server, and persist them as outliers.
+ This function *does not* recursively get missing auth events of the
+ newly fetched events. Callers must include in the `events` argument
+ any missing events from the auth chain.
+
Logs a warning if we can't find the given event.
"""
room_version = await self.store.get_room_version(room_id)
- event_infos = []
+ event_map = {} # type: Dict[str, EventBase]
async def get_event(event_id: str):
with nested_logging_context(event_id):
@@ -1146,17 +1162,7 @@ class FederationHandler(BaseHandler):
)
return
- # recursively fetch the auth events for this event
- auth_events = await self._get_events_from_store_or_dest(
- destination, room_id, event.auth_event_ids()
- )
- auth = {}
- for auth_event_id in event.auth_event_ids():
- ae = auth_events.get(auth_event_id)
- if ae:
- auth[(ae.type, ae.state_key)] = ae
-
- event_infos.append(_NewEventInfo(event, None, auth))
+ event_map[event.event_id] = event
except Exception as e:
logger.warning(
@@ -1168,6 +1174,32 @@ class FederationHandler(BaseHandler):
await concurrently_execute(get_event, events, 5)
+ # Make a map of auth events for each event. We do this after fetching
+ # all the events as some of the events' auth events will be in the list
+ # of requested events.
+
+ auth_events = [
+ aid
+ for event in event_map.values()
+ for aid in event.auth_event_ids()
+ if aid not in event_map
+ ]
+ persisted_events = await self.store.get_events(
+ auth_events, allow_rejected=True,
+ )
+
+ event_infos = []
+ for event in event_map.values():
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+ else:
+ logger.info("Missing auth event %s", auth_event_id)
+
+ event_infos.append(_NewEventInfo(event, None, auth))
+
await self._handle_new_events(
destination, event_infos,
)
@@ -1277,14 +1309,15 @@ class FederationHandler(BaseHandler):
try:
# Try the host we successfully got a response to /make_join/
# request first.
+ host_list = list(target_hosts)
try:
- target_hosts.remove(origin)
- target_hosts.insert(0, origin)
+ host_list.remove(origin)
+ host_list.insert(0, origin)
except ValueError:
pass
ret = await self.federation_client.send_join(
- target_hosts, event, room_version_obj
+ host_list, event, room_version_obj
)
origin = ret["origin"]
@@ -1523,8 +1556,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = await self.store.is_room_published(event.room_id)
+
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1562,7 +1602,7 @@ class FederationHandler(BaseHandler):
room_version,
event.get_pdu_json(),
self.hs.hostname,
- self.hs.config.signing_key[0],
+ self.hs.signing_key,
)
)
@@ -1584,13 +1624,14 @@ class FederationHandler(BaseHandler):
# Try the host that we succesfully called /make_leave/ on first for
# the /send_leave/ request.
+ host_list = list(target_hosts)
try:
- target_hosts.remove(origin)
- target_hosts.insert(0, origin)
+ host_list.remove(origin)
+ host_list.insert(0, origin)
except ValueError:
pass
- await self.federation_client.send_leave(target_hosts, event)
+ await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify([(event, context)])
@@ -1604,7 +1645,7 @@ class FederationHandler(BaseHandler):
user_id: str,
membership: str,
content: JsonDict = {},
- params: Optional[Dict[str, str]] = None,
+ params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]:
(
origin,
@@ -2018,8 +2059,8 @@ class FederationHandler(BaseHandler):
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_x = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
@@ -2055,76 +2096,67 @@ class FederationHandler(BaseHandler):
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
- if do_soft_fail_check:
- extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
-
- extrem_ids = set(extrem_ids)
- prev_event_ids = set(event.prev_event_ids())
-
- if extrem_ids == prev_event_ids:
- # If they're the same then the current state is the same as the
- # state at the event, so no point rechecking auth for soft fail.
- do_soft_fail_check = False
-
- if do_soft_fail_check:
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- # Calculate the "current state".
- if state is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
- # however we still want to do checks as gaps are easy to
- # maliciously manufacture.
- #
- # So we use a "current state" that is actually a state
- # resolution across the current forward extremities and the
- # given state at the event. This should correctly handle cases
- # like bans, especially with state res v2.
+ if backfilled or event.internal_metadata.is_outlier():
+ return
- state_sets = await self.state_store.get_state_groups(
- event.room_id, extrem_ids
- )
- state_sets = list(state_sets.values())
- state_sets.append(state)
- current_state_ids = await self.state_handler.resolve_events(
- room_version, state_sets, event
- )
- current_state_ids = {
- k: e.event_id for k, e in current_state_ids.items()
- }
- else:
- current_state_ids = await self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
- )
+ extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids)
+ prev_event_ids = set(event.prev_event_ids())
- logger.debug(
- "Doing soft-fail check for %s: state %s",
- event.event_id,
- current_state_ids,
+ if extrem_ids == prev_event_ids:
+ # If they're the same then the current state is the same as the
+ # state at the event, so no point rechecking auth for soft fail.
+ return
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ # Calculate the "current state".
+ if state is not None:
+ # If we're explicitly given the state then we won't have all the
+ # prev events, and so we have a gap in the graph. In this case
+ # we want to be a little careful as we might have been down for
+ # a while and have an incorrect view of the current state,
+ # however we still want to do checks as gaps are easy to
+ # maliciously manufacture.
+ #
+ # So we use a "current state" that is actually a state
+ # resolution across the current forward extremities and the
+ # given state at the event. This should correctly handle cases
+ # like bans, especially with state res v2.
+
+ state_sets = await self.state_store.get_state_groups(
+ event.room_id, extrem_ids
+ )
+ state_sets = list(state_sets.values())
+ state_sets.append(state)
+ current_state_ids = await self.state_handler.resolve_events(
+ room_version, state_sets, event
+ )
+ current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+ else:
+ current_state_ids = await self.state_handler.get_current_state_ids(
+ event.room_id, latest_event_ids=extrem_ids
)
- # Now check if event pass auth against said current state
- auth_types = auth_types_for_event(event)
- current_state_ids = [
- e for k, e in current_state_ids.items() if k in auth_types
- ]
+ logger.debug(
+ "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+ )
- current_auth_events = await self.store.get_events(current_state_ids)
- current_auth_events = {
- (e.type, e.state_key): e for e in current_auth_events.values()
- }
+ # Now check if event pass auth against said current state
+ auth_types = auth_types_for_event(event)
+ current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
- try:
- event_auth.check(
- room_version_obj, event, auth_events=current_auth_events
- )
- except AuthError as e:
- logger.warning("Soft-failing %r because %s", event, e)
- event.internal_metadata.soft_failed = True
+ current_auth_events = await self.store.get_events(current_state_ids)
+ current_auth_events = {
+ (e.type, e.state_key): e for e in current_auth_events.values()
+ }
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ except AuthError as e:
+ logger.warning("Soft-failing %r because %s", event, e)
+ event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
@@ -2293,10 +2325,10 @@ class FederationHandler(BaseHandler):
remote_auth_chain = await self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
)
- except RequestSendFailed as e:
+ except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
- logger.info("Failed to get event auth from remote: %s", e)
+ logger.info("Failed to get event auth from remote: %s", e1)
return context
seen_remotes = await self.store.have_seen_events(
@@ -2774,7 +2806,8 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content)
- last_exception = None
+ last_exception = None # type: Optional[Exception]
+
# for each public key in the 3pid invite event
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
@@ -2828,6 +2861,12 @@ class FederationHandler(BaseHandler):
return
except Exception as e:
last_exception = e
+
+ if last_exception is None:
+ # we can only get here if get_public_keys() returned an empty list
+ # TODO: make this better
+ raise RuntimeError("no public key in invite event")
+
raise last_exception
async def _check_key_revocation(self, public_key, url):
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 7cb106e365..ecdb12a7bf 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -70,7 +70,7 @@ class GroupsLocalWorkerHandler(object):
self.clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self.server_name = hs.hostname
self.notifier = hs.get_notifier()
self.attestations = hs.get_groups_attestation_signing()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 4ba0042768..88537a5be3 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
+from twisted.internet import defer
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
@@ -32,6 +33,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
from synapse.config.emailconfig import ThreepidBehaviour
@@ -43,29 +45,34 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-id_server_scheme = "https://"
-
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
- self.http_client = SimpleHttpClient(hs)
+ self.hs = hs
+ self.http_client = hs.get_simple_http_client()
# We create a blacklisting instance of SimpleHttpClient for contacting identity
# servers specified by clients
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
self.federation_http_client = hs.get_http_client()
- self.hs = hs
- async def threepid_from_creds(self, id_server, creds):
+ self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
+ self.trust_any_id_server_just_for_testing_do_not_use = (
+ hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
+ )
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
+
+ async def threepid_from_creds(self, id_server_url, creds):
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
- id_server (str): The identity server to validate 3PIDs against. Must be a
+ id_server_url (str): The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
creds (dict[str, str]): Dictionary containing the following keys:
@@ -92,7 +99,14 @@ class IdentityHandler(BaseHandler):
query_params = {"sid": session_id, "client_secret": client_secret}
- url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
+ url = "%s%s" % (
+ id_server_url,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
try:
data = await self.http_client.get_json(url, query_params)
@@ -101,7 +115,7 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info(
"%s returned %i for threepid validation for: %s",
- id_server,
+ id_server_url,
e.code,
creds,
)
@@ -115,7 +129,7 @@ class IdentityHandler(BaseHandler):
if "medium" in data:
return data
- logger.info("%s reported non-validated threepid: %s", id_server, creds)
+ logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
return None
async def bind_threepid(
@@ -146,14 +160,19 @@ class IdentityHandler(BaseHandler):
if id_access_token is None:
use_v2 = False
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,)
headers["Authorization"] = create_id_access_token_header(id_access_token)
else:
- bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,)
try:
# Use the blacklisting http client as this call is only to identity servers
@@ -238,9 +257,6 @@ class IdentityHandler(BaseHandler):
Deferred[bool]: True on success, otherwise False if the identity
server doesn't support unbinding
"""
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
-
content = {
"mxid": mxid,
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
@@ -249,15 +265,25 @@ class IdentityHandler(BaseHandler):
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
+ url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
- method="POST",
+ method=b"POST",
url_bytes=url_bytes,
content=content,
- destination_is=id_server,
+ destination_is=id_server.encode("ascii"),
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,)
+
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
@@ -371,15 +397,34 @@ class IdentityHandler(BaseHandler):
return session_id
+ def rewrite_id_server_url(self, url: str, add_https=False) -> str:
+ """Given an identity server URL, optionally add a protocol scheme
+ before rewriting it according to the rewrite_identity_server_urls
+ config option
+
+ Adds https:// to the URL if specified, then tries to rewrite the
+ url. Returns either the rewritten URL or the URL with optional
+ protocol scheme additions.
+ """
+ rewritten_url = url
+ if add_https:
+ rewritten_url = "https://" + rewritten_url
+
+ rewritten_url = self.rewrite_identity_server_urls.get(
+ rewritten_url, rewritten_url
+ )
+ logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
+ return rewritten_url
+
async def requestEmailToken(
- self, id_server, email, client_secret, send_attempt, next_link=None
+ self, id_server_url, email, client_secret, send_attempt, next_link=None
):
"""
Request an external server send an email on our behalf for the purposes of threepid
validation.
Args:
- id_server (str): The identity server to proxy to
+ id_server_url (str): The identity server to proxy to
email (str): The email to send the message to
client_secret (str): The unique client_secret sends by the user
send_attempt (int): Which attempt this is
@@ -393,6 +438,11 @@ class IdentityHandler(BaseHandler):
"client_secret": client_secret,
"send_attempt": send_attempt,
}
+
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
if next_link:
params["next_link"] = next_link
@@ -407,7 +457,8 @@ class IdentityHandler(BaseHandler):
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
+ "%s/_matrix/identity/api/v1/validate/email/requestToken"
+ % (id_server_url,),
params,
)
return data
@@ -419,7 +470,7 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
- id_server,
+ id_server_url,
country,
phone_number,
client_secret,
@@ -430,7 +481,7 @@ class IdentityHandler(BaseHandler):
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
- id_server (str): The identity server to proxy to
+ id_server_url (str): The identity server to proxy to
country (str): The country code of the phone number
phone_number (str): The number to send the message to
client_secret (str): The unique client_secret sends by the user
@@ -458,9 +509,13 @@ class IdentityHandler(BaseHandler):
"details and update your config file."
)
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
+ "%s/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ % (id_server_url,),
params,
)
except HttpResponseException as e:
@@ -554,6 +609,89 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
+ # TODO: The following two methods are used for proxying IS requests using
+ # the CS API. They should be consolidated with those in RoomMemberHandler
+ # https://github.com/matrix-org/synapse-dinsic/issues/25
+
+ @defer.inlineCallbacks
+ def proxy_lookup_3pid(self, id_server, medium, address):
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = yield self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
+ {"medium": medium, "address": address},
+ )
+
+ if "mxid" in data:
+ if "signatures" not in data:
+ raise AuthError(401, "No signatures on 3pid binding")
+ yield self._verify_any_signature(data, id_server)
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
+ @defer.inlineCallbacks
+ def proxy_bulk_lookup_3pid(self, id_server, threepids):
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ threepids ([[str, str]]): The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = yield self.http_client.post_json_get_json(
+ "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ defer.returnValue(data)
+
async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
@@ -568,10 +706,13 @@ class IdentityHandler(BaseHandler):
Returns:
str|None: the matrix ID of the 3pid, or None if it is not recognized.
"""
+ # Rewrite id_server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
if id_access_token is not None:
try:
results = await self._lookup_3pid_v2(
- id_server, id_access_token, medium, address
+ id_server_url, id_access_token, medium, address
)
return results
@@ -589,14 +730,15 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return await self._lookup_3pid_v1(id_server, medium, address)
+ return await self._lookup_3pid_v1(id_server, id_server_url, medium, address)
- async def _lookup_3pid_v1(self, id_server, medium, address):
+ async def _lookup_3pid_v1(self, id_server, id_server_url, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
+ id_server_url (str): The actual, reachable domain of the id server
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
@@ -604,8 +746,8 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
{"medium": medium, "address": address},
)
@@ -621,12 +763,11 @@ class IdentityHandler(BaseHandler):
return None
- async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
+ async def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
- id_server (str): The server name (including port, if required)
- of the identity server to use.
+ id_server_url (str): The protocol scheme and domain of the id server
id_access_token (str): The access token to authenticate to the identity server with
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
@@ -636,8 +777,8 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
+ hash_details = await self.http_client.get_json(
+ "%s/_matrix/identity/v2/hash_details" % (id_server_url,),
{"access_token": id_access_token},
)
except TimeoutError:
@@ -645,15 +786,14 @@ class IdentityHandler(BaseHandler):
if not isinstance(hash_details, dict):
logger.warning(
- "Got non-dict object when checking hash details of %s%s: %s",
- id_server_scheme,
- id_server,
+ "Got non-dict object when checking hash details of %s: %s",
+ id_server_url,
hash_details,
)
raise SynapseError(
400,
- "Non-dict object from %s%s during v2 hash_details request: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Non-dict object from %s during v2 hash_details request: %s"
+ % (id_server_url, hash_details),
)
# Extract information from hash_details
@@ -667,8 +807,8 @@ class IdentityHandler(BaseHandler):
):
raise SynapseError(
400,
- "Invalid hash details received from identity server %s%s: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Invalid hash details received from identity server %s: %s"
+ % (id_server_url, hash_details),
)
# Check if any of the supported lookup algorithms are present
@@ -690,7 +830,7 @@ class IdentityHandler(BaseHandler):
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
- id_server,
+ id_server_url,
supported_lookup_algorithms,
)
raise SynapseError(
@@ -703,8 +843,8 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = await self.blacklisting_http_client.post_json_get_json(
- "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
+ lookup_results = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/v2/lookup" % (id_server_url,),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
@@ -731,30 +871,31 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
- async def _verify_any_signature(self, data, server_hostname):
- if server_hostname not in data["signatures"]:
- raise AuthError(401, "No signature from server %s" % (server_hostname,))
- for key_name, signature in data["signatures"][server_hostname].items():
- try:
- key_data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s"
- % (id_server_scheme, server_hostname, key_name)
- )
- except TimeoutError:
- raise SynapseError(500, "Timed out contacting identity server")
+ async def _verify_any_signature(self, data, id_server):
+ if id_server not in data["signatures"]:
+ raise AuthError(401, "No signature from server %s" % (id_server,))
+
+ for key_name, signature in data["signatures"][id_server].items():
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ key_data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_url, key_name)
+ )
if "public_key" not in key_data:
raise AuthError(
- 401, "No public key named %s from %s" % (key_name, server_hostname)
+ 401, "No public key named %s from %s" % (key_name, id_server)
)
verify_signed_json(
data,
- server_hostname,
+ id_server,
decode_verify_key_bytes(
key_name, decode_base64(key_data["public_key"])
),
)
return
+ raise AuthError(401, "No signature from server %s" % (id_server,))
+
async def ask_id_server_for_third_party_invite(
self,
requester,
@@ -814,15 +955,17 @@ class IdentityHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url,
}
+ # Rewrite the identity server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
- base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+ base_url = "%s/_matrix/identity" % (id_server_url,)
if id_access_token:
- key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (
+ id_server_url,
)
# Attempt a v2 lookup
@@ -841,9 +984,8 @@ class IdentityHandler(BaseHandler):
raise e
if data is None:
- key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+ id_server_url,
)
url = base_url + "/api/v1/store-invite"
@@ -855,10 +997,7 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
- "Error trying to call /store-invite on %s%s: %s",
- id_server_scheme,
- id_server,
- e,
+ "Error trying to call /store-invite on %s: %s", id_server_url, e,
)
if data is None:
@@ -871,10 +1010,9 @@ class IdentityHandler(BaseHandler):
)
except HttpResponseException as e:
logger.warning(
- "Error calling /store-invite on %s%s with fallback "
+ "Error calling /store-invite on %s with fallback "
"encoding: %s",
- id_server_scheme,
- id_server,
+ id_server_url,
e,
)
raise e
@@ -895,6 +1033,30 @@ class IdentityHandler(BaseHandler):
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
+ async def bind_email_using_internal_sydent_api(
+ self, id_server_url: str, email: str, user_id: str,
+ ):
+ """Bind an email to a fully qualified user ID using the internal API of an
+ instance of Sydent.
+
+ Args:
+ id_server_url: The URL of the Sydent instance
+ email: The email address to bind
+ user_id: The user ID to bind the email to
+
+ Raises:
+ HTTPResponseException: On a non-2xx HTTP response.
+ """
+ # id_server_url is assumed to have no trailing slashes
+ url = id_server_url + "/_matrix/identity/internal/bind"
+ body = {
+ "address": email,
+ "medium": "email",
+ "mxid": user_id,
+ }
+
+ await self.http_client.post_json_get_json(url, body)
+
def create_id_access_token_header(id_access_token):
"""Create an Authorization header for passing to SimpleHttpClient as the header value
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 665ad19b5d..da206e1ec1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from canonicaljson import encode_canonical_json, json
@@ -55,6 +55,9 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -349,7 +352,7 @@ _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
class EventCreationHandler(object):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -814,11 +817,17 @@ class EventCreationHandler(object):
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
- try:
- await self.auth.check_from_context(room_version, event, context)
- except AuthError as err:
- logger.warning("Denying new event %r because %s", event, err)
- raise err
+ if event.internal_metadata.is_out_of_band_membership():
+ # the only sort of out-of-band-membership events we expect to see here
+ # are invite rejections we have generated ourselves.
+ assert event.type == EventTypes.Member
+ assert event.content["membership"] == Membership.LEAVE
+ else:
+ try:
+ await self.auth.check_from_context(room_version, event, context)
+ except AuthError as err:
+ logger.warning("Denying new event %r because %s", event, err)
+ raise err
# Ensure that we can round trip before trying to persist in db
try:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4b1e3073a8..dd8979e750 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,8 +15,13 @@
# limitations under the License.
import logging
+from typing import List
-from twisted.internet import defer
+from six.moves import range
+
+from signedjson.sign import sign_json
+
+from twisted.internet import defer, reactor
from synapse.api.errors import (
AuthError,
@@ -25,6 +31,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester, get_domain_from_id
@@ -44,6 +51,8 @@ class BaseProfileHandler(BaseHandler):
subclass MasterProfileHandler
"""
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs):
super(BaseProfileHandler, self).__init__(hs)
@@ -54,6 +63,88 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+ self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._assign_profile_replication_batches)
+ reactor.callWhenRunning(self._replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ @defer.inlineCallbacks
+ def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = yield self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ @defer.inlineCallbacks
+ def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = yield self.store.get_replication_hosts()
+ latest_batch = yield self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ yield self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ @defer.inlineCallbacks
+ def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = yield self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ yield self.http_client.post_json_get_json(url, signed_body)
+ yield self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
@@ -153,7 +244,7 @@ class BaseProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
- if not by_admin and target_user != requester.user:
+ if not by_admin and requester and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
@@ -173,13 +264,23 @@ class BaseProfileHandler(BaseHandler):
if new_displayname == "":
new_displayname = None
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
# If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
requester = create_requester(target_user)
- await self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ await self.store.set_profile_displayname(
+ target_user.localpart, new_displayname, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -189,6 +290,50 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ @defer.inlineCallbacks
+ def set_active(
+ self, users: List[UserID], active: bool, hide: bool,
+ ):
+ """
+ Sets the 'active' flag on a set of user profiles. If set to false, the
+ accounts are considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the users rather than deactivating them. This means withholding the
+ profiles from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB.
+
+ Note that unlike set_displayname and set_avatar_url, this does *not*
+ perform authorization checks! This is because the only place it's used
+ currently is in account deactivation where we've already done these
+ checks anyway.
+
+ Args:
+ users: The users to modify
+ active: Whether to set the user to active or inactive
+ hide: Whether to hide the user (withold from replication). If
+ False and active is False, user will have their profile
+ erased
+
+ Returns:
+ Deferred
+ """
+ if len(self.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ yield self.store.set_profiles_active(users, active, hide, new_batchnum)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user):
@@ -239,11 +384,51 @@ class BaseProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
+ # Enforce a max avatar size if one is defined
+ if self.max_avatar_size or self.allowed_avatar_mimetypes:
+ media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url)
+
+ # Check that this media exists locally
+ media_info = await self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
# Same like set_displayname
if by_admin:
requester = create_requester(target_user)
- await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ await self.store.set_profile_avatar_url(
+ target_user.localpart, new_avatar_url, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -253,6 +438,23 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
@defer.inlineCallbacks
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 78c3772ac1..f223630d43 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -48,6 +48,7 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
@@ -60,6 +61,8 @@ class RegistrationHandler(BaseHandler):
)
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -75,8 +78,18 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
async def check_username(
- self, localpart, guest_access_token=None, assigned_user_id=None
+ self, localpart, guest_access_token=None, assigned_user_id=None,
):
+ """
+
+ Args:
+ localpart (str|None): The user's localpart
+ guest_access_token (str|None): A guest's access token
+ assigned_user_id (str|None): An existing User ID for this user if pre-calculated
+
+ Returns:
+ Deferred
+ """
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -119,6 +132,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
+
+ # Retrieve guest user information from provided access token
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
@@ -204,6 +219,11 @@ class RegistrationHandler(BaseHandler):
address=address,
)
+ if default_display_name:
+ await self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
@@ -234,6 +254,10 @@ class RegistrationHandler(BaseHandler):
address=address,
)
+ await self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
# Successfully registered
break
except SynapseError:
@@ -267,7 +291,15 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- await self._register_email_threepid(user_id, threepid_dict, None)
+ await self.register_email_threepid(user_id, threepid_dict, None)
+
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ await self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ await self.profile_handler.set_active([user], False, True)
return user_id
@@ -461,7 +493,10 @@ class RegistrationHandler(BaseHandler):
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart, as_token):
+ async def appservice_register(
+ self, user_localpart, as_token, password_hash, display_name
+ ):
+ # FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -478,12 +513,25 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
+ display_name = display_name or user.localpart
+
await self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
+ )
+
+ await self.profile_handler.set_displayname(
+ user, None, display_name, by_admin=True
)
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = await self.store.get_profileinfo(user_localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
return user_id
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
@@ -510,6 +558,37 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
+ async def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
async def _generate_user_id(self):
if self._next_generated_user_id is None:
with await self._generate_user_id_linearizer.queue(()):
@@ -663,6 +742,7 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
+
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
@@ -670,7 +750,32 @@ class RegistrationHandler(BaseHandler):
):
await self.store.upsert_monthly_active_user(user_id)
- await self._register_email_threepid(user_id, threepid, access_token)
+ await self.register_email_threepid(user_id, threepid, access_token)
+
+ if self.hs.config.bind_new_user_emails_to_sydent:
+ # Attempt to call Sydent's internal bind API on the given identity server
+ # to bind this threepid
+ id_server_url = self.hs.config.bind_new_user_emails_to_sydent
+
+ logger.debug(
+ "Attempting the bind email of %s to identity server: %s using "
+ "internal Sydent bind API.",
+ user_id,
+ self.hs.config.bind_new_user_emails_to_sydent,
+ )
+
+ try:
+ await self.identity_handler.bind_email_using_internal_sydent_api(
+ id_server_url, threepid["address"], user_id
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to bind email of '%s' to Sydent instance '%s' ",
+ "using Sydent internal bind API: %s",
+ user_id,
+ id_server_url,
+ e,
+ )
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
@@ -691,7 +796,7 @@ class RegistrationHandler(BaseHandler):
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def _register_email_threepid(self, user_id, threepid, token):
+ async def register_email_threepid(self, user_id, threepid, token):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 950a84acd0..ac9a58e766 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -335,7 +335,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -581,8 +593,14 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -617,7 +635,6 @@ class RoomCreationHandler(BaseHandler):
else:
room_alias = None
- invite_list = config.get("invite", [])
for i in invite_list:
try:
uid = UserID.from_string(i)
@@ -639,8 +656,6 @@ class RoomCreationHandler(BaseHandler):
% (user_id,),
)
- invite_3pid_list = config.get("invite_3pid", [])
-
visibility = config.get("visibility", None)
is_public = visibility == "public"
@@ -742,6 +757,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -757,6 +773,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
id_access_token=id_access_token,
)
@@ -826,6 +843,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=False,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 27c479da9e..c3fee116c2 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,17 +16,21 @@
import abc
import logging
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, Union
+
+from unpaddedbase64 import encode_base64
from synapse import types
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.room_versions import EventFormatVersions
+from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase
+from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
-from synapse.replication.http.membership import (
- ReplicationLocallyRejectInviteRestServlet,
-)
-from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
+from synapse.events.validator import EventValidator
+from synapse.storage.roommember import RoomsForUser
+from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -58,6 +60,7 @@ class RoomMemberHandler(object):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.identity_handler = hs.get_handlers().identity_handler
self.member_linearizer = Linearizer(name="member")
@@ -74,10 +77,6 @@ class RoomMemberHandler(object):
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
- else:
- self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client(
- hs
- )
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
@@ -105,46 +104,28 @@ class RoomMemberHandler(object):
raise NotImplementedError()
@abc.abstractmethod
- async def _remote_reject_invite(
+ async def remote_reject_invite(
self,
+ invite_event_id: str,
+ txn_id: Optional[str],
requester: Requester,
- remote_room_hosts: List[str],
- room_id: str,
- target: UserID,
- content: dict,
- ) -> Tuple[Optional[str], int]:
- """Attempt to reject an invite for a room this server is not in. If we
- fail to do so we locally mark the invite as rejected.
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rejects an out-of-band invite we have received from a remote server
Args:
- requester
- remote_room_hosts: List of servers to use to try and reject invite
- room_id
- target: The user rejecting the invite
- content: The content for the rejection event
+ invite_event_id: ID of the invite to be rejected
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the rejection request, according to the access token
+ content: additional content to include in the rejection event.
+ Normally an empty dict.
Returns:
- A dictionary to be returned to the client, may
- include event_id etc, or nothing if we locally rejected
+ event id, stream_id of the leave event
"""
raise NotImplementedError()
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
- if self._is_on_event_persistence_instance:
- return await self.persist_event_storage.locally_reject_invite(
- user_id, room_id
- )
- else:
- result = await self._locally_reject_client(
- instance_name=self._event_stream_writer_instance,
- user_id=user_id,
- room_id=room_id,
- )
- return result["stream_id"]
-
@abc.abstractmethod
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the
@@ -287,8 +268,10 @@ class RoomMemberHandler(object):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
- ) -> Tuple[Optional[str], int]:
+ ) -> Tuple[str, int]:
+ """Update a user's membership in a room"""
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -302,6 +285,7 @@ class RoomMemberHandler(object):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
@@ -318,8 +302,9 @@ class RoomMemberHandler(object):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
- ) -> Tuple[Optional[str], int]:
+ ) -> Tuple[str, int]:
content_specified = bool(content)
if content is None:
content = {}
@@ -382,8 +367,15 @@ class RoomMemberHandler(object):
)
block_invite = True
+ is_published = await self.store.is_room_published(room_id)
+
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(),
+ target.to_string(),
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -461,6 +453,25 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ inviter = await self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if not is_host_in_room:
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
@@ -485,11 +496,17 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- inviter = await self._get_inviter(target.to_string(), room_id)
- if not inviter:
+ invite = await self.store.get_invite_for_local_user_in_room(
+ user_id=target.to_string(), room_id=room_id
+ ) # type: Optional[RoomsForUser]
+ if not invite:
raise SynapseError(404, "Not a known room")
- if self.hs.is_mine(inviter):
+ logger.info(
+ "%s rejects invite to %s from %s", target, room_id, invite.sender
+ )
+
+ if self.hs.is_mine_id(invite.sender):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
@@ -497,10 +514,10 @@ class RoomMemberHandler(object):
# active on other servers.
pass
else:
- # send the rejection to the inviter's HS.
- remote_room_hosts = remote_room_hosts + [inviter.domain]
- return await self._remote_reject_invite(
- requester, remote_room_hosts, room_id, target, content,
+ # send the rejection to the inviter's HS (with fallback to
+ # local event)
+ return await self.remote_reject_invite(
+ invite.event_id, txn_id, requester, content,
)
return await self._local_membership_update(
@@ -735,6 +752,7 @@ class RoomMemberHandler(object):
id_server: str,
requester: Requester,
txn_id: Optional[str],
+ new_room: bool = False,
id_access_token: Optional[str] = None,
) -> int:
if self.config.block_non_admin_invites:
@@ -758,6 +776,16 @@ class RoomMemberHandler(object):
Codes.FORBIDDEN,
)
+ can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
if not self._enable_lookup:
raise SynapseError(
403, "Looking up third-party identifiers is denied from this server"
@@ -767,6 +795,19 @@ class RoomMemberHandler(object):
id_server, medium, address, id_access_token
)
+ is_published = await self.store.is_room_published(room_id)
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
@@ -1014,33 +1055,119 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
- async def _remote_reject_invite(
+ async def remote_reject_invite(
self,
+ invite_event_id: str,
+ txn_id: Optional[str],
requester: Requester,
- remote_room_hosts: List[str],
- room_id: str,
- target: UserID,
- content: dict,
- ) -> Tuple[Optional[str], int]:
- """Implements RoomMemberHandler._remote_reject_invite
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rejects an out-of-band invite received from a remote user
+
+ Implements RoomMemberHandler.remote_reject_invite
"""
+ invite_event = await self.store.get_event(invite_event_id)
+ room_id = invite_event.room_id
+ target_user = invite_event.state_key
+
+ # first of all, try doing a rejection via the inviting server
fed_handler = self.federation_handler
try:
+ inviter_id = UserID.from_string(invite_event.sender)
event, stream_id = await fed_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, target.to_string(), content=content,
+ [inviter_id.domain], room_id, target_user, content=content
)
return event.event_id, stream_id
except Exception as e:
- # if we were unable to reject the exception, just mark
- # it as rejected on our end and plough ahead.
+ # if we were unable to reject the invite, we will generate our own
+ # leave event.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warning("Failed to reject invite: %s", e)
- stream_id = await self.locally_reject_invite(target.to_string(), room_id)
- return None, stream_id
+ return await self._locally_reject_invite(
+ invite_event, txn_id, requester, content
+ )
+
+ async def _locally_reject_invite(
+ self,
+ invite_event: EventBase,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Generate a local invite rejection
+
+ This is called after we fail to reject an invite via a remote server. It
+ generates an out-of-band membership event locally.
+
+ Args:
+ invite_event: the invite to be rejected
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the rejection request, according to the access token
+ content: additional content to include in the rejection event.
+ Normally an empty dict.
+ """
+
+ room_id = invite_event.room_id
+ target_user = invite_event.state_key
+ room_version = await self.store.get_room_version(room_id)
+
+ content["membership"] = Membership.LEAVE
+
+ # the auth events for the new event are the same as that of the invite, plus
+ # the invite itself.
+ #
+ # the prev_events are just the invite.
+ invite_hash = invite_event.event_id # type: Union[str, Tuple]
+ if room_version.event_format == EventFormatVersions.V1:
+ alg, h = compute_event_reference_hash(invite_event)
+ invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
+
+ auth_events = tuple(invite_event.auth_events) + (invite_hash,)
+ prev_events = (invite_hash,)
+
+ # we cap depth of generated events, to ensure that they are not
+ # rejected by other servers (and so that they can be persisted in
+ # the db)
+ depth = min(invite_event.depth + 1, MAX_DEPTH)
+
+ event_dict = {
+ "depth": depth,
+ "auth_events": auth_events,
+ "prev_events": prev_events,
+ "type": EventTypes.Member,
+ "room_id": room_id,
+ "sender": target_user,
+ "content": content,
+ "state_key": target_user,
+ }
+
+ event = create_local_event_from_event_dict(
+ clock=self.clock,
+ hostname=self.hs.hostname,
+ signing_key=self.hs.signing_key,
+ room_version=room_version,
+ event_dict=event_dict,
+ )
+ event.internal_metadata.outlier = True
+ event.internal_metadata.out_of_band_membership = True
+ if txn_id is not None:
+ event.internal_metadata.txn_id = txn_id
+ if requester.access_token_id is not None:
+ event.internal_metadata.token_id = requester.access_token_id
+
+ EventValidator().validate_new(event, self.config)
+
+ context = await self.state_handler.compute_event_context(event)
+ context.app_service = requester.app_service
+ stream_id = await self.event_creation_handler.handle_new_client_event(
+ requester, event, context, extra_users=[UserID.from_string(target_user)],
+ )
+ return event.event_id, stream_id
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 02e0c4103d..897338fd54 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -61,21 +61,22 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret["event_id"], ret["stream_id"]
- async def _remote_reject_invite(
+ async def remote_reject_invite(
self,
+ invite_event_id: str,
+ txn_id: Optional[str],
requester: Requester,
- remote_room_hosts: List[str],
- room_id: str,
- target: UserID,
content: dict,
- ) -> Tuple[Optional[str], int]:
- """Implements RoomMemberHandler._remote_reject_invite
+ ) -> Tuple[str, int]:
+ """
+ Rejects an out-of-band invite received from a remote user
+
+ Implements RoomMemberHandler.remote_reject_invite
"""
ret = await self._remote_reject_client(
+ invite_event_id=invite_event_id,
+ txn_id=txn_id,
requester=requester,
- remote_room_hosts=remote_room_hosts,
- room_id=room_id,
- user_id=target.to_string(),
content=content,
)
return ret["event_id"], ret["stream_id"]
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 4d245b618b..5d34989f21 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6c7abaa578..879c4c07c6 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -294,6 +294,9 @@ class TypingHandler(object):
rows.sort()
limited = False
+ # We, unusually, use a strict limit here as we have all the rows in
+ # memory rather than pulling them out of the database with a `LIMIT ?`
+ # clause.
if len(rows) > limit:
rows = rows[:limit]
current_id = rows[-1][0]
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 096619a8c2..479746c9c5 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -13,13 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource
-from synapse.http.server import wrap_json_request_handler
-
-class AdditionalResource(Resource):
+class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources
If the user has configured additional_resources, we need to wrap the
@@ -41,16 +38,10 @@ class AdditionalResource(Resource):
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
- Resource.__init__(self)
+ super().__init__()
self._handler = handler
- # required by the request_handler wrapper
- self.clock = hs.get_clock()
-
- def render(self, request):
- self._async_render(request)
- return NOT_DONE_YET
-
- @wrap_json_request_handler
def _async_render(self, request):
+ # Cheekily pass the result straight through, so we don't need to worry
+ # if its an awaitable or not.
return self._handler(request)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 18f6a8fd29..148eeb19dc 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -176,7 +176,7 @@ class MatrixFederationHttpClient(object):
def __init__(self, hs, tls_client_options_factory):
self.hs = hs
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self.server_name = hs.hostname
real_reactor = hs.get_reactor()
@@ -562,13 +562,17 @@ class MatrixFederationHttpClient(object):
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
- request = {"method": method, "uri": url_bytes, "origin": self.server_name}
+ request = {
+ "method": method.decode("ascii"),
+ "uri": url_bytes.decode("ascii"),
+ "origin": self.server_name,
+ }
if destination is not None:
- request["destination"] = destination
+ request["destination"] = destination.decode("ascii")
if destination_is is not None:
- request["destination_is"] = destination_is
+ request["destination_is"] = destination_is.decode("ascii")
if content is not None:
request["content"] = content
diff --git a/synapse/http/server.py b/synapse/http/server.py
index d192de7923..2b35f86066 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import collections
import html
import logging
@@ -21,7 +22,7 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
-from typing import Awaitable, Callable, TypeVar, Union
+from typing import Any, Callable, Dict, Tuple, Union
import jinja2
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
"""
-def wrap_json_request_handler(h):
- """Wraps a request handler method with exception handling.
-
- Also does the wrapping with request.processing as per wrap_async_request_handler.
-
- The handler method must have a signature of "handle_foo(self, request)",
- where "request" must be a SynapseRequest.
-
- The handler must return a deferred or a coroutine. If the deferred succeeds
- we assume that a response has been sent. If the deferred fails with a SynapseError we use
- it to send a JSON response with the appropriate HTTP reponse code. If the
- deferred fails with any other type of error we send a 500 reponse.
+def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
+ """Sends a JSON error response to clients.
"""
- async def wrapped_request_handler(self, request):
- try:
- await h(self, request)
- except SynapseError as e:
- code = e.code
- logger.info("%s SynapseError: %s - %s", request, code, e.msg)
-
- # Only respond with an error response if we haven't already started
- # writing, otherwise lets just kill the connection
- if request.startedWriting:
- if request.transport:
- try:
- request.transport.abortConnection()
- except Exception:
- # abortConnection throws if the connection is already closed
- pass
- else:
- respond_with_json(
- request,
- code,
- e.error_dict(),
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- )
-
- except Exception:
- # failure.Failure() fishes the original Failure out
- # of our stack, and thus gives us a sensible stack
- # trace.
- f = failure.Failure()
- logger.error(
- "Failed handle request via %r: %r",
- request.request_metrics.name,
- request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
- )
- # Only respond with an error response if we haven't already started
- # writing, otherwise lets just kill the connection
- if request.startedWriting:
- if request.transport:
- try:
- request.transport.abortConnection()
- except Exception:
- # abortConnection throws if the connection is already closed
- pass
- else:
- respond_with_json(
- request,
- 500,
- {"error": "Internal server error", "errcode": Codes.UNKNOWN},
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- )
-
- return wrap_async_request_handler(wrapped_request_handler)
-
-
-TV = TypeVar("TV")
-
-
-def wrap_html_request_handler(
- h: Callable[[TV, SynapseRequest], Awaitable]
-) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
- """Wraps a request handler method with exception handling.
+ if f.check(SynapseError):
+ error_code = f.value.code
+ error_dict = f.value.error_dict()
- Also does the wrapping with request.processing as per wrap_async_request_handler.
-
- The handler method must have a signature of "handle_foo(self, request)",
- where "request" must be a SynapseRequest.
- """
+ logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+ else:
+ error_code = 500
+ error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
- async def wrapped_request_handler(self, request):
- try:
- await h(self, request)
- except Exception:
- f = failure.Failure()
- return_html_error(f, request, HTML_ERROR_TEMPLATE)
+ logger.error(
+ "Failed handle request via %r: %r",
+ request.request_metrics.name,
+ request,
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
- return wrap_async_request_handler(wrapped_request_handler)
+ # Only respond with an error response if we haven't already started writing,
+ # otherwise lets just kill the connection
+ if request.startedWriting:
+ if request.transport:
+ try:
+ request.transport.abortConnection()
+ except Exception:
+ # abortConnection throws if the connection is already closed
+ pass
+ else:
+ respond_with_json(
+ request,
+ error_code,
+ error_dict,
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
def return_html_error(
@@ -249,7 +194,113 @@ class HttpServer(object):
pass
-class JsonResource(HttpServer, resource.Resource):
+class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
+ """Base class for resources that have async handlers.
+
+ Sub classes can either implement `_async_render_<METHOD>` to handle
+ requests by method, or override `_async_render` to handle all requests.
+
+ Args:
+ extract_context: Whether to attempt to extract the opentracing
+ context from the request the servlet is handling.
+ """
+
+ def __init__(self, extract_context=False):
+ super().__init__()
+
+ self._extract_context = extract_context
+
+ def render(self, request):
+ """ This gets called by twisted every time someone sends us a request.
+ """
+ defer.ensureDeferred(self._async_render_wrapper(request))
+ return NOT_DONE_YET
+
+ @wrap_async_request_handler
+ async def _async_render_wrapper(self, request):
+ """This is a wrapper that delegates to `_async_render` and handles
+ exceptions, return values, metrics, etc.
+ """
+ try:
+ request.request_metrics.name = self.__class__.__name__
+
+ with trace_servlet(request, self._extract_context):
+ callback_return = await self._async_render(request)
+
+ if callback_return is not None:
+ code, response = callback_return
+ self._send_response(request, code, response)
+ except Exception:
+ # failure.Failure() fishes the original Failure out
+ # of our stack, and thus gives us a sensible stack
+ # trace.
+ f = failure.Failure()
+ self._send_error_response(f, request)
+
+ async def _async_render(self, request):
+ """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
+ no appropriate method exists. Can be overriden in sub classes for
+ different routing.
+ """
+
+ method_handler = getattr(
+ self, "_async_render_%s" % (request.method.decode("ascii"),), None
+ )
+ if method_handler:
+ raw_callback_return = method_handler(request)
+
+ # Is it synchronous? We'll allow this for now.
+ if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+ callback_return = await raw_callback_return
+ else:
+ callback_return = raw_callback_return
+
+ return callback_return
+
+ _unrecognised_request_handler(request)
+
+ @abc.abstractmethod
+ def _send_response(
+ self, request: SynapseRequest, code: int, response_object: Any,
+ ) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _send_error_response(
+ self, f: failure.Failure, request: SynapseRequest,
+ ) -> None:
+ raise NotImplementedError()
+
+
+class DirectServeJsonResource(_AsyncResource):
+ """A resource that will call `self._async_on_<METHOD>` on new requests,
+ formatting responses and errors as JSON.
+ """
+
+ def _send_response(
+ self, request, code, response_object,
+ ):
+ """Implements _AsyncResource._send_response
+ """
+ # TODO: Only enable CORS for the requests that need it.
+ respond_with_json(
+ request,
+ code,
+ response_object,
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ canonical_json=self.canonical_json,
+ )
+
+ def _send_error_response(
+ self, f: failure.Failure, request: SynapseRequest,
+ ) -> None:
+ """Implements _AsyncResource._send_error_response
+ """
+ return_json_error(f, request)
+
+
+class JsonResource(DirectServeJsonResource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
@@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
- def __init__(self, hs, canonical_json=True):
- resource.Resource.__init__(self)
+ def __init__(self, hs, canonical_json=True, extract_context=False):
+ super().__init__(extract_context)
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.hs = hs
- def register_paths(
- self, method, path_patterns, callback, servlet_classname, trace=True
- ):
+ def register_paths(self, method, path_patterns, callback, servlet_classname):
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
@@ -295,37 +344,42 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
-
- trace (bool): Whether we should start a span to trace the servlet.
"""
method = method.encode("utf-8") # method is bytes on py3
- if trace:
- # We don't extract the context from the servlet because we can't
- # trust the sender
- callback = trace_servlet(servlet_classname)(callback)
-
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname)
)
- def render(self, request):
- """ This gets called by twisted every time someone sends us a request.
+ def _get_handler_for_request(
+ self, request: SynapseRequest
+ ) -> Tuple[Callable, str, Dict[str, str]]:
+ """Finds a callback method to handle the given request.
+
+ Returns:
+ A tuple of the callback to use, the name of the servlet, and the
+ key word arguments to pass to the callback
"""
- defer.ensureDeferred(self._async_render(request))
- return NOT_DONE_YET
+ request_path = request.path.decode("ascii")
+
+ # Loop through all the registered callbacks to check if the method
+ # and path regex match
+ for path_entry in self.path_regexs.get(request.method, []):
+ m = path_entry.pattern.match(request_path)
+ if m:
+ # We found a match!
+ return path_entry.callback, path_entry.servlet_classname, m.groupdict()
+
+ # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+ return _unrecognised_request_handler, "unrecognised_request_handler", {}
- @wrap_json_request_handler
async def _async_render(self, request):
- """ This gets called from render() every time someone sends us a request.
- This checks if anyone has registered a callback for that method and
- path.
- """
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
- # Make sure we have a name for this handler in prometheus.
+ # Make sure we have an appopriate name for this handler in prometheus
+ # (rather than the default of JsonResource).
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
@@ -338,81 +392,42 @@ class JsonResource(HttpServer, resource.Resource):
}
)
- callback_return = callback(request, **kwargs)
+ raw_callback_return = callback(request, **kwargs)
# Is it synchronous? We'll allow this for now.
- if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
- callback_return = await callback_return
+ if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+ callback_return = await raw_callback_return
+ else:
+ callback_return = raw_callback_return
- if callback_return is not None:
- code, response = callback_return
- self._send_response(request, code, response)
+ return callback_return
- def _get_handler_for_request(self, request):
- """Finds a callback method to handle the given request
- Args:
- request (twisted.web.http.Request):
+class DirectServeHtmlResource(_AsyncResource):
+ """A resource that will call `self._async_on_<METHOD>` on new requests,
+ formatting responses and errors as HTML.
+ """
- Returns:
- Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
- label to use for that method in prometheus metrics, and the
- dict mapping keys to path components as specified in the
- handler's path match regexp.
-
- The callback will normally be a method registered via
- register_paths, so will return (possibly via Deferred) either
- None, or a tuple of (http code, response body).
- """
- request_path = request.path.decode("ascii")
-
- # Loop through all the registered callbacks to check if the method
- # and path regex match
- for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request_path)
- if m:
- # We found a match!
- return path_entry.callback, path_entry.servlet_classname, m.groupdict()
-
- # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- return _unrecognised_request_handler, "unrecognised_request_handler", {}
+ # The error template to use for this resource
+ ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
- self, request, code, response_json_object, response_code_message=None
+ self, request: SynapseRequest, code: int, response_object: Any,
):
- # TODO: Only enable CORS for the requests that need it.
- respond_with_json(
- request,
- code,
- response_json_object,
- send_cors=True,
- response_code_message=response_code_message,
- pretty_print=_request_user_agent_is_curl(request),
- canonical_json=self.canonical_json,
- )
-
-
-class DirectServeResource(resource.Resource):
- def render(self, request):
+ """Implements _AsyncResource._send_response
"""
- Render the request, using an asynchronous render handler if it exists.
- """
- async_render_callback_name = "_async_render_" + request.method.decode("ascii")
-
- # Try and get the async renderer
- callback = getattr(self, async_render_callback_name, None)
+ # We expect to get bytes for us to write
+ assert isinstance(response_object, bytes)
+ html_bytes = response_object
- # No async renderer for this request method.
- if not callback:
- return super().render(request)
+ respond_with_html_bytes(request, 200, html_bytes)
- resp = trace_servlet(self.__class__.__name__)(callback)(request)
-
- # If it's a coroutine, turn it into a Deferred
- if isinstance(resp, types.CoroutineType):
- defer.ensureDeferred(resp)
-
- return NOT_DONE_YET
+ def _send_error_response(
+ self, f: failure.Failure, request: SynapseRequest,
+ ) -> None:
+ """Implements _AsyncResource._send_error_response
+ """
+ return_html_error(f, request, self.ERROR_TEMPLATE)
class StaticResource(File):
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 73bef5e5ca..c6c0e623c1 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -164,12 +164,10 @@ Gotchas
than one caller? Will all of those calling functions have be in a context
with an active span?
"""
-
import contextlib
import inspect
import logging
import re
-import types
from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Type
@@ -181,6 +179,7 @@ from twisted.internet import defer
from synapse.config import ConfigError
if TYPE_CHECKING:
+ from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
# Helper class
@@ -227,6 +226,7 @@ except ImportError:
tags = _DummyTagNames
try:
from jaeger_client import Config as JaegerConfig
+
from synapse.logging.scopecontextmanager import LogContextScopeManager
except ImportError:
JaegerConfig = None # type: ignore
@@ -793,48 +793,42 @@ def tag_args(func):
return _tag_args_inner
-def trace_servlet(servlet_name, extract_context=False):
- """Decorator which traces a serlet. It starts a span with some servlet specific
- tags such as the servlet_name and request information
+@contextlib.contextmanager
+def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+ """Returns a context manager which traces a request. It starts a span
+ with some servlet specific tags such as the request metrics name and
+ request information.
Args:
- servlet_name (str): The name to be used for the span's operation_name
- extract_context (bool): Whether to attempt to extract the opentracing
+ request
+ extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
-
"""
- def _trace_servlet_inner_1(func):
- if not opentracing:
- return func
-
- @wraps(func)
- async def _trace_servlet_inner(request, *args, **kwargs):
- request_tags = {
- "request_id": request.get_request_id(),
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
- tags.HTTP_METHOD: request.get_method(),
- tags.HTTP_URL: request.get_redacted_uri(),
- tags.PEER_HOST_IPV6: request.getClientIP(),
- }
-
- if extract_context:
- scope = start_active_span_from_request(
- request, servlet_name, tags=request_tags
- )
- else:
- scope = start_active_span(servlet_name, tags=request_tags)
-
- with scope:
- result = func(request, *args, **kwargs)
+ if opentracing is None:
+ yield
+ return
- if not isinstance(result, (types.CoroutineType, defer.Deferred)):
- # Some servlets aren't async and just return results
- # directly, so we handle that here.
- return result
+ request_tags = {
+ "request_id": request.get_request_id(),
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ tags.HTTP_METHOD: request.get_method(),
+ tags.HTTP_URL: request.get_redacted_uri(),
+ tags.PEER_HOST_IPV6: request.getClientIP(),
+ }
- return await result
+ request_name = request.request_metrics.name
+ if extract_context:
+ scope = start_active_span_from_request(request, request_name, tags=request_tags)
+ else:
+ scope = start_active_span(request_name, tags=request_tags)
- return _trace_servlet_inner
+ with scope:
+ try:
+ yield
+ finally:
+ # We set the operation name again in case its changed (which happens
+ # with JsonResource).
+ scope.span.set_operation_name(request.request_metrics.name)
- return _trace_servlet_inner_1
+ scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 13785038ad..a9269196b3 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set
from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
+from twisted.python.failure import Failure
from synapse.logging.context import LoggingContext, PreserveLoggingContext
@@ -212,7 +213,14 @@ def run_as_background_process(desc, func, *args, **kwargs):
return (yield result)
except Exception:
- logger.exception("Background process '%s' threw an exception", desc)
+ # failure.Failure() fishes the original Failure out of our stack, and
+ # thus gives us a sensible stack trace.
+ f = Failure()
+ logger.error(
+ "Background process '%s' threw an exception",
+ desc,
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
finally:
_background_process_in_flight_count.labels(desc).dec()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 87c120a59c..bd41f77852 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -83,7 +83,7 @@ class _NotifierUserStream(object):
self.current_token = current_token
# The last token for which we should wake up any streams that have a
- # token that comes before it. This gets updated everytime we get poked.
+ # token that comes before it. This gets updated every time we get poked.
# We start it at the current token since if we get any streams
# that have a token from before we have no idea whether they should be
# woken up or not, so lets just wake them up.
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index ed60dbc1bf..2fac07593b 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -20,6 +20,7 @@ from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from synapse.api.constants import EventTypes
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
@@ -305,12 +306,23 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
+ priority = "low"
+ if (
+ event.type == EventTypes.Encrypted
+ or tweaks.get("highlight")
+ or tweaks.get("sound")
+ ):
+ # HACK send our push as high priority only if it generates a sound, highlight
+ # or may do so (i.e. is encrypted so has unknown effects).
+ priority = "high"
+
if self.data.get("format") == "event_id_only":
d = {
"notification": {
"event_id": event.event_id,
"room_id": event.room_id,
"counts": {"unread": badge},
+ "prio": priority,
"devices": [
{
"app_id": self.app_id,
@@ -334,9 +346,8 @@ class HttpPusher(object):
"room_id": event.room_id,
"type": event.type,
"sender": event.user_id,
- "counts": { # -- we don't mark messages as read yet so
- # we have no way of knowing
- # Just set the badge to 1 until we have read receipts
+ "prio": priority,
+ "counts": {
"unread": badge,
# 'missed_calls': 2
},
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 8e0d3a416d..2d79ada189 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
import logging
import re
-from typing import Pattern
+from typing import Any, Dict, List, Pattern, Union
from synapse.events import EventBase
from synapse.types import UserID
@@ -72,13 +72,36 @@ def _test_ineq_condition(condition, number):
return False
-def tweaks_for_actions(actions):
+def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
+ """
+ Converts a list of actions into a `tweaks` dict (which can then be passed to
+ the push gateway).
+
+ This function ignores all actions other than `set_tweak` actions, and treats
+ absent `value`s as `True`, which agrees with the only spec-defined treatment
+ of absent `value`s (namely, for `highlight` tweaks).
+
+ Args:
+ actions: list of actions
+ e.g. [
+ {"set_tweak": "a", "value": "AAA"},
+ {"set_tweak": "b", "value": "BBB"},
+ {"set_tweak": "highlight"},
+ "notify"
+ ]
+
+ Returns:
+ dictionary of tweaks for those actions
+ e.g. {"a": "AAA", "b": "BBB", "highlight": True}
+ """
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
- if "set_tweak" in a and "value" in a:
- tweaks[a["set_tweak"]] = a["value"]
+ if "set_tweak" in a:
+ # value is allowed to be absent in which case the value assumed
+ # should be True.
+ tweaks[a["set_tweak"]] = a.get("value", True)
return tweaks
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index b1cac901eb..8cfcdb0573 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -66,7 +66,7 @@ REQUIREMENTS = [
"pymacaroons>=0.13.0",
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
- "prometheus_client>=0.0.18,<0.8.0",
+ "prometheus_client>=0.0.18,<0.9.0",
# we use attr.validators.deep_iterable, which arrived in 19.1.0
"attrs>=19.1.0",
"netaddr>=0.7.18",
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 19b69e0e11..5ef1c6c1dc 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource):
def __init__(self, hs):
- JsonResource.__init__(self, hs, canonical_json=False)
+ # We enable extracting jaeger contexts here as these are internal APIs.
+ super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)
def register_servlets(self, hs):
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 9caf1e80c1..fb0dd04f88 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -28,11 +28,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.logging.opentracing import (
- inject_active_span_byte_dict,
- trace,
- trace_servlet,
-)
+from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -96,11 +92,11 @@ class ReplicationEndpoint(object):
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
- ), "`instance_name` is a reserved paramater name"
+ ), "`instance_name` is a reserved parameter name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
- ), "`instance_name` is a reserved paramater name"
+ ), "`instance_name` is a reserved parameter name"
assert self.METHOD in ("PUT", "POST", "GET")
@@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
- handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
- # We don't let register paths trace this servlet using the default tracing
- # options because we wish to extract the context explicitly.
http_server.register_paths(
- method, [pattern], handler, self.__class__.__name__, trace=False
+ method, [pattern], handler, self.__class__.__name__,
)
def _cached_handler(self, request, txn_id, **kwargs):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index a7174c4a8f..63ef6eb7be 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,11 +14,11 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
if TYPE_CHECKING:
@@ -88,49 +88,54 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
- """Rejects the invite for the user and room.
+ """Rejects an out-of-band invite we have received from a remote server
Request format:
- POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
+ POST /_synapse/replication/remote_reject_invite/:event_id
{
+ "txn_id": ...,
"requester": ...,
- "remote_room_hosts": [...],
"content": { ... }
}
"""
NAME = "remote_reject_invite"
- PATH_ARGS = ("room_id", "user_id")
+ PATH_ARGS = ("invite_event_id",)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
- self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
+ def _serialize_payload( # type: ignore
+ invite_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ):
"""
Args:
- requester(Requester)
- room_id (str)
- user_id (str)
- remote_room_hosts (list[str]): Servers to try and reject via
+ invite_event_id: ID of the invite to be rejected
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the rejection request, according to the access token
+ content: additional content to include in the rejection event.
+ Normally an empty dict.
"""
return {
+ "txn_id": txn_id,
"requester": requester.serialize(),
- "remote_room_hosts": remote_room_hosts,
"content": content,
}
- async def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, invite_event_id):
content = parse_json_object_from_request(request)
- remote_room_hosts = content["remote_room_hosts"]
+ txn_id = content["txn_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -138,60 +143,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
if requester.user:
request.authenticated_entity = requester.user.to_string()
- logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
-
- try:
- event, stream_id = await self.federation_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, user_id, event_content,
- )
- event_id = event.event_id
- except Exception as e:
- # if we were unable to reject the exception, just mark
- # it as rejected on our end and plough ahead.
- #
- # The 'except' clause is very broad, but we need to
- # capture everything from DNS failures upwards
- #
- logger.warning("Failed to reject invite: %s", e)
-
- stream_id = await self.member_handler.locally_reject_invite(
- user_id, room_id
- )
- event_id = None
+ # hopefully we're now on the master, so this won't recurse!
+ event_id, stream_id = await self.member_handler.remote_reject_invite(
+ invite_event_id, txn_id, requester, event_content,
+ )
return 200, {"event_id": event_id, "stream_id": stream_id}
-class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
- """Rejects the invite for the user and room locally.
-
- Request format:
-
- POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
-
- {}
- """
-
- NAME = "locally_reject_invite"
- PATH_ARGS = ("room_id", "user_id")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
- self.member_handler = hs.get_room_member_handler()
-
- @staticmethod
- def _serialize_payload(room_id, user_id):
- return {}
-
- async def _handle_request(self, request, room_id, user_id):
- logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
-
- stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
-
- return 200, {"stream_id": stream_id}
-
-
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
@@ -245,4 +204,3 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
- ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 9db6c62bc7..525b94fd87 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,6 +16,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.storage.data_stores.main.tags import TagsWorkerStore
from synapse.storage.database import Database
@@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "tag_account_data":
+ if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- elif stream_name == "account_data":
+ elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..bd394f6b00 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
@@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "to_device":
+ if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(token)
for row in rows:
if row.entity.startswith("@"):
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..5d210fa3a1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "groups":
+ if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 4e0124842d..2938cb8e43 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore
from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database
@@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore):
return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "presence":
+ if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6adb19463a..23ec1c5b11 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore
@@ -30,7 +31,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "push_rules":
+ if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cb78b49acb..ff449f3658 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PushersStream
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.database import Database
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "pushers":
+ if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index be716cc558..6982686eb5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,20 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-# So, um, we want to borrow a load of functions intended for reading from
-# a DataStore, but we don't want to take functions that either write to the
-# DataStore or are cached and don't have cache invalidation logic.
-#
-# Rather than write duplicate versions of those functions, or lift them to
-# a common base class, we going to grab the underlying __func__ object from
-# the method descriptor on the DataStore and chuck them into our class.
-
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
@@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "receipts":
+ if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_receipt(
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8873bf37e5..8710207ada 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PublicRoomsStream
from synapse.storage.data_stores.main.room import RoomWorkerStore
from synapse.storage.database import Database
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "public_rooms":
+ if stream_name == PublicRoomsStream.NAME:
self._public_room_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 523a1358d4..1b8718b11d 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -25,7 +25,7 @@ Structure of the module:
* command.py - the definitions of all the valid commands
* protocol.py - the TCP protocol classes
* resource.py - handles streaming stream updates to replications
- * streams/ - the definitons of all the valid streams
+ * streams/ - the definitions of all the valid streams
The general interaction of the classes are:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index df29732f51..4985e40b1f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
- from synapse.server import HomeServer
from synapse.replication.tcp.handler import ReplicationCommandHandler
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index ea5937a20c..ccc7f1f0d1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -18,18 +18,11 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
allowed to be sent by which side.
"""
import abc
+import json
import logging
-import platform
from typing import Tuple, Type
-if platform.python_implementation() == "PyPy":
- import json
-
- _json_encoder = json.JSONEncoder()
-else:
- import simplejson as json # type: ignore[no-redef] # noqa: F821
-
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
+_json_encoder = json.JSONEncoder()
logger = logging.getLogger(__name__)
@@ -54,7 +47,7 @@ class Command(metaclass=abc.ABCMeta):
@abc.abstractmethod
def to_line(self) -> str:
- """Serialises the comamnd for the wire. Does not include the command
+ """Serialises the command for the wire. Does not include the command
prefix.
"""
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e6a2e2598b..55b3b79008 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
@@ -149,10 +148,11 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
+ import txredisapi
+
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
)
- import txredisapi
logger.info(
"Connecting to redis (host=%r port=%r)",
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 4198eece71..ca47f5cc88 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -317,7 +317,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def _queue_command(self, cmd):
"""Queue the command until the connection is ready to write to again.
"""
- logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+ logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
if len(self.pending_commands) > self.max_line_buffer:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index e776b63183..0a7e7f67be 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -177,7 +177,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
- send outbound commands (this is seperate to the redis connection
+ send outbound commands (this is separate to the redis connection
used to subscribe).
"""
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f196eff072..9076bbe9f1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -198,26 +198,6 @@ def current_token_without_instance(
return lambda instance_name: current_token()
-def db_query_to_update_function(
- query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
-) -> UpdateFunction:
- """Wraps a db query function which returns a list of rows to make it
- suitable for use as an `update_function` for the Stream class
- """
-
- async def update_function(instance_name, from_token, upto_token, limit):
- rows = await query_function(from_token, upto_token, limit)
- updates = [(row[0], row[1:]) for row in rows]
- limited = False
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
-
- return update_function
-
-
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
@@ -393,7 +373,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
- db_query_to_update_function(store.get_all_updated_pushers_rows),
+ store.get_all_updated_pushers_rows,
)
@@ -421,26 +401,12 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
- self.store = hs.get_datastore()
+ store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self.store.get_cache_stream_token,
- self._update_function,
- )
-
- async def _update_function(
- self, instance_name: str, from_token: int, upto_token: int, limit: int
- ):
- rows = await self.store.get_all_updated_caches(
- instance_name, from_token, upto_token, limit
+ store.get_cache_stream_token,
+ store.get_all_updated_caches,
)
- updates = [(row[0], row[1:]) for row in rows]
- limited = False
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
class PublicRoomsStream(Stream):
@@ -465,7 +431,7 @@ class PublicRoomsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_public_room_stream_id),
- db_query_to_update_function(store.get_all_new_public_rooms),
+ store.get_all_new_public_rooms,
)
@@ -486,7 +452,7 @@ class DeviceListsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
- db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+ store.get_all_device_list_changes_for_remotes,
)
@@ -504,7 +470,7 @@ class ToDeviceStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
- db_query_to_update_function(store.get_all_new_device_messages),
+ store.get_all_new_device_messages,
)
@@ -524,7 +490,7 @@ class TagAccountDataStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_account_data_stream_id),
- db_query_to_update_function(store.get_all_updated_tags),
+ store.get_all_updated_tags,
)
@@ -612,7 +578,7 @@ class GroupServerStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_group_stream_token),
- db_query_to_update_function(store.get_all_groups_changes),
+ store.get_all_groups_changes,
)
@@ -630,7 +596,5 @@ class UserSignatureStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
- db_query_to_update_function(
- store.get_all_user_signature_changes_for_remotes
- ),
+ store.get_all_user_signature_changes_for_remotes,
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f370390331..1c2a4cce7f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import heapq
from collections import Iterable
from typing import List, Tuple, Type
@@ -22,7 +21,6 @@ import attr
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
-
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@@ -64,7 +62,7 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- # Unique string that ids the type. Must be overriden in sub classes.
+ # Unique string that ids the type. Must be overridden in sub classes.
TypeId = None # type: str
@classmethod
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 46e458e95b..2e81eeff65 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -118,6 +118,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index bf0f9bd077..64d5c58b65 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -26,8 +27,9 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -113,7 +115,7 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
- def on_GET(self, request):
+ def on_GET(self, request: SynapseRequest):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -141,10 +143,10 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
- def on_OPTIONS(self, request):
+ def on_OPTIONS(self, request: SynapseRequest):
return 200, {}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
@@ -153,9 +155,9 @@ class LoginRestServlet(RestServlet):
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
- result = await self.do_jwt_login(login_submission)
+ result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = await self.do_token_login(login_submission)
+ result = await self._do_token_login(login_submission)
else:
result = await self._do_other_login(login_submission)
except KeyError:
@@ -166,14 +168,14 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- async def _do_other_login(self, login_submission):
+ async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
Returns:
- dict: HTTP response
+ HTTP response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
@@ -206,11 +208,14 @@ class LoginRestServlet(RestServlet):
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- address = address.lower()
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
@@ -288,25 +293,30 @@ class LoginRestServlet(RestServlet):
return result
async def _complete_login(
- self, user_id, login_submission, callback=None, create_non_existent_users=False
- ):
+ self,
+ user_id: str,
+ login_submission: JsonDict,
+ callback: Optional[
+ Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
+ ] = None,
+ create_non_existent_users: bool = False,
+ ) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
- all succesful logins.
+ all successful logins.
- Applies the ratelimiting for succesful login attempts against an
+ Applies the ratelimiting for successful login attempts against an
account.
Args:
- user_id (str): ID of the user to register.
- login_submission (dict): Dictionary of login information.
- callback (func|None): Callback function to run after registration.
- create_non_existent_users (bool): Whether to create the user if
- they don't exist. Defaults to False.
+ user_id: ID of the user to register.
+ login_submission: Dictionary of login information.
+ callback: Callback function to run after registration.
+ create_non_existent_users: Whether to create the user if they don't
+ exist. Defaults to False.
Returns:
- result (Dict[str,str]): Dictionary of account information after
- successful registration.
+ result: Dictionary of account information after successful registration.
"""
# Before we actually log them in we check if they've already logged in
@@ -340,7 +350,7 @@ class LoginRestServlet(RestServlet):
return result
- async def do_token_login(self, login_submission):
+ async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -350,7 +360,7 @@ class LoginRestServlet(RestServlet):
result = await self._complete_login(user_id, login_submission)
return result
- async def do_jwt_login(self, login_submission):
+ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..165313b572 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -63,11 +65,27 @@ class ProfileDisplaynameRestServlet(RestServlet):
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -76,6 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -114,11 +133,27 @@ class ProfileAvatarURLRestServlet(RestServlet):
user, requester, new_avatar_url, is_admin
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 46811abbfa..01b80b86fa 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -217,10 +217,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
event_id = event.event_id
- ret = {} # type: dict
- if event_id:
- set_tag("event_id", event_id)
- ret = {"event_id": event_id}
+ set_tag("event_id", event_id)
+ ret = {"event_id": event_id}
return 200, ret
@@ -725,7 +723,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
)
return 200, {}
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 747d46eac2..50277c6cf6 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -50,7 +50,7 @@ class VoipRestServlet(RestServlet):
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
- password = base64.b64encode(mac.digest())
+ password = base64.b64encode(mac.digest()).decode("ascii")
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 182a308eef..d4e0b962af 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
+
from http import HTTPStatus
+from twisted.internet import defer
+
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour
@@ -28,9 +32,10 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.push.mailer import Mailer, load_jinja2_templates
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
-from synapse.util.threepids import check_3pid_allowed
+from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -83,17 +88,29 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database. This allows the user to reset his password without having to
+ # know the exact spelling (eg. upper and lower case) of address in the database.
+ # Stored in the database "foo@bar.com"
+ # User requests with "FOO@bar.com" would raise a Not Found error
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
if not check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ # The email will be sent to the stored address.
+ # This avoids a potential account hijack by requesting a password reset to
+ # an email address which is controlled by the attacker but which, after
+ # canonicalisation, matches the one in our database.
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -220,6 +237,7 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore()
self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -251,13 +269,17 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- params = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ params = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
user_id = requester.user.to_string()
else:
requester = None
@@ -274,10 +296,13 @@ class PasswordRestServlet(RestServlet):
if "medium" not in threepid or "address" not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid["medium"] == "email":
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
- threepid["address"] = threepid["address"].lower()
+ try:
+ threepid["address"] = canonicalise_email(threepid["address"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
# if using email, we must know about the email they're authing with!
threepid_user_id = await self.datastore.get_user_id_by_threepid(
threepid["medium"], threepid["address"]
@@ -297,11 +322,29 @@ class PasswordRestServlet(RestServlet):
user_id, new_password_hash, logout_devices, requester
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_password(params, shadow_user.to_string())
+
return 200, {}
def on_OPTIONS(self, _):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -392,20 +435,27 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database.
+ # This ensures that the validation email is sent to the canonicalised address
+ # as it will later be entered into the database.
+ # Otherwise the email will be sent to "FOO@bar.com" and stored as
+ # "foo@bar.com" in database.
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", email)):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
- existing_user_id = await self.store.get_user_id_by_threepid(
- "email", body["email"]
- )
+ existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
if self.config.request_token_inhibit_3pid_errors:
@@ -466,13 +516,15 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
@@ -631,7 +683,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -650,6 +703,29 @@ class ThreepidRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ # skip validation if this is a shadow 3PID from an AS
+ if requester.app_service:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
+
+ await self.auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return 200, {}
+
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
if threepid_creds is None:
raise SynapseError(
@@ -671,12 +747,36 @@ class ThreepidRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
@@ -687,6 +787,7 @@ class ThreepidAddRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -722,12 +823,34 @@ class ThreepidAddRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
@@ -797,6 +920,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
@@ -821,6 +945,12 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
@@ -828,6 +958,77 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+ @defer.inlineCallbacks
+ def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+
+ defer.returnValue((200, ret))
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.proxy_bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ defer.returnValue((200, ret))
+
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
@@ -856,4 +1057,6 @@ def register_servlets(hs, http_server):
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index c1d4cd0caf..d31ec7c29d 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -39,6 +40,7 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
+ self._profile_handler = hs.get_profile_handler()
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
@@ -50,6 +52,11 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ await self._profile_handler.set_active([user], not hide_profile, True)
+
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 56a451c42f..001f49fb3e 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
import hmac
import logging
+import re
from typing import List, Union
import synapse
@@ -47,7 +49,7 @@ from synapse.push.mailer import load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
-from synapse.util.threepids import check_3pid_allowed
+from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -116,19 +118,26 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See on_POST in EmailThreepidRequestTokenRestServlet
+ # in synapse/rest/client/v2_alpha/account.py)
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- "email", body["email"]
+ "email", email
)
if existing_user_id is not None:
@@ -192,7 +201,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -348,15 +359,9 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
- ip = self.hs.get_ip_from_request(request)
- with self.ratelimiter.ratelimit(ip) as wait_deferred:
- await wait_deferred
-
- username = parse_string(request, "username", required=True)
-
- await self.registration_handler.check_username(username)
-
- return 200, {"available": True}
+ # We are not interested in logging in via a username in this deployment.
+ # Simply allow anything here as it won't be used later.
+ return 200, {"available": True}
class RegisterRestServlet(RestServlet):
@@ -407,6 +412,7 @@ class RegisterRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
+ desired_password_hash = None
if "password" in body:
password = body.pop("password")
if not isinstance(password, str) or len(password) > 512:
@@ -418,12 +424,19 @@ class RegisterRestServlet(RestServlet):
if "password_hash" in body:
raise SynapseError(400, "Unexpected property: password_hash")
body["password_hash"] = await self.auth_handler.hash(password)
+ desired_password_hash = body["password_hash"]
+ # We don't care about usernames for this deployment. In fact, the act
+ # of checking whether they exist already can leak metadata about
+ # which users are already registered.
+ #
+ # Usernames are already derived via the provided email.
+ # So, if they're not necessary, just ignore them.
+ #
+ # (we do still allow appservices to set them below)
desired_username = None
- if "username" in body:
- if not isinstance(body["username"], str) or len(body["username"]) > 512:
- raise SynapseError(400, "Invalid username")
- desired_username = body["username"]
+
+ desired_display_name = body.get("display_name")
appservice = None
if self.auth.has_access_token(request):
@@ -437,7 +450,7 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
- desired_username = body.get("user", desired_username)
+ desired_username = body.get("user", body.get("username"))
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
@@ -448,19 +461,14 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, str):
result = await self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username,
+ desired_password_hash,
+ desired_display_name,
+ access_token,
+ body,
)
return 200, result # we throw for non 200 responses
- # for regular registration, downcase the provided username before
- # attempting to register it. This should mean
- # that people who try to register with upper-case in their usernames
- # don't get a nasty surprise. (Note that we treat username
- # case-insenstively in login, so they are free to carry on imagining
- # that their username is CrAzYh4cKeR if that keeps them happy)
- if desired_username is not None:
- desired_username = desired_username.lower()
-
# == Normal User Registration == (everyone else)
if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled")
@@ -486,13 +494,6 @@ class RegisterRestServlet(RestServlet):
session_id, "registered_user_id", None
)
- if desired_username is not None:
- await self.registration_handler.check_username(
- desired_username,
- guest_access_token=guest_access_token,
- assigned_user_id=registered_user_id,
- )
-
auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows,
request,
@@ -513,7 +514,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (await check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -521,6 +522,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = await self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = _map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -531,7 +606,12 @@ class RegisterRestServlet(RestServlet):
# NB: This may be from the auth handler and NOT from the POST
assert_params_in_dict(params, ["password_hash"])
- desired_username = params.get("username", None)
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
+
guest_access_token = params.get("guest_access_token", None)
new_password_hash = params.get("password_hash", None)
@@ -552,6 +632,15 @@ class RegisterRestServlet(RestServlet):
if login_type in auth_result:
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See on_POST in EmailThreepidRequestTokenRestServlet
+ # in synapse/rest/client/v2_alpha/account.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
existing_user_id = await self.store.get_user_id_by_threepid(
medium, address
@@ -568,6 +657,7 @@ class RegisterRestServlet(RestServlet):
localpart=desired_username,
password_hash=new_password_hash,
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
)
@@ -579,6 +669,14 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
+ if self.hs.config.shadow_server:
+ await self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
await self.auth_handler.set_session_data(
@@ -603,11 +701,30 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- async def _do_appservice_registration(self, username, as_token, body):
+ async def _do_appservice_registration(
+ self, username, password_hash, display_name, as_token, body
+ ):
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
user_id = await self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password_hash, display_name
)
- return await self._create_registration_details(user_id, body)
+ result = await self._create_registration_details(user_id, body)
+
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ await self.registration_handler.register_email_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ await self.registration_handler.register_msisdn_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ return result
async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
@@ -658,6 +775,60 @@ class RegisterRestServlet(RestServlet):
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..6e8300d6a5 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@
# limitations under the License.
import logging
+from typing import Dict
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from signedjson.sign import sign_json
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import UserID
from ._base import client_patterns
@@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
"""Searches for users in directory
@@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = await self.http_client.post_json_get_json(url, signed_body)
+ return 200, resp
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,5 +95,125 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
+class SingleUserInfoServlet(RestServlet):
+ """
+ Deprecated and replaced by `/users/info`
+
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(SingleUserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ async def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = await self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ return 200, res
+
+ user_id_to_info = await self.store.get_info_for_users([user_id])
+ return 200, user_id_to_info[user_id]
+
+ async def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
+ return user_ids_to_info_dict[user_id]
+
+
+class UserInfoServlet(RestServlet):
+ """Bulk version of `/user/{user_id}/info` endpoint
+
+ GET /users/info HTTP/1.1
+
+ Returns a dictionary of user_id to info dictionary. Supports remote users
+ """
+
+ PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+
+ async def on_POST(self, request):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ # Extract the user_ids from the request
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, required=["user_ids"])
+
+ user_ids = body["user_ids"]
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # Separate local and remote users
+ local_user_ids = set()
+ remote_server_to_user_ids = {} # type: Dict[str, set]
+ for user_id in user_ids:
+ user = UserID.from_string(user_id)
+
+ if self.hs.is_mine(user):
+ local_user_ids.add(user_id)
+ else:
+ remote_server_to_user_ids.setdefault(user.domain, set())
+ remote_server_to_user_ids[user.domain].add(user_id)
+
+ # Retrieve info of all local users
+ user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)
+
+ # Request info of each remote user from their remote homeserver
+ for server_name, user_id_set in remote_server_to_user_ids.items():
+ # Make a request to the given server about their own users
+ res = await self.transport_layer.get_info_of_users(
+ server_name, list(user_id_set)
+ )
+
+ for user_id, info in res:
+ user_id_to_info_dict[user_id] = info
+
+ return 200, user_id_to_info_dict
+
+
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ SingleUserInfoServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..b1999d051b 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -57,9 +57,12 @@ class VersionsRestServlet(RestServlet):
# MSC2326.
"org.matrix.label_based_filtering": True,
# Implements support for cross signing as described in MSC1756
- "org.matrix.e2e_cross_signing": True,
+ # "org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
+ # Tchap does not currently assume this rule for r0.5.0
+ # XXX: Remove this when it does
+ "m.lazy_load_members": True,
},
},
)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 0a890c98cb..4386eb4e72 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -26,11 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
-from synapse.http.server import (
- DirectServeResource,
- respond_with_html,
- wrap_html_request_handler,
-)
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_string
from synapse.types import UserID
@@ -48,7 +44,7 @@ else:
return a == b
-class ConsentResource(DirectServeResource):
+class ConsentResource(DirectServeHtmlResource):
"""A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template.
@@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
- @wrap_html_request_handler
async def _async_render_GET(self, request):
"""
Args:
@@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
- @wrap_html_request_handler
async def _async_render_POST(self, request):
"""
Args:
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index ab671f7334..e149ac1733 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,17 +20,13 @@ from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json_bytes,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__)
-class RemoteKey(DirectServeResource):
+class RemoteKey(DirectServeJsonResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
isLeaf = True
def __init__(self, hs):
+ super().__init__()
+
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
- @wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
@@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- @wrap_json_request_handler
async def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 9f747de263..68dd2a1c8a 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -14,16 +14,10 @@
# limitations under the License.
#
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource, respond_with_json
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json,
- wrap_json_request_handler,
-)
-
-class MediaConfigResource(DirectServeResource):
+class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
@@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- @wrap_json_request_handler
async def _async_render_GET(self, request):
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
- return NOT_DONE_YET
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 24d3ae5bbc..d3d8457303 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -15,18 +15,14 @@
import logging
import synapse.http.servlet
-from synapse.http.server import (
- DirectServeResource,
- set_cors_headers,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
-class DownloadResource(DirectServeResource):
+class DownloadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
self.media_repo = media_repo
self.server_name = hs.hostname
- # this is expected by @wrap_json_request_handler
- self.clock = hs.get_clock()
-
- @wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
request.setHeader(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b4645cd608..e52c86c798 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
- DirectServeResource,
+ DirectServeJsonResource,
respond_with_json,
respond_with_json_bytes,
- wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
-class PreviewUrlResource(DirectServeResource):
+class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
self._start_expire_url_cache_data, 10 * 1000
)
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
request.setHeader(b"Allow", b"OPTIONS, GET")
- return respond_with_json(request, 200, {}, send_cors=True)
+ respond_with_json(request, 200, {}, send_cors=True)
- @wrap_json_request_handler
async def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render?
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 0b87220234..a83535b97b 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,11 +16,7 @@
import logging
-from synapse.http.server import (
- DirectServeResource,
- set_cors_headers,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from ._base import (
@@ -34,7 +30,7 @@ from ._base import (
logger = logging.getLogger(__name__)
-class ThumbnailResource(DirectServeResource):
+class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- self.clock = hs.get_clock()
- @wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index c234ea7421..7126997134 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from io import BytesIO
-import PIL.Image as Image
+from PIL import Image as Image
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 83d005812d..3ebf7a68e6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,20 +15,14 @@
import logging
-from twisted.web.server import NOT_DONE_YET
-
from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
-class UploadResource(DirectServeResource):
+class UploadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
- return NOT_DONE_YET
- @wrap_json_request_handler
async def _async_render_POST(self, request):
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
index c03194f001..f7a0bc4bdb 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/oidc/callback_resource.py
@@ -14,18 +14,17 @@
# limitations under the License.
import logging
-from synapse.http.server import DirectServeResource, wrap_html_request_handler
+from synapse.http.server import DirectServeHtmlResource
logger = logging.getLogger(__name__)
-class OIDCCallbackResource(DirectServeResource):
+class OIDCCallbackResource(DirectServeHtmlResource):
isLeaf = 1
def __init__(self, hs):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
- @wrap_html_request_handler
async def _async_render_GET(self, request):
- return await self._oidc_handler.handle_oidc_callback(request)
+ await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index 75e58043b4..c10188a5d7 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -16,10 +16,10 @@
from twisted.python import failure
from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource, return_html_error
-class SAML2ResponseResource(DirectServeResource):
+class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
new file mode 100644
index 0000000000..6f2a1931c5
--- /dev/null
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.config._base import ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DomainRuleChecker(object):
+ """
+ A re-implementation of the SpamChecker that prevents users in one domain from
+ inviting users in other domains to rooms, based on a configuration.
+
+ Takes a config in the format:
+
+ spam_checker:
+ module: "rulecheck.DomainRuleChecker"
+ config:
+ domain_mapping:
+ "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
+ "other_inviter_domain": [ "invitee_domain_permitted" ]
+ default: False
+
+ # Only let local users join rooms if they were explicitly invited.
+ can_only_join_rooms_with_invite: false
+
+ # Only let local users create rooms if they are inviting only one
+ # other user, and that user matches the rules above.
+ can_only_create_one_to_one_rooms: false
+
+ # Only let local users invite during room creation, regardless of the
+ # domain mapping rules above.
+ can_only_invite_during_room_creation: false
+
+ # Prevent local users from inviting users from certain domains to
+ # rooms published in the room directory.
+ domains_prevented_from_being_invited_to_published_rooms: []
+
+ # Allow third party invites
+ can_invite_by_third_party_id: true
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config):
+ self.domain_mapping = config["domain_mapping"] or {}
+ self.default = config["default"]
+
+ self.can_only_join_rooms_with_invite = config.get(
+ "can_only_join_rooms_with_invite", False
+ )
+ self.can_only_create_one_to_one_rooms = config.get(
+ "can_only_create_one_to_one_rooms", False
+ )
+ self.can_only_invite_during_room_creation = config.get(
+ "can_only_invite_during_room_creation", False
+ )
+ self.can_invite_by_third_party_id = config.get(
+ "can_invite_by_third_party_id", True
+ )
+ self.domains_prevented_from_being_invited_to_published_rooms = config.get(
+ "domains_prevented_from_being_invited_to_published_rooms", []
+ )
+
+ def check_event_for_spam(self, event):
+ """Implements synapse.events.SpamChecker.check_event_for_spam
+ """
+ return False
+
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
+ """Implements synapse.events.SpamChecker.user_may_invite
+ """
+ if self.can_only_invite_during_room_creation and not new_room:
+ return False
+
+ if not self.can_invite_by_third_party_id and third_party_invite:
+ return False
+
+ # This is a third party invite (without a bound mxid), so unless we have
+ # banned all third party invites (above) we allow it.
+ if not invitee_userid:
+ return True
+
+ inviter_domain = self._get_domain_from_id(inviter_userid)
+ invitee_domain = self._get_domain_from_id(invitee_userid)
+
+ if inviter_domain not in self.domain_mapping:
+ return self.default
+
+ if (
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
+ ):
+ return False
+
+ return invitee_domain in self.domain_mapping[inviter_domain]
+
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
+ """Implements synapse.events.SpamChecker.user_may_create_room
+ """
+
+ if cloning:
+ return True
+
+ if not self.can_invite_by_third_party_id and third_party_invite_list:
+ return False
+
+ number_of_invites = len(invite_list) + len(third_party_invite_list)
+
+ if self.can_only_create_one_to_one_rooms and number_of_invites != 1:
+ return False
+
+ return True
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias
+ """
+ return True
+
+ def user_may_publish_room(self, userid, room_id):
+ """Implements synapse.events.SpamChecker.user_may_publish_room
+ """
+ return True
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Implements synapse.events.SpamChecker.user_may_join_room
+ """
+ if self.can_only_join_rooms_with_invite and not is_invited:
+ return False
+
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ """Implements synapse.events.SpamChecker.parse_config
+ """
+ if "default" in config:
+ return config
+ else:
+ raise ConfigError("No default set for spam_config DomainRuleChecker")
+
+ @staticmethod
+ def _get_domain_from_id(mxid):
+ """Parses a string and returns the domain part of the mxid.
+
+ Args:
+ mxid (str): a valid mxid
+
+ Returns:
+ str: the domain part of the mxid
+
+ """
+ idx = mxid.find(":")
+ if idx == -1:
+ raise Exception("Invalid ID: %r" % (mxid,))
+ return mxid[idx + 1 :]
diff --git a/synapse/secrets.py b/synapse/secrets.py
index 0b327a0f82..5f43f81eb0 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -19,7 +19,6 @@ Injectable secrets module for Synapse.
See https://docs.python.org/3/library/secrets.html#module-secrets for the API
used in Python 3.6, and the API emulated in Python 2.7.
"""
-
import sys
# secrets is available since python 3.6
@@ -31,8 +30,8 @@ if sys.version_info[0:2] >= (3, 6):
else:
- import os
import binascii
+ import os
class Secrets(object):
def token_bytes(self, nbytes=32):
diff --git a/synapse/server.py b/synapse/server.py
index fe94836a2c..f1078a3805 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -174,6 +174,7 @@ class HomeServer(object):
"event_builder_factory",
"filtering",
"http_client_context_factory",
+ "proxied_http_client",
"simple_http_client",
"proxied_http_client",
"media_repository",
@@ -212,6 +213,7 @@ class HomeServer(object):
"replication_streamer",
"replication_data_handler",
"replication_streams",
+ "password_policy_handler",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -232,6 +234,8 @@ class HomeServer(object):
self._reactor = reactor
self.hostname = hostname
+ # the key we use to sign events and requests
+ self.signing_key = config.key.signing_key[0]
self.config = config
self._building = {}
self._listening_services = []
@@ -570,9 +574,6 @@ class HomeServer(object):
def build_event_client_serializer(self):
return EventClientSerializer(self)
- def build_password_policy_handler(self):
- return PasswordPolicyHandler(self)
-
def build_storage(self) -> Storage:
return Storage(self, self.datastores)
@@ -585,6 +586,9 @@ class HomeServer(object):
def build_replication_streams(self):
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
+ def build_password_policy_handler(self):
+ return PasswordPolicyHandler(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index eac5a4e55b..f39f556c20 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,10 +16,12 @@
import itertools
import logging
-from typing import Any, Iterable, Optional, Tuple
+from typing import Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
+ EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
@@ -44,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ):
- """Fetches cache invalidation rows between the two given IDs written
- by the given instance. Returns at most `limit` rows.
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for caches replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
- return []
+ return [], current_id, False
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
@@ -64,17 +83,24 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
- return txn.fetchall()
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "events":
+ if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
- elif stream_name == "backfill":
+ elif stream_name == BackfillStream.NAME:
for row in rows:
self._invalidate_caches_for_event(
-token,
@@ -86,7 +112,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
row.relates_to,
backfilled=True,
)
- elif stream_name == "caches":
+ elif stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 9a1178fb39..d313b9705f 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
- def get_all_new_device_messages(self, last_pos, current_pos, limit):
- """
+ async def get_all_new_device_messages(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for to device replication stream.
+
Args:
- last_pos(int):
- current_pos(int):
- limit(int):
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
Returns:
- A deferred list of rows from the device inbox
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- if last_pos == current_pos:
- return defer.succeed([])
+
+ if last_id == current_id:
+ return [], current_id, False
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
- upper_pos = min(current_pos, last_pos + limit)
+ upper_pos = min(current_id, last_id + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
- txn.execute(sql, (last_pos, upper_pos))
- rows = txn.fetchall()
+ txn.execute(sql, (last_id, upper_pos))
+ updates = [(row[0], row[1:]) for row in txn]
sql = (
"SELECT max(stream_id), destination"
@@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
- txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn)
+ txn.execute(sql, (last_id, upper_pos))
+ updates.extend((row[0], row[1:]) for row in txn)
# Order by ascending stream ordering
- rows.sort()
+ updates.sort()
- return rows
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- return self.db.runInteraction(
+ return updates, upto_token, limited
+
+ return await self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 0ff0542453..343cf9a2d5 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -582,32 +582,58 @@ class DeviceWorkerStore(SQLBaseStore):
return set()
async def get_all_device_list_changes_for_remotes(
- self, from_key: int, to_key: int, limit: int,
- ) -> List[Tuple[int, str]]:
- """Return a list of `(stream_id, entity)` which is the combined list of
- changes to devices and which destinations need to be poked. Entity is
- either a user ID (starting with '@') or a remote destination.
- """
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for device lists replication stream.
- # This query Does The Right Thing where it'll correctly apply the
- # bounds to the inner queries.
- sql = """
- SELECT stream_id, entity FROM (
- SELECT stream_id, user_id AS entity FROM device_lists_stream
- UNION ALL
- SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
- ) AS e
- WHERE ? < stream_id AND stream_id <= ?
- LIMIT ?
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- return await self.db.execute(
+ if last_id == current_id:
+ return [], current_id, False
+
+ def _get_all_device_list_changes_for_remotes(txn):
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
+ sql = """
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db.runInteraction(
"get_all_device_list_changes_for_remotes",
- None,
- sql,
- from_key,
- to_key,
- limit,
+ _get_all_device_list_changes_for_remotes,
)
@cached(max_entries=10000)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 1a0842d4b0..6c3cff82e1 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List
+from typing import Dict, List, Tuple
from canonicaljson import encode_canonical_json, json
@@ -479,34 +479,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
- def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
- """Return a list of changes from the user signature stream to notify remotes.
+ async def get_all_user_signature_changes_for_remotes(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for groups replication stream.
+
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
users or servers, so no `destination` is needed in the returned
list. However, this is needed to poke workers.
Args:
- from_key (int): the stream ID to start at (exclusive)
- to_key (int): the stream ID to end at (inclusive)
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
Returns:
- Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
- """
- sql = """
- SELECT stream_id, from_user_id AS user_id
- FROM user_signature_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- return self.db.execute(
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def _get_all_user_signature_changes_for_remotes_txn(txn):
+ sql = """
+ SELECT stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
+
+ updates = [(row[0], (row[1:])) for row in txn]
+
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db.runInteraction(
"get_all_user_signature_changes_for_remotes",
- None,
- sql,
- from_key,
- to_key,
- limit,
+ _get_all_user_signature_changes_for_remotes_txn,
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index cfd24d2f06..230fb5cd7f 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from collections import OrderedDict, namedtuple
@@ -28,12 +27,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import (
- EventContentFields,
- EventTypes,
- Membership,
- RelationTypes,
-)
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
@@ -48,8 +42,8 @@ from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
- from synapse.storage.data_stores.main import DataStore
from synapse.server import HomeServer
+ from synapse.storage.data_stores.main import DataStore
logger = logging.getLogger(__name__)
@@ -820,7 +814,6 @@ class PersistEventsStore:
"event_reference_hashes",
"event_search",
"event_to_state_groups",
- "local_invites",
"state_events",
"rejections",
"redactions",
@@ -1197,65 +1190,27 @@ class PersistEventsStore:
(event.state_key,),
)
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened. If the event is an
- # outlier it is only current if its an "out of band membership",
- # like a remote invite or a rejection of a remote invite.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_out_of_band_membership()
- )
- is_mine = self.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self.db.simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- },
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(
- sql,
- (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ),
- )
-
- # We also update the `local_current_membership` table with
- # latest invite info. This will usually get updated by the
- # `current_state_events` handling, unless its an outlier.
- if event.internal_metadata.is_outlier():
- # This should only happen for out of band memberships, so
- # we add a paranoia check.
- assert event.internal_metadata.is_out_of_band_membership()
-
- self.db.simple_upsert_txn(
- txn,
- table="local_current_membership",
- keyvalues={
- "room_id": event.room_id,
- "user_id": event.state_key,
- },
- values={
- "event_id": event.event_id,
- "membership": event.membership,
- },
- )
+ # We update the local_current_membership table only if the event is
+ # "current", i.e., its something that has just happened.
+ #
+ # This will usually get updated by the `current_state_events` handling,
+ # unless its an outlier, and an outlier is only "current" if it's an "out of
+ # band membership", like a remote invite or a rejection of a remote invite.
+ if (
+ self.is_mine_id(event.state_key)
+ and not backfilled
+ and event.internal_metadata.is_outlier()
+ and event.internal_metadata.is_out_of_band_membership()
+ ):
+ self.db.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": event.room_id, "user_id": event.state_key},
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
def _handle_event_relations(self, txn, event):
"""Handles inserting relation data during peristence of events
@@ -1586,31 +1541,3 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
-
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
-
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
- # We also clear this entry from `local_current_membership`.
- # Ideally we'd point to a leave event, but we don't have one, so
- # nevermind.
- self.db.simple_delete_txn(
- txn,
- table="local_current_membership",
- keyvalues={"room_id": room_id, "user_id": user_id},
- )
-
- with self._stream_id_gen.get_next() as stream_ordering:
- await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-
- return stream_ordering
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index a48c7a96ca..01cad7d4fa 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -38,6 +38,8 @@ from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import BackfillStream
+from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -80,10 +82,7 @@ class EventsWorkerStore(SQLBaseStore):
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- extra_tables=[("local_invites", "stream_id")],
+ db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
@@ -113,9 +112,9 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_ongoing = 0
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "events":
+ if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(token)
- elif stream_name == "backfill":
+ elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(-token)
super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index fb1361f1c1..4fb9f9850c 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
from canonicaljson import json
from twisted.internet import defer
@@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
- def get_all_groups_changes(self, from_token, to_token, limit):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(
- from_token
- )
+ async def get_all_groups_changes(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for groups replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ last_id = int(last_id)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+
if not has_changed:
- return defer.succeed([])
+ return [], current_id, False
def _get_all_groups_changes_txn(txn):
sql = """
@@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, limit))
- return [
- (stream_id, group_id, user_id, gtype, json.loads(content_json))
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [
+ (stream_id, (group_id, user_id, gtype, json.loads(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
- return self.db.runInteraction(
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn
)
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index bfc9369f0b..cd2472feb0 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,11 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.roommember import ProfileInfo
+from synapse.types import UserID
+from synapse.util.caches.descriptors import cached
+
+BATCH_SIZE = 100
class ProfileWorkerStore(SQLBaseStore):
@@ -41,6 +48,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
def get_profile_displayname(self, user_localpart):
return self.db.simple_select_one_onecol(
table="profiles",
@@ -49,6 +57,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
def get_profile_avatar_url(self, user_localpart):
return self.db.simple_select_one_onecol(
table="profiles",
@@ -57,6 +66,54 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.db.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.db.runInteraction("get_latest_profile_replication_batch_number", f)
+
+ def get_profile_batch(self, batchnum):
+ return self.db.simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return self.db.runInteraction("assign_profile_batch", f)
+
+ def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.db.runInteraction("get_replication_hosts", f)
+
+ def update_replication_batch_for_host(self, host, last_synced_batch):
+ return self.db.simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
def get_from_remote_profile_cache(self, user_id):
return self.db.simple_select_one(
table="remote_profile_cache",
@@ -71,24 +128,83 @@ class ProfileWorkerStore(SQLBaseStore):
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db.simple_update_one(
+ def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db.simple_update_one(
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ def set_profiles_active(
+ self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ ):
+ """Given a set of users, set active and hidden flags on them.
+
+ Args:
+ users: A list of UserIDs
+ active: Whether to set the users to active or inactive
+ hide: Whether to hide the users (withold from replication). If
+ False and active is False, users will have their profiles
+ erased
+ batchnum: The batch number, used for profile replication
+
+ Returns:
+ Deferred
+ """
+ # Convert list of localparts to list of tuples containing localparts
+ user_localparts = [(user.localpart,) for user in users]
+
+ # Generate list of value tuples for each user
+ value_names = ("active", "batch")
+ values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
+
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile information
+ value_names += ("avatar_url", "displayname")
+ values = [v + (None, None) for v in values]
+
+ return self.db.runInteraction(
+ "set_profiles_active",
+ self.db.simple_upsert_many_txn,
+ table="profiles",
+ key_names=("user_id",),
+ key_values=user_localparts,
+ value_names=value_names,
+ value_values=values,
)
class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
+
+ super(ProfileStore, self).__init__(database, db_conn, hs)
+
+ self.db.updates.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
@@ -107,7 +223,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db.simple_update(
+ return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
index a93e1ef198..6546569139 100644
--- a/synapse/storage/data_stores/main/purge_events.py
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -361,7 +361,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
- "local_invites",
"room_account_data",
"room_tags",
"local_current_membership",
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 547b9d69cb..5461016240 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import Iterable, Iterator
+from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json, json
@@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore):
rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
return rows
- def get_all_updated_pushers(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed(([], []))
-
- def get_all_updated_pushers_txn(txn):
- sql = (
- "SELECT id, user_name, access_token, profile_tag, kind,"
- " app_id, app_display_name, device_display_name, pushkey, ts,"
- " lang, data"
- " FROM pushers"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- updated = txn.fetchall()
-
- sql = (
- "SELECT stream_id, user_id, app_id, pushkey"
- " FROM deleted_pushers"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- deleted = txn.fetchall()
+ async def get_all_updated_pushers_rows(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for pushers replication stream.
- return updated, deleted
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
- return self.db.runInteraction(
- "get_all_updated_pushers", get_all_updated_pushers_txn
- )
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
- def get_all_updated_pushers_rows(self, last_id, current_id, limit):
- """Get all the pushers that have changed between the given tokens.
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
- Returns:
- Deferred(list(tuple)): each tuple consists of:
- stream_id (str)
- user_id (str)
- app_id (str)
- pushkey (str)
- was_deleted (bool): whether the pusher was added/updated (False)
- or deleted (True)
+ The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_updated_pushers_rows_txn(txn):
- sql = (
- "SELECT id, user_name, app_id, pushkey"
- " FROM pushers"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC LIMIT ?"
- )
+ sql = """
+ SELECT id, user_name, app_id, pushkey
+ FROM pushers
+ WHERE ? < id AND id <= ?
+ ORDER BY id ASC LIMIT ?
+ """
txn.execute(sql, (last_id, current_id, limit))
- results = [list(row) + [False] for row in txn]
-
- sql = (
- "SELECT stream_id, user_id, app_id, pushkey"
- " FROM deleted_pushers"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
+ updates = [
+ (stream_id, (user_name, app_id, pushkey, False))
+ for stream_id, user_name, app_id, pushkey in txn
+ ]
+
+ sql = """
+ SELECT stream_id, user_id, app_id, pushkey
+ FROM deleted_pushers
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
txn.execute(sql, (last_id, current_id, limit))
+ updates.extend(
+ (stream_id, (user_name, app_id, pushkey, True))
+ for stream_id, user_name, app_id, pushkey in txn
+ )
+
+ updates.sort() # Sort so that they're ordered by stream id
- results.extend(list(row) + [True] for row in txn)
- results.sort() # Sort so that they're ordered by stream id
+ limited = False
+ upper_bound = current_id
+ if len(updates) >= limit:
+ limited = True
+ upper_bound = updates[-1][0]
- return results
+ return updates, upper_bound, limited
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 587d4b91c1..efb1a4fb4c 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -17,7 +17,7 @@
import logging
import re
-from typing import Optional
+from typing import List, Optional
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -157,6 +157,37 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def get_expired_users(self):
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
+
+ Returns:
+ Deferred[List[UserID]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
+ sql = """
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+
+ return [UserID.from_string(row[0]) for row in rows]
+
+ res = yield self.db.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+ return res
+
+ @defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
"""Defines a renewal token for a given user.
@@ -272,6 +303,55 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
+ @defer.inlineCallbacks
+ def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = yield self.db.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self.clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 13e366536a..98a8b5a11a 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -340,6 +340,24 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @defer.inlineCallbacks
+ def is_room_published(self, room_id):
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id (str)
+ Returns:
+ bool: Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = yield self.get_room(room_id)
+ if not room_info:
+ defer.returnValue(False)
+
+ # Check the is_public value
+ defer.returnValue(room_info.get("is_public", False))
+
async def get_rooms_paginate(
self,
start: int,
@@ -548,6 +566,11 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ defer.returnValue({"min_lifetime": None, "max_lifetime": None})
def get_retention_policy_for_room_txn(txn):
txn.execute(
@@ -803,7 +826,32 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ async def get_all_new_public_rooms(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for public rooms replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+ if last_id == current_id:
+ return [], current_id, False
+
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
@@ -813,13 +861,17 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ?
"""
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- if prev_id == current_id:
- return defer.succeed([])
+ return updates, upto_token, limited
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py
index 4b2ffd35fd..ee675e71ff 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/fts.py
+++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.prepare_database import get_statements
@@ -66,7 +64,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py
index 414f9f5aa0..b7972cfa8e 100644
--- a/synapse/storage/data_stores/main/schema/delta/27/ts.py
+++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
index 7d8ca5f93f..63b757ade6 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/search_update.py
+++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import get_statements
@@ -50,7 +48,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"rows_inserted": 0,
"have_added_indexes": False,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
index bff1256a7b..a3e81eeac7 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql b/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..c8893ecbe8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
diff --git a/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..18a0f7e10c
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('profile_replication_status_host_index', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/55/room_retention.sql
index ee6cdf7a14..ee6cdf7a14 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/room_retention.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index f8c776be3f..290317fd94 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore):
return deferred
- @defer.inlineCallbacks
- def get_all_updated_tags(self, last_id, current_id, limit):
- """Get all the client tags that have changed on the server
+ async def get_all_updated_tags(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for tags replication stream.
+
Args:
- last_id(int): The position to fetch from.
- current_id(int): The position to fetch up to.
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
Returns:
- A deferred list of tuples of stream_id int, user_id string,
- room_id string, tag string and content string.
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
+
if last_id == current_id:
- return []
+ return [], current_id, False
def get_all_updated_tags_txn(txn):
sql = (
@@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- tag_ids = yield self.db.runInteraction(
+ tag_ids = await self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
@@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
- results.append((stream_id, user_id, room_id, tag_json))
+ results.append((stream_id, (user_id, room_id, tag_json)))
return results
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
- tags = yield self.db.runInteraction(
+ tags = await self.db.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
- return results
+ limited = False
+ upto_token = current_id
+ if len(results) >= limit:
+ upto_token = results[-1][0]
+ limited = True
+
+ return results, upto_token, limited
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index ec2f38c373..4c044b1a15 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -17,10 +17,10 @@ from typing import Any, Dict, Optional, Union
import attr
-import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.types import JsonDict
+from synapse.util import stringutils as stringutils
@attr.s
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 6c7d08a6f2..a31588080d 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -92,7 +92,7 @@ class PostgresEngine(BaseDatabaseEngine):
errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
if ctype != "C":
- errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+ errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,))
if errors:
raise IncorrectDatabaseSetup(
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index ec894a91cb..fa46041676 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -783,9 +783,3 @@ class EventsPersistenceStorage(object):
for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
- return await self.persist_events_store.locally_reject_invite(user_id, room_id)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index daff81c5ee..2d2b560e74 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,12 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Any, Iterable, Iterator, List, Tuple
from typing_extensions import Protocol
-
"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index cd56cd91ed..ca7c16ff65 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -68,13 +68,13 @@ class PaginationConfig(object):
elif from_tok:
from_tok = StreamToken.from_string(from_tok)
except Exception:
- raise SynapseError(400, "'from' paramater is invalid")
+ raise SynapseError(400, "'from' parameter is invalid")
try:
if to_tok:
to_tok = StreamToken.from_string(to_tok)
except Exception:
- raise SynapseError(400, "'to' paramater is invalid")
+ raise SynapseError(400, "'to' parameter is invalid")
limit = parse_integer(request, "limit", default=default_limit)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index fcd2aaa9c9..5d3eddcfdc 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -68,7 +68,7 @@ class EventSources(object):
The returned token does not have the current values for fields other
than `room`, since they are not used during pagination.
- Retuns:
+ Returns:
Deferred[StreamToken]
"""
token = StreamToken(
diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
new file mode 100644
index 0000000000..2c9155d15c
--- /dev/null
+++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,588 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import email.utils
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.config._base import ConfigError
+from synapse.types import get_domain_from_id
+
+ACCESS_RULES_TYPE = "im.vector.room.access_rules"
+ACCESS_RULE_RESTRICTED = "restricted"
+ACCESS_RULE_UNRESTRICTED = "unrestricted"
+ACCESS_RULE_DIRECT = "direct"
+
+VALID_ACCESS_RULES = (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+)
+
+# Rules to which we need to apply the power levels restrictions.
+#
+# These are all of the rules that neither:
+# * forbid users from joining based on a server blacklist (which means that there
+# is no need to apply power level restrictions), nor
+# * target direct chats (since we allow both users to be room admins in this case).
+#
+# The power-level restrictions, when they are applied, prevent the following:
+# * the default power level for users (users_default) being set to anything other than 0.
+# * a non-default power level being assigned to any user which would be forbidden from
+# joining a restricted room.
+RULES_WITH_RESTRICTED_POWER_LEVELS = (ACCESS_RULE_UNRESTRICTED,)
+
+
+class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config, http_client):
+ self.http_client = http_client
+
+ self.id_server = config["id_server"]
+
+ self.domains_forbidden_when_restricted = config.get(
+ "domains_forbidden_when_restricted", []
+ )
+
+ @staticmethod
+ def parse_config(config):
+ if "id_server" in config:
+ return config
+ else:
+ raise ConfigError("No IS for event rules TchapEventRules")
+
+ def on_create_room(self, requester, config, is_requester_admin) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+ """
+ is_direct = config.get("is_direct")
+ preset = config.get("preset")
+ access_rule = None
+ join_rule = None
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and deny the request if not.
+ for event in config.get("initial_state", []):
+ if event["type"] == ACCESS_RULES_TYPE:
+ access_rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if access_rule is None:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule name is valid.
+ if access_rule not in VALID_ACCESS_RULES:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule is "direct" if the room is a direct chat.
+ if (is_direct and access_rule != ACCESS_RULE_DIRECT) or (
+ access_rule == ACCESS_RULE_DIRECT and not is_direct
+ ):
+ raise SynapseError(400, "Invalid access rule")
+
+ if event["type"] == EventTypes.JoinRules:
+ join_rule = event["content"].get("join_rule")
+
+ if access_rule is None:
+ # If there's no access rules event in the initial state, create one with the
+ # default setting.
+ if is_direct:
+ default_rule = ACCESS_RULE_DIRECT
+ else:
+ # If the default value for non-direct chat changes, we should make another
+ # case here for rooms created with either a "public" join_rule or the
+ # "public_chat" preset to make sure those keep defaulting to "restricted"
+ default_rule = ACCESS_RULE_RESTRICTED
+
+ if not config.get("initial_state"):
+ config["initial_state"] = []
+
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
+ }
+ )
+
+ access_rule = default_rule
+
+ # Check that the preset or the join rule in use is compatible with the access
+ # rule, whether it's a user-defined one or the default one (i.e. if it involves
+ # a "public" join rule, the access rule must be "restricted").
+ if (
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule != ACCESS_RULE_RESTRICTED:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Check if the creator can override values for the power levels.
+ allowed = self._is_power_level_content_allowed(
+ config.get("power_level_content_override", {}), access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content override")
+
+ # Second loop for events we need to know the current rule to process.
+ for event in config.get("initial_state", []):
+ if event["type"] == EventTypes.PowerLevels:
+ allowed = self._is_power_level_content_allowed(
+ event["content"], access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content")
+
+ return True
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(self, medium, address, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+ """
+ rule = self._get_rule_from_state(state_events)
+
+ if medium != "email":
+ defer.returnValue(False)
+
+ if rule != ACCESS_RULE_RESTRICTED:
+ # Only "restricted" requires filtering 3PID invites. We don't need to do
+ # anything for "direct" here, because only "restricted" requires filtering
+ # based on the HS the address is mapped to.
+ defer.returnValue(True)
+
+ parsed_address = email.utils.parseaddr(address)[1]
+ if parsed_address != address:
+ # Avoid reproducing the security issue described here:
+ # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2
+ # It's probably not worth it but let's just be overly safe here.
+ defer.returnValue(False)
+
+ # Get the HS this address belongs to from the identity server.
+ res = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
+ {"medium": medium, "address": address},
+ )
+
+ # Look for a domain that's not forbidden from being invited.
+ if not res.get("hs"):
+ defer.returnValue(False)
+ if res.get("hs") in self.domains_forbidden_when_restricted:
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
+ def check_event_allowed(self, event, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+ """
+ if event.type == ACCESS_RULES_TYPE:
+ return self._on_rules_change(event, state_events)
+
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ if event.type == EventTypes.PowerLevels:
+ return self._is_power_level_content_allowed(event.content, rule)
+
+ if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
+ return self._on_membership_or_invite(event, rule, state_events)
+
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
+ if event.type == EventTypes.RoomAvatar:
+ return self._on_room_avatar_change(event, rule)
+
+ if event.type == EventTypes.Name:
+ return self._on_room_name_change(event, rule)
+
+ if event.type == EventTypes.Topic:
+ return self._on_room_topic_change(event, rule)
+
+ return True
+
+ def _on_rules_change(self, event, state_events):
+ """Implement the checks and behaviour specified on allowing or forbidding a new
+ im.vector.room.access_rules event.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ new_rule = event.content.get("rule")
+
+ # Check for invalid values.
+ if new_rule not in VALID_ACCESS_RULES:
+ return False
+
+ # We must not allow rooms with the "public" join rule to be given any other access
+ # rule than "restricted".
+ join_rule = self._get_join_rule_from_state(state_events)
+ if join_rule == JoinRules.PUBLIC and new_rule != ACCESS_RULE_RESTRICTED:
+ return False
+
+ # Make sure we don't apply "direct" if the room has more than two members.
+ if new_rule == ACCESS_RULE_DIRECT:
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ if len(existing_members) > 2 or len(threepid_tokens) > 1:
+ return False
+
+ prev_rules_event = state_events.get((ACCESS_RULES_TYPE, ""))
+
+ # Now that we know the new rule doesn't break the "direct" case, we can allow any
+ # new rule in rooms that had none before.
+ if prev_rules_event is None:
+ return True
+
+ prev_rule = prev_rules_event.content.get("rule")
+
+ # Currently, we can only go from "restricted" to "unrestricted".
+ if prev_rule == ACCESS_RULE_RESTRICTED and new_rule == ACCESS_RULE_UNRESTRICTED:
+ return True
+
+ return False
+
+ def _on_membership_or_invite(self, event, rule, state_events):
+ """Applies the correct rule for incoming m.room.member and
+ m.room.third_party_invite events.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if rule == ACCESS_RULE_RESTRICTED:
+ ret = self._on_membership_or_invite_restricted(event)
+ elif rule == ACCESS_RULE_UNRESTRICTED:
+ ret = self._on_membership_or_invite_unrestricted()
+ elif rule == ACCESS_RULE_DIRECT:
+ ret = self._on_membership_or_invite_direct(event, state_events)
+ else:
+ # We currently apply the default (restricted) if we don't know the rule, we
+ # might want to change that in the future.
+ ret = self._on_membership_or_invite_restricted(event)
+
+ return ret
+
+ def _on_membership_or_invite_restricted(self, event):
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ "restricted" currently means that users can only invite users if their server is
+ included in a limited list of domains.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # We're not applying the rules on m.room.third_party_member events here because
+ # the filtering on threepids is done in check_threepid_can_be_invited, which is
+ # called before check_event_allowed.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return True
+
+ # We only need to process "join" and "invite" memberships, in order to be backward
+ # compatible, e.g. if a user from a blacklisted server joined a restricted room
+ # before the rules started being enforced on the server, that user must be able to
+ # leave it.
+ if event.membership not in [Membership.JOIN, Membership.INVITE]:
+ return True
+
+ invitee_domain = get_domain_from_id(event.state_key)
+ return invitee_domain not in self.domains_forbidden_when_restricted
+
+ def _on_membership_or_invite_unrestricted(self):
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ "unrestricted" currently means that every event is allowed.
+
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return True
+
+ def _on_membership_or_invite_direct(self, event, state_events):
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ "direct" currently means that no member is allowed apart from the two initial
+ members the room was created for (i.e. the room's creator and their first
+ invitee).
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Get the room memberships and 3PID invite tokens from the room's state.
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ # There should never be more than one 3PID invite in the room state: if the second
+ # original user came and left, and we're inviting them using their email address,
+ # given we know they have a Matrix account binded to the address (so they could
+ # join the first time), Synapse will successfully look it up before attempting to
+ # store an invite on the IS.
+ if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite:
+ # If we already have a 3PID invite in flight, don't accept another one, unless
+ # the new one has the same invite token as its state key. This is because 3PID
+ # invite revocations must be allowed, and a revocation is basically a new 3PID
+ # invite event with an empty content and the same token as the invite it
+ # revokes.
+ return event.state_key in threepid_tokens
+
+ if len(existing_members) == 2:
+ # If the user was within the two initial user of the room, Synapse would have
+ # looked it up successfully and thus sent a m.room.member here instead of
+ # m.room.third_party_invite.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return False
+
+ # We can only have m.room.member events here. The rule in this case is to only
+ # allow the event if its target is one of the initial two members in the room,
+ # i.e. the state key of one of the two m.room.member states in the room.
+ return event.state_key in existing_members
+
+ # We're alone in the room (and always have been) and there's one 3PID invite in
+ # flight.
+ if len(existing_members) == 1 and len(threepid_tokens) == 1:
+ # We can only have m.room.member events here. In this case, we can only allow
+ # the event if it's either a m.room.member from the joined user (we can assume
+ # that the only m.room.member event is a join otherwise we wouldn't be able to
+ # send an event to the room) or an an invite event which target is the invited
+ # user.
+ target = event.state_key
+ is_from_threepid_invite = self._is_invite_from_threepid(
+ event, threepid_tokens[0]
+ )
+ if is_from_threepid_invite or target == existing_members[0]:
+ return True
+
+ return False
+
+ return True
+
+ def _is_power_level_content_allowed(self, content, access_rule):
+ """Check if a given power levels event is permitted under the given access rule.
+
+ It shouldn't be allowed if it either changes the default PL to a non-0 value or
+ gives a non-0 PL to a user that would have been forbidden from joining the room
+ under a more restrictive access rule.
+
+ Args:
+ content (dict[]): The content of the m.room.power_levels event to check.
+ access_rule (str): The access rule in place in this room.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Check if we need to apply the restrictions with the current rule.
+ if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
+ return True
+
+ # If users_default is explicitly set to a non-0 value, deny the event.
+ users_default = content.get("users_default", 0)
+ if users_default:
+ return False
+
+ users = content.get("users", {})
+ for user_id, power_level in users.items():
+ server_name = get_domain_from_id(user_id)
+ # Check the domain against the blacklist. If found, and the PL isn't 0, deny
+ # the event.
+ if (
+ server_name in self.domains_forbidden_when_restricted
+ and power_level != 0
+ ):
+ return False
+
+ return True
+
+ def _on_join_rule_change(self, event, rule):
+ """Check whether a join rule change is allowed. A join rule change is always
+ allowed unless the new join rule is "public" and the current access rule isn't
+ "restricted".
+ The rationale is that external users (those whose server would be denied access
+ to rooms enforcing the "restricted" access rule) should always rely on non-
+ external users for access to rooms, therefore they shouldn't be able to access
+ rooms that don't require an invite to be joined.
+
+ Note that we currently rely on the default access rule being "restricted": during
+ room creation, the m.room.join_rules event will be sent *before* the
+ im.vector.room.access_rules one, so the access rule that will be considered here
+ in this case will be the default "restricted" one. This is fine since the
+ "restricted" access rule allows any value for the join rule, but we should keep
+ that in mind if we need to change the default access rule in the future.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
+ return rule == ACCESS_RULE_RESTRICTED
+
+ return True
+
+ def _on_room_avatar_change(self, event, rule):
+ """Check whether a change of room avatar is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_name_change(self, event, rule):
+ """Check whether a change of room name is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_topic_change(self, event, rule):
+ """Check whether a change of room topic is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ @staticmethod
+ def _get_rule_from_state(state_events):
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the rule (either "direct", "restricted" or "unrestricted")
+ """
+ access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
+ if access_rules is None:
+ rule = ACCESS_RULE_RESTRICTED
+ else:
+ rule = access_rules.content.get("rule")
+ return rule
+
+ @staticmethod
+ def _get_join_rule_from_state(state_events):
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the join rule (either "public", or "invite")
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return None
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
+ def _get_members_and_tokens_from_state(state_events):
+ """Retrieves from a list of state events the list of users that have a
+ m.room.member event in the room, and the tokens of 3PID invites in the room.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ existing_members (list[str]): List of targets of the m.room.member events in
+ the state.
+ threepid_invite_tokens (list[str]): List of tokens of the 3PID invites in the
+ state.
+ """
+ existing_members = []
+ threepid_invite_tokens = []
+ for key, state_event in state_events.items():
+ if key[0] == EventTypes.Member and state_event.content:
+ existing_members.append(state_event.state_key)
+ if key[0] == EventTypes.ThirdPartyInvite and state_event.content:
+ # Don't include revoked invites.
+ threepid_invite_tokens.append(state_event.state_key)
+
+ return existing_members, threepid_invite_tokens
+
+ @staticmethod
+ def _is_invite_from_threepid(invite, threepid_invite_token):
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite (EventBase): The m.room.member event with "invite" membership.
+ threepid_invite_token (str): The state key from the 3PID invite.
+ """
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
+
+ return token == threepid_invite_token
diff --git a/synapse/types.py b/synapse/types.py
index acf60baddc..40cab6ed74 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -19,6 +19,8 @@ import sys
from collections import namedtuple
from typing import Any, Dict, Tuple, TypeVar
+from six.moves import filter
+
import attr
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
@@ -29,7 +31,7 @@ from synapse.api.errors import Codes, SynapseError
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
- from typing import Sized, Iterable, Container
+ from typing import Container, Iterable, Sized
T_co = TypeVar("T_co", covariant=True)
@@ -267,6 +269,19 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+def strip_invalid_mxid_characters(localpart):
+ """Removes any invalid characters from an mxid
+
+ Args:
+ localpart (basestring): the localpart to be stripped
+
+ Returns:
+ localpart (basestring): the localpart having been stripped
+ """
+ filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
+ return "".join(filtered)
+
+
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 60f0de70f7..c63256d3bd 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -55,7 +55,7 @@ class Clock(object):
return self._reactor.seconds()
def time_msec(self):
- """Returns the current system time in miliseconds since epoch."""
+ """Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000)
def looping_call(self, f, msec, *args, **kwargs):
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 65abf0846e..f562770922 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -352,7 +352,7 @@ class ReadWriteLock(object):
# resolved when they release the lock).
#
# Read: We know its safe to acquire a read lock when the latest writer has
- # been resolved. The new reader is appeneded to the list of latest readers.
+ # been resolved. The new reader is appended to the list of latest readers.
#
# Write: We know its safe to acquire the write lock when both the latest
# writers and readers have been resolved. The new writer replaces the latest
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 64f35fc288..9b09c08b89 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -516,7 +516,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
"""
Args:
orig (function)
- cached_method_name (str): The name of the chached method.
+ cached_method_name (str): The name of the cached method.
list_name (str): Name of the argument which is the bulk lookup list
num_args (int): number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 45af8d3eeb..da20523b70 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -39,7 +39,7 @@ class Distributor(object):
Signals are named simply by strings.
TODO(paul): It would be nice to give signals stronger object identities,
- so we can attach metadata, docstrings, detect typoes, etc... But this
+ so we can attach metadata, docstrings, detect typos, etc... But this
model will do for today.
"""
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 2605f3c65b..54c046b6e1 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -192,7 +192,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
result = yield d
except Exception:
# this will fish an earlier Failure out of the stack where possible, and
- # thus is preferable to passing in an exeception to the Failure
+ # thus is preferable to passing in an exception to the Failure
# constructor, since it results in less stack-mangling.
result = Failure()
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index af69587196..8794317caa 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -22,7 +22,7 @@ from synapse.api.errors import CodeMessageException
logger = logging.getLogger(__name__)
-# the intial backoff, after the first transaction fails
+# the initial backoff, after the first transaction fails
MIN_RETRY_INTERVAL = 10 * 60 * 1000
# how much we multiply the backoff by after each subsequent fail
@@ -174,7 +174,7 @@ class RetryDestinationLimiter(object):
# has been decommissioned.
# If we get a 401, then we should probably back off since they
# won't accept our requests for at least a while.
- # 429 is us being aggresively rate limited, so lets rate limit
+ # 429 is us being aggressively rate limited, so lets rate limit
# ourselves.
if exc_val.code == 404 and self.backoff_on_404:
valid_err_code = False
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..cfdaa1c5d9 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,8 +73,31 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
+
+ defer.returnValue(False)
+
+
+def canonicalise_email(address: str) -> str:
+ """'Canonicalise' email address
+ Case folding of local part of email address and lowercase domain part
+ See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265
+
+ Args:
+ address: email address to be canonicalised
+ Returns:
+ The canonical form of the email address
+ Raises:
+ ValueError if the address could not be parsed.
+ """
+
+ address = address.strip()
+
+ parts = address.split("@")
+ if len(parts) != 2:
+ logger.debug("Couldn't parse email address %s", address)
+ raise ValueError("Unable to parse email address")
- return False
+ return parts[0].casefold() + "@" + parts[1].lower()
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 3dfd4af26c..0f042c5696 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -319,7 +319,7 @@ def filter_events_for_server(
return True
# Lets check to see if all the events have a history visibility
- # of "shared" or "world_readable". If thats the case then we don't
+ # of "shared" or "world_readable". If that's the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
@@ -335,7 +335,7 @@ def filter_events_for_server(
visibility_ids.add(hist)
# If we failed to find any history visibility events then the default
- # is "shared" visiblity.
+ # is "shared" visibility.
if not visibility_ids:
all_open = True
else:
|