diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 616942b057..11da016ac5 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -64,6 +64,13 @@ class Codes(object):
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
+ PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
+ PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
+ PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
+ PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
+ PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
+ PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
+ WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
@@ -439,6 +446,20 @@ class IncompatibleRoomVersionError(SynapseError):
return cs_error(self.msg, self.errcode, room_version=self._room_version)
+class PasswordRefusedError(SynapseError):
+ """A password has been refused, either during password reset/change or registration.
+ """
+
+ def __init__(
+ self,
+ msg="This password doesn't comply with the server's policy",
+ errcode=Codes.WEAK_PASSWORD,
+ ):
+ super(PasswordRefusedError, self).__init__(
+ code=400, msg=msg, errcode=errcode,
+ )
+
+
class RequestSendFailed(RuntimeError):
"""Sending a HTTP request over federation failed due to not being able to
talk to the remote server for some reason.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index c8fd8909a4..fba7ad9551 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -401,6 +401,9 @@ class GenericWorkerTyping(object):
self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids
+ def get_current_token(self) -> int:
+ return self._latest_room_serial
+
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
@@ -875,6 +878,9 @@ def start(config_options):
# Force the appservice to start since they will be disabled in the main config
config.notify_appservices = True
+ else:
+ # For other worker types we force this to off.
+ config.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
if config.start_pushers:
@@ -888,6 +894,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
+ else:
+ # For other worker types we force this to off.
+ config.start_pushers = False
if config.worker_app == "synapse.app.user_dir":
if config.update_user_directory:
@@ -901,6 +910,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True
+ else:
+ # For other worker types we force this to off.
+ config.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
if config.send_federation:
@@ -914,6 +926,9 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
+ else:
+ # For other worker types we force this to off.
+ config.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index f0171bb5b2..56c87fa296 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -24,7 +24,6 @@ class CaptchaConfig(Config):
self.enable_registration_captcha = config.get(
"enable_registration_captcha", False
)
- self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config.get(
"recaptcha_siteverify_api",
"https://www.recaptcha.net/recaptcha/api/siteverify",
@@ -49,10 +48,6 @@ class CaptchaConfig(Config):
#
#enable_registration_captcha: false
- # A secret key used to bypass the captcha test entirely.
- #
- #captcha_bypass_secret: "YOUR_SECRET_HERE"
-
# The API endpoint to use for verifying m.login.recaptcha responses.
#
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
diff --git a/synapse/config/database.py b/synapse/config/database.py
index b8ab2f86ac..c27fef157b 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -20,6 +20,11 @@ from synapse.config._base import Config, ConfigError
logger = logging.getLogger(__name__)
+NON_SQLITE_DATABASE_PATH_WARNING = """\
+Ignoring 'database_path' setting: not using a sqlite3 database.
+--------------------------------------------------------------------------------
+"""
+
DEFAULT_CONFIG = """\
## Database ##
@@ -105,6 +110,11 @@ class DatabaseConnectionConfig:
class DatabaseConfig(Config):
section = "database"
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.databases = []
+
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
@@ -125,12 +135,13 @@ class DatabaseConfig(Config):
multi_database_config = config.get("databases")
database_config = config.get("database")
+ database_path = config.get("database_path")
if multi_database_config and database_config:
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
if multi_database_config:
- if config.get("database_path"):
+ if database_path:
raise ConfigError("Can't specify 'database_path' with 'databases'")
self.databases = [
@@ -138,13 +149,17 @@ class DatabaseConfig(Config):
for name, db_conf in multi_database_config.items()
]
- else:
- if database_config is None:
- database_config = {"name": "sqlite3", "args": {}}
-
+ if database_config:
self.databases = [DatabaseConnectionConfig("master", database_config)]
- self.set_databasepath(config.get("database_path"))
+ if database_path:
+ if self.databases and self.databases[0].name != "sqlite3":
+ logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
+ return
+
+ database_config = {"name": "sqlite3", "args": {}}
+ self.databases = [DatabaseConnectionConfig("master", database_config)]
+ self.set_databasepath(database_path)
def generate_config_section(self, data_dir_path, **kwargs):
return DEFAULT_CONFIG % {
@@ -152,27 +167,37 @@ class DatabaseConfig(Config):
}
def read_arguments(self, args):
- self.set_databasepath(args.database_path)
+ """
+ Cases for the cli input:
+ - If no databases are configured and no database_path is set, raise.
+ - No databases and only database_path available ==> sqlite3 db.
+ - If there are multiple databases and a database_path raise an error.
+ - If the database set in the config file is sqlite then
+ overwrite with the command line argument.
+ """
- def set_databasepath(self, database_path):
- if database_path is None:
+ if args.database_path is None:
+ if not self.databases:
+ raise ConfigError("No database config provided")
return
- if database_path != ":memory:":
- database_path = self.abspath(database_path)
+ if len(self.databases) == 0:
+ database_config = {"name": "sqlite3", "args": {}}
+ self.databases = [DatabaseConnectionConfig("master", database_config)]
+ self.set_databasepath(args.database_path)
+ return
+
+ if self.get_single_database().name == "sqlite3":
+ self.set_databasepath(args.database_path)
+ else:
+ logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
- # We only support setting a database path if we have a single sqlite3
- # database.
- if len(self.databases) != 1:
- raise ConfigError("Cannot specify 'database_path' with multiple databases")
+ def set_databasepath(self, database_path):
- database = self.get_single_database()
- if database.config["name"] != "sqlite3":
- # We don't raise here as we haven't done so before for this case.
- logger.warn("Ignoring 'database_path' for non-sqlite3 database")
- return
+ if database_path != ":memory:":
+ database_path = self.abspath(database_path)
- database.config["args"]["database"] = database_path
+ self.databases[0].config["args"]["database"] = database_path
@staticmethod
def add_arguments(parser):
@@ -187,7 +212,7 @@ class DatabaseConfig(Config):
def get_single_database(self) -> DatabaseConnectionConfig:
"""Returns the database if there is only one, useful for e.g. tests
"""
- if len(self.databases) != 1:
+ if not self.databases:
raise Exception("More than one database exists")
return self.databases[0]
diff --git a/synapse/config/password.py b/synapse/config/password.py
index 2a634ac751..9c0ea8c30a 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -31,6 +31,10 @@ class PasswordConfig(Config):
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
self.password_pepper = password_config.get("pepper", "")
+ # Password policy
+ self.password_policy = password_config.get("policy") or {}
+ self.password_policy_enabled = self.password_policy.get("enabled", False)
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -48,4 +52,39 @@ class PasswordConfig(Config):
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#
#pepper: "EVEN_MORE_SECRET"
+
+ # Define and enforce a password policy. Each parameter is optional.
+ # This is an implementation of MSC2000.
+ #
+ policy:
+ # Whether to enforce the password policy.
+ # Defaults to 'false'.
+ #
+ #enabled: true
+
+ # Minimum accepted length for a password.
+ # Defaults to 0.
+ #
+ #minimum_length: 15
+
+ # Whether a password must contain at least one digit.
+ # Defaults to 'false'.
+ #
+ #require_digit: true
+
+ # Whether a password must contain at least one symbol.
+ # A symbol is any character that's not a number or a letter.
+ # Defaults to 'false'.
+ #
+ #require_symbol: true
+
+ # Whether a password must contain at least one lowercase letter.
+ # Defaults to 'false'.
+ #
+ #require_lowercase: true
+
+ # Whether a password must contain at least one lowercase letter.
+ # Defaults to 'false'.
+ #
+ #require_uppercase: true
"""
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 9bb3beedbc..e7ea3a01cb 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -129,6 +129,10 @@ class RegistrationConfig(Config):
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
+ self.enable_set_displayname = config.get("enable_set_displayname", True)
+ self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
+ self.enable_3pid_changes = config.get("enable_3pid_changes", True)
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -330,6 +334,29 @@ class RegistrationConfig(Config):
#email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
+ # Whether users are allowed to change their displayname after it has
+ # been initially set. Useful when provisioning users based on the
+ # contents of a third-party directory.
+ #
+ # Does not apply to server administrators. Defaults to 'true'
+ #
+ #enable_set_displayname: false
+
+ # Whether users are allowed to change their avatar after it has been
+ # initially set. Useful when provisioning users based on the contents
+ # of a third-party directory.
+ #
+ # Does not apply to server administrators. Defaults to 'true'
+ #
+ #enable_set_avatar_url: false
+
+ # Whether users can change the 3PIDs associated with their accounts
+ # (email address and msisdn).
+ #
+ # Defaults to 'true'
+ #
+ #enable_3pid_changes: false
+
# Users who register on this homeserver will automatically be joined
# to these rooms
#
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 95762689bc..ec3dca9efc 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -39,6 +39,17 @@ class SSOConfig(Config):
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
+ # Attempt to also whitelist the server's login fallback, since that fallback sets
+ # the redirect URL to itself (so it can process the login token then return
+ # gracefully to the client). This would make it pointless to ask the user for
+ # confirmation, since the URL the confirmation page would be showing wouldn't be
+ # the client's.
+ # public_baseurl is an optional setting, so we only add the fallback's URL to the
+ # list if it's provided (because we can't figure out what that URL is otherwise).
+ if self.public_baseurl:
+ login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
+ self.sso_client_whitelist.append(login_fallback_url)
+
def generate_config_section(self, **kwargs):
return """\
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
@@ -54,6 +65,10 @@ class SSOConfig(Config):
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
+ # If public_baseurl is set, then the login fallback page (used by clients
+ # that don't natively support the required login flows) is whitelisted in
+ # addition to any URLs in this list.
+ #
# By default, this list is empty.
#
#client_whitelist:
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 983f0ead8c..a9f4025bfe 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -43,8 +43,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.context import (
- LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
preserve_fn,
run_in_background,
@@ -236,7 +236,7 @@ class Keyring(object):
"""
try:
- ctx = LoggingContext.current_context()
+ ctx = current_context()
# map from server name to a set of outstanding request ids
server_to_request_ids = {}
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index b0b0eba41e..4b115aac04 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -32,8 +32,8 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
- LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
)
from synapse.types import JsonDict, get_domain_from_id
@@ -78,7 +78,7 @@ class FederationBase(object):
"""
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
- ctx = LoggingContext.current_context()
+ ctx = current_context()
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 233cb33daf..a477578e44 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -499,4 +499,13 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int:
+ # Dummy implementation for case where federation sender isn't offloaded
+ # to a worker.
return 0
+
+ async def get_replication_rows(
+ self, from_token, to_token, limit, federation_ack=None
+ ):
+ # Dummy implementation for case where federation sender isn't offloaded
+ # to a worker.
+ return []
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7860f9625e..2ce1425dfa 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -125,7 +125,11 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def validate_user_via_ui_auth(
- self, requester: Requester, request_body: Dict[str, Any], clientip: str
+ self,
+ requester: Requester,
+ request: SynapseRequest,
+ request_body: Dict[str, Any],
+ clientip: str,
):
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -137,6 +141,8 @@ class AuthHandler(BaseHandler):
Args:
requester: The user, as given by the access token
+ request: The request sent by the client.
+
request_body: The body of the request sent by the client
clientip: The IP address of the client.
@@ -172,7 +178,9 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_login_types]
try:
- result, params, _ = yield self.check_auth(flows, request_body, clientip)
+ result, params, _ = yield self.check_auth(
+ flows, request, request_body, clientip
+ )
except LoginError:
# Update the ratelimite to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(
@@ -211,7 +219,11 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def check_auth(
- self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
+ self,
+ flows: List[List[str]],
+ request: SynapseRequest,
+ clientdict: Dict[str, Any],
+ clientip: str,
):
"""
Takes a dictionary sent by the client in the login / registration
@@ -231,6 +243,8 @@ class AuthHandler(BaseHandler):
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
+ request: The request sent by the client.
+
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
@@ -270,13 +284,27 @@ class AuthHandler(BaseHandler):
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a homeserver.
- # Revisit: Assumimg the REST APIs do sensible validation, the data
+ # Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbintrary.
session["clientdict"] = clientdict
self._save_session(session)
elif "clientdict" in session:
clientdict = session["clientdict"]
+ # Ensure that the queried operation does not vary between stages of
+ # the UI authentication session. This is done by generating a stable
+ # comparator based on the URI, method, and body (minus the auth dict)
+ # and storing it during the initial query. Subsequent queries ensure
+ # that this comparator has not changed.
+ comparator = (request.uri, request.method, clientdict)
+ if "ui_auth" not in session:
+ session["ui_auth"] = comparator
+ elif session["ui_auth"] != comparator:
+ raise SynapseError(
+ 403,
+ "Requested operation has changed during the UI authentication session.",
+ )
+
if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session)
@@ -322,6 +350,7 @@ class AuthHandler(BaseHandler):
creds,
list(clientdict),
)
+
return creds, clientdict, session["id"]
ret = self._auth_dict_for_flows(flows, session)
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
new file mode 100644
index 0000000000..f8dc274b78
--- /dev/null
+++ b/synapse/handlers/cas_handler.py
@@ -0,0 +1,204 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import xml.etree.ElementTree as ET
+from typing import AnyStr, Dict, Optional, Tuple
+
+from six.moves import urllib
+
+from twisted.web.client import PartialDownloadError
+
+from synapse.api.errors import Codes, LoginError
+from synapse.http.site import SynapseRequest
+from synapse.types import UserID, map_username_to_mxid_localpart
+
+logger = logging.getLogger(__name__)
+
+
+class CasHandler:
+ """
+ Utility class for to handle the response from a CAS SSO service.
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+
+ def __init__(self, hs):
+ self._hostname = hs.hostname
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_registration_handler()
+
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
+ self._cas_displayname_attribute = hs.config.cas_displayname_attribute
+ self._cas_required_attributes = hs.config.cas_required_attributes
+
+ self._http_client = hs.get_proxied_http_client()
+
+ def _build_service_param(self, client_redirect_url: AnyStr) -> str:
+ return "%s%s?%s" % (
+ self._cas_service_url,
+ "/_matrix/client/r0/login/cas/ticket",
+ urllib.parse.urlencode({"redirectUrl": client_redirect_url}),
+ )
+
+ async def _handle_cas_response(
+ self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str
+ ) -> None:
+ """
+ Retrieves the user and display name from the CAS response and continues with the authentication.
+
+ Args:
+ request: The original client request.
+ cas_response_body: The response from the CAS server.
+ client_redirect_url: The URl to redirect the client to when
+ everything is done.
+ """
+ user, attributes = self._parse_cas_response(cas_response_body)
+ displayname = attributes.pop(self._cas_displayname_attribute, None)
+
+ for required_attribute, required_value in self._cas_required_attributes.items():
+ # If required attribute was not in CAS Response - Forbidden
+ if required_attribute not in attributes:
+ raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+ # Also need to check value
+ if required_value is not None:
+ actual_value = attributes[required_attribute]
+ # If required attribute value does not match expected - Forbidden
+ if required_value != actual_value:
+ raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+ await self._on_successful_auth(user, request, client_redirect_url, displayname)
+
+ def _parse_cas_response(
+ self, cas_response_body: str
+ ) -> Tuple[str, Dict[str, Optional[str]]]:
+ """
+ Retrieve the user and other parameters from the CAS response.
+
+ Args:
+ cas_response_body: The response from the CAS query.
+
+ Returns:
+ A tuple of the user and a mapping of other attributes.
+ """
+ user = None
+ attributes = {}
+ try:
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise Exception("root of CAS response is not serviceResponse")
+ success = root[0].tag.endswith("authenticationSuccess")
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes[tag] = attribute.text
+ if user is None:
+ raise Exception("CAS response does not contain user")
+ except Exception:
+ logger.exception("Error parsing CAS response")
+ raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
+ if not success:
+ raise LoginError(
+ 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+ )
+ return user, attributes
+
+ async def _on_successful_auth(
+ self,
+ username: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ user_display_name: Optional[str] = None,
+ ) -> None:
+ """Called once the user has successfully authenticated with the SSO.
+
+ Registers the user if necessary, and then returns a redirect (with
+ a login token) to the client.
+
+ Args:
+ username: the remote user id. We'll map this onto
+ something sane for a MXID localpath.
+
+ request: the incoming request from the browser. We'll
+ respond to it with a redirect.
+
+ client_redirect_url: the redirect_url the client gave us when
+ it first started the process.
+
+ user_display_name: if set, and we have to register a new user,
+ we will set their displayname to this.
+ """
+ localpart = map_username_to_mxid_localpart(username)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ if not registered_user_id:
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=user_display_name
+ )
+
+ self._auth_handler.complete_sso_login(
+ registered_user_id, request, client_redirect_url
+ )
+
+ def handle_redirect_request(self, client_redirect_url: bytes) -> bytes:
+ """
+ Generates a URL to the CAS server where the client should be redirected.
+
+ Args:
+ client_redirect_url: The final URL the client should go to after the
+ user has negotiated SSO.
+
+ Returns:
+ The URL to redirect to.
+ """
+ args = urllib.parse.urlencode(
+ {"service": self._build_service_param(client_redirect_url)}
+ )
+
+ return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii")
+
+ async def handle_ticket_request(
+ self, request: SynapseRequest, client_redirect_url: str, ticket: str
+ ) -> None:
+ """
+ Validates a CAS ticket sent by the client for login/registration.
+
+ On a successful request, writes a redirect to the request.
+ """
+ uri = self._cas_server_url + "/proxyValidate"
+ args = {
+ "ticket": ticket,
+ "service": self._build_service_param(client_redirect_url),
+ }
+ try:
+ body = await self._http_client.get_raw(uri, args)
+ except PartialDownloadError as pde:
+ # Twisted raises this error if the connection is closed,
+ # even if that's being used old-http style to signal end-of-data
+ body = pde.response
+
+ await self._handle_cas_response(request, body, client_redirect_url)
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
new file mode 100644
index 0000000000..d06b110269
--- /dev/null
+++ b/synapse/handlers/password_policy.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from synapse.api.errors import Codes, PasswordRefusedError
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyHandler(object):
+ def __init__(self, hs):
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ # Regexps for the spec'd policy parameters.
+ self.regexp_digit = re.compile("[0-9]")
+ self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
+ self.regexp_uppercase = re.compile("[A-Z]")
+ self.regexp_lowercase = re.compile("[a-z]")
+
+ def validate_password(self, password):
+ """Checks whether a given password complies with the server's policy.
+
+ Args:
+ password (str): The password to check against the server's policy.
+
+ Raises:
+ PasswordRefusedError: The password doesn't comply with the server's policy.
+ """
+
+ if not self.enabled:
+ return
+
+ minimum_accepted_length = self.policy.get("minimum_length", 0)
+ if len(password) < minimum_accepted_length:
+ raise PasswordRefusedError(
+ msg=(
+ "The password must be at least %d characters long"
+ % minimum_accepted_length
+ ),
+ errcode=Codes.PASSWORD_TOO_SHORT,
+ )
+
+ if (
+ self.policy.get("require_digit", False)
+ and self.regexp_digit.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one digit",
+ errcode=Codes.PASSWORD_NO_DIGIT,
+ )
+
+ if (
+ self.policy.get("require_symbol", False)
+ and self.regexp_symbol.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one symbol",
+ errcode=Codes.PASSWORD_NO_SYMBOL,
+ )
+
+ if (
+ self.policy.get("require_uppercase", False)
+ and self.regexp_uppercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one uppercase letter",
+ errcode=Codes.PASSWORD_NO_UPPERCASE,
+ )
+
+ if (
+ self.policy.get("require_lowercase", False)
+ and self.regexp_lowercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one lowercase letter",
+ errcode=Codes.PASSWORD_NO_LOWERCASE,
+ )
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 50ce0c585b..6aa1c0f5e0 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -157,6 +157,15 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
+ if not by_admin and not self.hs.config.enable_set_displayname:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ if profile.display_name:
+ raise SynapseError(
+ 400,
+ "Changing display name is disabled on this server",
+ Codes.FORBIDDEN,
+ )
+
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -218,6 +227,13 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
+ if not by_admin and not self.hs.config.enable_set_avatar_url:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ if profile.avatar_url:
+ raise SynapseError(
+ 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
+ )
+
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 72c109981b..dc04b53f43 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -26,6 +26,7 @@ from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
+from synapse.module_api.errors import RedirectException
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
@@ -119,6 +120,9 @@ class SamlHandler:
try:
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
+ except RedirectException:
+ # Raise the exception as per the wishes of the SAML module response
+ raise
except Exception as e:
# If decoding the response or mapping it to a user failed, then log the
# error and tell the user that something went wrong.
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 12657ca698..7d1263caf2 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -32,6 +32,7 @@ class SetPasswordHandler(BaseHandler):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
+ self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks
def set_password(
@@ -44,6 +45,7 @@ class SetPasswordHandler(BaseHandler):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
+ self._password_policy_handler.validate_password(new_password)
password_hash = yield self._auth_handler.hash(new_password)
try:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 669dbc8a48..5746fdea14 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -26,7 +26,7 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.events import EventBase
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
@@ -301,7 +301,7 @@ class SyncHandler(object):
else:
sync_type = "incremental_sync"
- context = LoggingContext.current_context()
+ context = current_context()
if context:
context.tag = sync_type
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 58f9cc61c8..b58ae3d9db 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -19,7 +19,7 @@ import threading
from prometheus_client.core import Counter, Histogram
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
from synapse.metrics import LaterGauge
logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ LaterGauge(
class RequestMetrics(object):
def start(self, time_sec, name, method):
self.start = time_sec
- self.start_context = LoggingContext.current_context()
+ self.start_context = current_context()
self.name = name
self.method = method
@@ -163,7 +163,7 @@ class RequestMetrics(object):
with _in_flight_requests_lock:
_in_flight_requests.discard(self)
- context = LoggingContext.current_context()
+ context = current_context()
tag = ""
if context:
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index ffa7b20ca8..7372450b45 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
TerseJSONToConsoleLogObserver,
TerseJSONToTCPLogObserver,
)
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@@ -86,7 +86,7 @@ class LogContextObserver(object):
].startswith("Timing out client"):
return
- context = LoggingContext.current_context()
+ context = current_context()
# Copy the context information to the log event.
if context is not None:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 860b99a4c6..a8eafb1c7c 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -175,7 +175,54 @@ class ContextResourceUsage(object):
return res
-LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
+LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
+
+
+class _Sentinel(object):
+ """Sentinel to represent the root context"""
+
+ __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
+
+ def __init__(self) -> None:
+ # Minimal set for compatibility with LoggingContext
+ self.previous_context = None
+ self.finished = False
+ self.request = None
+ self.scope = None
+ self.tag = None
+
+ def __str__(self):
+ return "sentinel"
+
+ def copy_to(self, record):
+ pass
+
+ def copy_to_twisted_log_entry(self, record):
+ record["request"] = None
+ record["scope"] = None
+
+ def start(self):
+ pass
+
+ def stop(self):
+ pass
+
+ def add_database_transaction(self, duration_sec):
+ pass
+
+ def add_database_scheduled(self, sched_sec):
+ pass
+
+ def record_event_fetch(self, event_count):
+ pass
+
+ def __nonzero__(self):
+ return False
+
+ __bool__ = __nonzero__ # python3
+
+
+SENTINEL_CONTEXT = _Sentinel()
class LoggingContext(object):
@@ -199,76 +246,33 @@ class LoggingContext(object):
"_resource_usage",
"usage_start",
"main_thread",
- "alive",
+ "finished",
"request",
"tag",
"scope",
]
- thread_local = threading.local()
-
- class Sentinel(object):
- """Sentinel to represent the root context"""
-
- __slots__ = ["previous_context", "alive", "request", "scope", "tag"]
-
- def __init__(self) -> None:
- # Minimal set for compatibility with LoggingContext
- self.previous_context = None
- self.alive = None
- self.request = None
- self.scope = None
- self.tag = None
-
- def __str__(self):
- return "sentinel"
-
- def copy_to(self, record):
- pass
-
- def copy_to_twisted_log_entry(self, record):
- record["request"] = None
- record["scope"] = None
-
- def start(self):
- pass
-
- def stop(self):
- pass
-
- def add_database_transaction(self, duration_sec):
- pass
-
- def add_database_scheduled(self, sched_sec):
- pass
-
- def record_event_fetch(self, event_count):
- pass
-
- def __nonzero__(self):
- return False
-
- __bool__ = __nonzero__ # python3
-
- sentinel = Sentinel()
-
def __init__(self, name=None, parent_context=None, request=None) -> None:
- self.previous_context = LoggingContext.current_context()
+ self.previous_context = current_context()
self.name = name
# track the resources used by this context so far
self._resource_usage = ContextResourceUsage()
- # If alive has the thread resource usage when the logcontext last
- # became active.
+ # The thread resource usage when the logcontext became active. None
+ # if the context is not currently active.
self.usage_start = None
self.main_thread = get_thread_id()
self.request = None
self.tag = ""
- self.alive = True
self.scope = None # type: Optional[_LogContextScope]
+ # keep track of whether we have hit the __exit__ block for this context
+ # (suggesting that the the thing that created the context thinks it should
+ # be finished, and that re-activating it would suggest an error).
+ self.finished = False
+
self.parent_context = parent_context
if self.parent_context is not None:
@@ -283,44 +287,15 @@ class LoggingContext(object):
return str(self.request)
return "%s@%x" % (self.name, id(self))
- @classmethod
- def current_context(cls) -> LoggingContextOrSentinel:
- """Get the current logging context from thread local storage
-
- Returns:
- LoggingContext: the current logging context
- """
- return getattr(cls.thread_local, "current_context", cls.sentinel)
-
- @classmethod
- def set_current_context(
- cls, context: LoggingContextOrSentinel
- ) -> LoggingContextOrSentinel:
- """Set the current logging context in thread local storage
- Args:
- context(LoggingContext): The context to activate.
- Returns:
- The context that was previously active
- """
- current = cls.current_context()
-
- if current is not context:
- current.stop()
- cls.thread_local.current_context = context
- context.start()
- return current
-
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
- old_context = self.set_current_context(self)
+ old_context = set_current_context(self)
if self.previous_context != old_context:
logger.warning(
"Expected previous context %r, found %r",
self.previous_context,
old_context,
)
- self.alive = True
-
return self
def __exit__(self, type, value, traceback) -> None:
@@ -329,24 +304,19 @@ class LoggingContext(object):
Returns:
None to avoid suppressing any exceptions that were thrown.
"""
- current = self.set_current_context(self.previous_context)
+ current = set_current_context(self.previous_context)
if current is not self:
- if current is self.sentinel:
+ if current is SENTINEL_CONTEXT:
logger.warning("Expected logging context %s was lost", self)
else:
logger.warning(
"Expected logging context %s but found %s", self, current
)
- self.alive = False
-
- # if we have a parent, pass our CPU usage stats on
- if self.parent_context is not None and hasattr(
- self.parent_context, "_resource_usage"
- ):
- self.parent_context._resource_usage += self._resource_usage
- # reset them in case we get entered again
- self._resource_usage.reset()
+ # the fact that we are here suggests that the caller thinks that everything
+ # is done and dusted for this logcontext, and further activity will not get
+ # recorded against the correct metrics.
+ self.finished = True
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
@@ -371,9 +341,14 @@ class LoggingContext(object):
logger.warning("Started logcontext %s on different thread", self)
return
+ if self.finished:
+ logger.warning("Re-starting finished log context %s", self)
+
# If we haven't already started record the thread resource usage so
# far
- if not self.usage_start:
+ if self.usage_start:
+ logger.warning("Re-starting already-active log context %s", self)
+ else:
self.usage_start = get_thread_resource_usage()
def stop(self) -> None:
@@ -396,6 +371,15 @@ class LoggingContext(object):
self.usage_start = None
+ # if we have a parent, pass our CPU usage stats on
+ if self.parent_context is not None and hasattr(
+ self.parent_context, "_resource_usage"
+ ):
+ self.parent_context._resource_usage += self._resource_usage
+
+ # reset them in case we get entered again
+ self._resource_usage.reset()
+
def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far.
@@ -409,7 +393,7 @@ class LoggingContext(object):
# If we are on the correct thread and we're currently running then we
# can include resource usage so far.
is_main_thread = get_thread_id() == self.main_thread
- if self.alive and self.usage_start and is_main_thread:
+ if self.usage_start and is_main_thread:
utime_delta, stime_delta = self._get_cputime()
res.ru_utime += utime_delta
res.ru_stime += stime_delta
@@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter):
Returns:
True to include the record in the log output.
"""
- context = LoggingContext.current_context()
+ context = current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
@@ -512,27 +496,24 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
- def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
- if new_context is None:
- self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
- else:
- self.new_context = new_context
+ def __init__(
+ self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
+ ) -> None:
+ self.new_context = new_context
def __enter__(self) -> None:
"""Captures the current logging context"""
- self.current_context = LoggingContext.set_current_context(self.new_context)
+ self.current_context = set_current_context(self.new_context)
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
- if not self.current_context.alive:
- logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context"""
- context = LoggingContext.set_current_context(self.current_context)
+ context = set_current_context(self.current_context)
if context != self.new_context:
- if context is LoggingContext.sentinel:
+ if not context:
logger.warning("Expected logging context %s was lost", self.new_context)
else:
logger.warning(
@@ -541,9 +522,30 @@ class PreserveLoggingContext(object):
context,
)
- if self.current_context is not LoggingContext.sentinel:
- if not self.current_context.alive:
- logger.debug("Restoring dead context: %s", self.current_context)
+
+_thread_local = threading.local()
+_thread_local.current_context = SENTINEL_CONTEXT
+
+
+def current_context() -> LoggingContextOrSentinel:
+ """Get the current logging context from thread local storage"""
+ return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
+
+
+def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
+ """Set the current logging context in thread local storage
+ Args:
+ context(LoggingContext): The context to activate.
+ Returns:
+ The context that was previously active
+ """
+ current = current_context()
+
+ if current is not context:
+ current.stop()
+ _thread_local.current_context = context
+ context.start()
+ return current
def nested_logging_context(
@@ -572,7 +574,7 @@ def nested_logging_context(
if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel
else:
- context = LoggingContext.current_context()
+ context = current_context()
return LoggingContext(
parent_context=context, request=str(context.request) + "-" + suffix
)
@@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs):
CRITICAL error about an unhandled error will be logged without much
indication about where it came from.
"""
- current = LoggingContext.current_context()
+ current = current_context()
try:
res = f(*args, **kwargs)
except: # noqa: E722
@@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs):
# The function may have reset the context before returning, so
# we need to restore it now.
- ctx = LoggingContext.set_current_context(current)
+ ctx = set_current_context(current)
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
@@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred):
# ok, we can't be sure that a yield won't block, so let's reset the
# logcontext, and add a callback to the deferred to restore it.
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ prev_context = set_current_context(SENTINEL_CONTEXT)
deferred.addBoth(_set_context_cb, prev_context)
return deferred
@@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT")
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context"""
- LoggingContext.set_current_context(context)
+ set_current_context(context)
return result
@@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
- logcontext = LoggingContext.current_context()
+ logcontext = current_context()
def g():
with LoggingContext(parent_context=logcontext):
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 4eed4f2338..dc3ab00cbb 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
import twisted
-from synapse.logging.context import LoggingContext, nested_logging_context
+from synapse.logging.context import current_context, nested_logging_context
logger = logging.getLogger(__name__)
@@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
(Scope) : the Scope that is active, or None if not
available.
"""
- ctx = LoggingContext.current_context()
- if ctx is LoggingContext.sentinel:
- return None
- else:
- return ctx.scope
+ ctx = current_context()
+ return ctx.scope
def activate(self, span, finish_on_close):
"""
@@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
"""
enter_logcontext = False
- ctx = LoggingContext.current_context()
+ ctx = current_context()
- if ctx is LoggingContext.sentinel:
+ if not ctx:
# We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
return Scope(None, span)
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 28dbc6fcba..4613b2538c 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -21,6 +21,7 @@ from synapse.replication.http import (
membership,
register,
send_event,
+ streams,
)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
+ streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
new file mode 100644
index 0000000000..ffd4c61993
--- /dev/null
+++ b/synapse/replication/http/streams.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationGetStreamUpdates(ReplicationEndpoint):
+ """Fetches stream updates from a server. Used for streams not persisted to
+ the database, e.g. typing notifications.
+
+ The API looks like:
+
+ GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
+
+ 200 OK
+
+ {
+ updates: [ ... ],
+ upto_token: 10,
+ limited: False,
+ }
+
+ """
+
+ NAME = "get_repl_stream_updates"
+ PATH_ARGS = ("stream_name",)
+ METHOD = "GET"
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ # We pull the streams from the replication steamer (if we try and make
+ # them ourselves we end up in an import loop).
+ self.streams = hs.get_replication_streamer().get_streams()
+
+ @staticmethod
+ def _serialize_payload(stream_name, from_token, upto_token, limit):
+ return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
+
+ async def _handle_request(self, request, stream_name):
+ stream = self.streams.get(stream_name)
+ if stream is None:
+ raise SynapseError(400, "Unknown stream")
+
+ from_token = parse_integer(request, "from_token", required=True)
+ upto_token = parse_integer(request, "upto_token", required=True)
+ limit = parse_integer(request, "limit", required=True)
+
+ updates, upto_token, limited = await stream.get_updates_since(
+ from_token, upto_token, limit
+ )
+
+ return (
+ 200,
+ {"updates": updates, "upto_token": upto_token, "limited": limited},
+ )
+
+
+def register_servlets(hs, http_server):
+ ReplicationGetStreamUpdates(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f45cbd37a0..751c799d94 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,8 +18,10 @@ from typing import Dict, Optional
import six
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.data_stores.main.cache import (
+ CURRENT_STATE_CACHE_NAME,
+ CacheInvalidationWorkerStore,
+)
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@@ -35,7 +37,7 @@ def __func__(inp):
return inp.__func__
-class BaseSlavedStore(SQLBaseStore):
+class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
+ def get_cache_stream_token(self):
+ if self._cache_id_gen:
+ return self._cache_id_gen.get_current_token()
+ else:
+ return 0
+
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index f22c2d44a3..bce8a3d115 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
+ def get_pushers_stream_token(self):
+ return self._pushers_id_gen.get_current_token()
+
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..7e7ad0f798 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
+ self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
- self.client_name, self.server_name, self._clock, self.handler
+ self.hs, self.client_name, self.server_name, self._clock, self.handler,
)
def clientConnectionLost(self, connector, reason):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 451671412d..5a6b734094 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -136,8 +136,8 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
- Sent to the client after all missing updates for a stream have been sent
- to the client and they're now up to date.
+ On receipt of a POSITION command clients should check if they have missed
+ any updates, and if so then fetch them out of band.
"""
NAME = "POSITION"
@@ -179,42 +179,24 @@ class NameCommand(Command):
class ReplicateCommand(Command):
- """Sent by the client to subscribe to the stream.
+ """Sent by the client to subscribe to streams.
Format::
- REPLICATE <stream_name> <token>
-
- Where <token> may be either:
- * a numeric stream_id to stream updates from
- * "NOW" to stream all subsequent updates.
-
- The <stream_name> can be "ALL" to subscribe to all known streams, in which
- case the <token> must be set to "NOW", i.e.::
-
- REPLICATE ALL NOW
+ REPLICATE
"""
NAME = "REPLICATE"
- def __init__(self, stream_name, token):
- self.stream_name = stream_name
- self.token = token
+ def __init__(self):
+ pass
@classmethod
def from_line(cls, line):
- stream_name, token = line.split(" ", 1)
- if token in ("NOW", "now"):
- token = "NOW"
- else:
- token = int(token)
- return cls(stream_name, token)
+ return cls()
def to_line(self):
- return " ".join((self.stream_name, str(self.token)))
-
- def get_logcontext_id(self):
- return "REPLICATE-" + self.stream_name
+ return ""
class UserSyncCommand(Command):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..f81d2e2442 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
- < REPLICATE events 1
- < REPLICATE backfill 1
- < REPLICATE caches 1
+ < REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -53,17 +51,15 @@ import fcntl
import logging
import struct
from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import Any, DefaultDict, Dict, List, Set
-from six import iteritems, iterkeys
+from six import iteritems
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
@@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
+MYPY = False
+if MYPY:
+ from synapse.server import HomeServer
+
+
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.streamer = streamer
- # The streams the client has subscribed to and is up to date with
- self.replication_streams = set() # type: Set[str]
-
- # The streams the client is currently subscribing to.
- self.connecting_streams = set() # type: Set[str]
-
- # Map from stream name to list of updates to send once we've finished
- # subscribing the client to the stream.
- self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
-
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
@@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
async def on_REPLICATE(self, cmd):
- stream_name = cmd.stream_name
- token = cmd.token
-
- if stream_name == "ALL":
- # Subscribe to all streams we're publishing to.
- deferreds = [
- run_in_background(self.subscribe_to_stream, stream, token)
- for stream in iterkeys(self.streamer.streams_by_name)
- ]
-
- await make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
- else:
- await self.subscribe_to_stream(stream_name, token)
+ # Subscribe to all streams we're publishing to.
+ for stream_name in self.streamer.streams_by_name:
+ current_token = self.streamer.get_stream_token(stream_name)
+ self.send_command(PositionCommand(stream_name, current_token))
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
@@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen,
)
- async def subscribe_to_stream(self, stream_name, token):
- """Subscribe the remote to a stream.
-
- This invloves checking if they've missed anything and sending those
- updates down if they have. During that time new updates for the stream
- are queued and sent once we've sent down any missed updates.
- """
- self.replication_streams.discard(stream_name)
- self.connecting_streams.add(stream_name)
-
- try:
- # Get missing updates
- updates, current_token = await self.streamer.get_stream_updates(
- stream_name, token
- )
-
- # Send all the missing updates
- for update in updates:
- token, row = update[0], update[1]
- self.send_command(RdataCommand(stream_name, token, row))
-
- # We send a POSITION command to ensure that they have an up to
- # date token (especially useful if we didn't send any updates
- # above)
- self.send_command(PositionCommand(stream_name, current_token))
-
- # Now we can send any updates that came in while we were subscribing
- pending_rdata = self.pending_rdata.pop(stream_name, [])
- updates = []
- for token, update in pending_rdata:
- # If the token is null, it is part of a batch update. Batches
- # are multiple updates that share a single token. To denote
- # this, the token is set to None for all tokens in the batch
- # except for the last. If we find a None token, we keep looking
- # through tokens until we find one that is not None and then
- # process all previous updates in the batch as if they had the
- # final token.
- if token is None:
- # Store this update as part of a batch
- updates.append(update)
- continue
-
- if token <= current_token:
- # This update or batch of updates is older than
- # current_token, dismiss it
- updates = []
- continue
-
- updates.append(update)
-
- # Send all updates that are part of this batch with the
- # found token
- for update in updates:
- self.send_command(RdataCommand(stream_name, token, update))
-
- # Clear stored updates
- updates = []
-
- # They're now fully subscribed
- self.replication_streams.add(stream_name)
- except Exception as e:
- logger.exception("[%s] Failed to handle REPLICATE command", self.id())
- self.send_error("failed to handle replicate: %r", e)
- finally:
- self.connecting_streams.discard(stream_name)
-
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
- if stream_name in self.replication_streams:
- # The client is subscribed to the stream
- self.send_command(RdataCommand(stream_name, token, data))
- elif stream_name in self.connecting_streams:
- # The client is being subscribed to the stream
- logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
- self.pending_rdata.setdefault(stream_name, []).append((token, data))
- else:
- # The client isn't subscribed
- logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+ self.send_command(RdataCommand(stream_name, token, data))
def send_sync(self, data):
self.send_command(SyncCommand(data))
@@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self,
+ hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
@@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.handler = handler
+ self.streams = {
+ stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+ } # type: Dict[str, Stream]
+
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set() # type: Set[str]
+ self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {} # type: Dict[str, Any]
+ self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
- for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
- self.replicate(stream_name, token)
+ self.replicate()
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
@@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
- # This will happen if we don't actually subscribe to any streams
- if not self.streams_connecting:
- self.handler.finished_connecting()
-
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
raise
- if cmd.token is None:
+ if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
@@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
- async def on_POSITION(self, cmd):
- # When we get a `POSITION` command it means we've finished getting
- # missing updates for the given stream, and are now up to date.
+ async def on_POSITION(self, cmd: PositionCommand):
+ stream = self.streams.get(cmd.stream_name)
+ if not stream:
+ logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+ return
+
+ # Find where we previously streamed up to.
+ current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
+ if current_token is None:
+ logger.warning(
+ "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
+ )
+ return
+
+ # Fetch all updates between then and now.
+ limited = True
+ while limited:
+ updates, current_token, limited = await stream.get_updates_since(
+ current_token, cmd.token
+ )
+
+ # Check if the connection was closed underneath us, if so we bail
+ # rather than risk having concurrent catch ups going on.
+ if self.state == ConnectionStates.CLOSED:
+ return
+
+ if updates:
+ await self.handler.on_rdata(
+ cmd.stream_name,
+ current_token,
+ [stream.parse_row(update[1]) for update in updates],
+ )
+
+ # We've now caught up to position sent to us, notify handler.
+ await self.handler.on_position(cmd.stream_name, cmd.token)
+
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- await self.handler.on_position(cmd.stream_name, cmd.token)
+ # Check if the connection was closed underneath us, if so we bail
+ # rather than risk having concurrent catch ups going on.
+ if self.state == ConnectionStates.CLOSED:
+ return
+
+ # Handle any RDATA that came in while we were catching up.
+ rows = self.pending_batches.pop(cmd.stream_name, [])
+ if rows:
+ await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
@@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
- def replicate(self, stream_name, token):
+ def replicate(self):
"""Send the subscription request to the server
"""
- if stream_name not in STREAMS_MAP:
- raise Exception("Invalid stream name %r" % (stream_name,))
-
- logger.info(
- "[%s] Subscribing to replication stream: %r from %r",
- self.id(),
- stream_name,
- token,
- )
-
- self.streams_connecting.add(stream_name)
+ logger.info("[%s] Subscribing to replication streams", self.id())
- self.send_command(ReplicateCommand(stream_name, token))
+ self.send_command(ReplicateCommand())
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 6e2ebaf614..4374e99e32 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
import logging
import random
-from typing import Any, List
+from typing import Any, Dict, List
from six import itervalues
@@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
+from .streams import STREAMS_MAP, Stream
from .streams.federation import FederationStream
stream_updates_counter = Counter(
@@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
"""
def __init__(self, hs):
- self.streamer = ReplicationStreamer(hs)
+ self.streamer = hs.get_replication_streamer()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
@@ -133,6 +133,11 @@ class ReplicationStreamer(object):
for conn in self.connections:
conn.send_error("server shutting down")
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a mapp from stream name to stream instance.
+ """
+ return self.streams_by_name
+
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -190,7 +195,8 @@ class ReplicationStreamer(object):
stream.current_token(),
)
try:
- updates, current_token = await stream.get_updates()
+ updates, current_token, limited = await stream.get_updates()
+ self.pending_updates |= limited
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
@@ -226,8 +232,7 @@ class ReplicationStreamer(object):
self.pending_updates = False
self.is_looping = False
- @measure_func("repl.get_stream_updates")
- async def get_stream_updates(self, stream_name, token):
+ def get_stream_token(self, stream_name):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -235,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return await stream.get_updates_since(token)
+ return stream.current_token()
@measure_func("repl.federation_ack")
def federation_ack(self, token):
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 29199f5b46..37bcd3de66 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -24,6 +24,9 @@ Each stream is defined by the following information:
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
+
+from typing import Dict, Type
+
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
@@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
PushersStream,
PushRulesStream,
ReceiptsStream,
+ Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
@@ -63,10 +67,12 @@ STREAMS_MAP = {
GroupServerStream,
UserSignatureStream,
)
-}
+} # type: Dict[str, Type[Stream]]
+
__all__ = [
"STREAMS_MAP",
+ "Stream",
"BackfillStream",
"PresenceStream",
"TypingStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 32d9514883..c14dff6c64 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
from collections import namedtuple
-from typing import Any, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
+StreamRow = Tuple[Token, tuple]
+
+
class Stream(object):
"""Base class for the streams.
@@ -56,6 +65,7 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
+
# The token from which we last asked for updates
self.last_token = self.current_token()
@@ -65,61 +75,46 @@ class Stream(object):
"""
self.last_token = self.current_token()
- async def get_updates(self):
+ async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- updates, current_token = await self.get_updates_since(self.last_token)
+ current_token = self.current_token()
+ updates, current_token, limited = await self.get_updates_since(
+ self.last_token, current_token
+ )
self.last_token = current_token
- return updates, current_token
+ return updates, current_token, limited
async def get_updates_since(
- self, from_token: int
- ) -> Tuple[List[Tuple[int, JsonDict]], int]:
+ self, from_token: Token, upto_token: Token, limit: int = 100
+ ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
- Resolves to a pair `(updates, new_last_token)`, where `updates` is
- a list of `(token, row)` entries and `new_last_token` is the new
- position in stream.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- if from_token in ("NOW", "now"):
- return [], self.current_token()
-
- current_token = self.current_token()
-
from_token = int(from_token)
- if from_token == current_token:
- return [], current_token
+ if from_token == upto_token:
+ return [], upto_token, False
- rows = await self.update_function(
- from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+ updates, upto_token, limited = await self.update_function(
+ from_token, upto_token, limit=limit,
)
-
- # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
- rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-
- updates = [(row[0], row[1:]) for row in rows]
-
- # check we didn't get more rows than the limit.
- # doing it like this allows the update_function to be a generator.
- if len(updates) >= MAX_EVENTS_BEHIND:
- raise Exception("stream %s has fallen behind" % (self.NAME))
-
- # The update function didn't hit the limit, so we must have got all
- # the updates to `current_token`, and can return that as our new
- # stream position.
- return updates, current_token
+ return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -141,6 +136,48 @@ class Stream(object):
raise NotImplementedError()
+def db_query_to_update_function(
+ query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ """Wraps a db query function which returns a list of rows to make it
+ suitable for use as an `update_function` for the Stream class
+ """
+
+ async def update_function(from_token, upto_token, limit):
+ rows = await query_function(from_token, upto_token, limit)
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) == limit:
+ upto_token = rows[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return update_function
+
+
+def make_http_update_function(
+ hs, stream_name: str
+) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ """Makes a suitable function for use as an `update_function` that queries
+ the master process for updates.
+ """
+
+ client = ReplicationGetStreamUpdates.make_client(hs)
+
+ async def update_function(
+ from_token: int, upto_token: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ return await client(
+ stream_name=stream_name,
+ from_token=from_token,
+ upto_token=upto_token,
+ limit=limit,
+ )
+
+ return update_function
+
+
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
@@ -164,7 +201,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
- self.update_function = store.get_all_new_backfill_event_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -190,8 +227,15 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
+ self._is_worker = hs.config.worker_app is not None
+
self.current_token = store.get_current_presence_token # type: ignore
- self.update_function = presence_handler.get_all_presence_updates # type: ignore
+
+ if hs.config.worker_app is None:
+ self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -208,7 +252,12 @@ class TypingStream(Stream):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
- self.update_function = typing_handler.get_all_typing_updates # type: ignore
+
+ if hs.config.worker_app is None:
+ self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(TypingStream, self).__init__(hs)
@@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
- self.update_function = store.get_all_updated_receipts # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -256,7 +305,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
- return [(row[0], row[2]) for row in rows]
+
+ limited = False
+ if len(rows) == limit:
+ to_token = rows[-1][0]
+ limited = True
+
+ return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
@@ -275,7 +330,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
- self.update_function = store.get_all_updated_pushers_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
@@ -307,7 +362,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
- self.update_function = store.get_all_updated_caches # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
@@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
- self.update_function = store.get_all_new_public_rooms # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
- self.update_function = store.get_all_new_device_messages # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
- self.update_function = store.get_all_updated_tags # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -412,10 +467,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -442,7 +498,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
- self.update_function = store.get_all_groups_changes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..c6a595629f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr
-from ._base import Stream
+from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
- async def update_function(self, from_token, current_token, limit=None):
+ async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index f5f9336430..48c1d45718 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,9 @@
# limitations under the License.
from collections import namedtuple
-from ._base import Stream
+from twisted.internet import defer
+
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream):
@@ -33,11 +35,18 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
+ _QUERY_MASTER = True
def __init__(self, hs):
- federation_sender = hs.get_federation_sender()
-
- self.current_token = federation_sender.get_current_token # type: ignore
- self.update_function = federation_sender.get_replication_rows # type: ignore
+ # Not all synapse instances will have a federation sender instance,
+ # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
+ # so we stub the stream out when that is the case.
+ if hs.config.worker_app is None or hs.should_send_federation():
+ federation_sender = hs.get_federation_sender()
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
+ else:
+ self.current_token = lambda: 0 # type: ignore
+ self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super(FederationStream, self).__init__(hs)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 4a1fc2ec2b..46e458e95b 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
keys,
notifications,
openid,
+ password_policy,
read_marker,
receipts,
register,
@@ -118,6 +119,7 @@ class ClientRestResource(JsonResource):
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
# moving to /_synapse/admin
synapse.rest.admin.register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 42cc2b062a..ed70d448a1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -29,7 +29,11 @@ from synapse.rest.admin._base import (
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
-from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet
+from synapse.rest.admin.rooms import (
+ JoinRoomAliasServlet,
+ ListRoomRestServlet,
+ ShutdownRoomRestServlet,
+)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
@@ -189,6 +193,7 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
+ JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f9b8c0a4f0..659b8a10ee 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Optional
-from synapse.api.constants import Membership
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -29,7 +30,7 @@ from synapse.rest.admin._base import (
historical_admin_path_patterns,
)
from synapse.storage.data_stores.main.room import RoomSortOrder
-from synapse.types import create_requester
+from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -237,3 +238,75 @@ class ListRoomRestServlet(RestServlet):
response["prev_batch"] = 0
return 200, response
+
+
+class JoinRoomAliasServlet(RestServlet):
+
+ PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.admin_handler = hs.get_handlers().admin_handler
+ self.state_handler = hs.get_state_handler()
+
+ async def on_POST(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+
+ assert_params_in_dict(content, ["user_id"])
+ target_user = UserID.from_string(content["user_id"])
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "This endpoint can only be used with local users")
+
+ if not await self.admin_handler.get_user(target_user):
+ raise NotFoundError("User not found")
+
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ try:
+ remote_room_hosts = [
+ x.decode("ascii") for x in request.args[b"server_name"]
+ ] # type: Optional[List[str]]
+ except Exception:
+ remote_room_hosts = None
+ elif RoomAlias.is_valid(room_identifier):
+ handler = self.room_member_handler
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ fake_requester = create_requester(target_user)
+
+ # send invite if room has "JoinRules.INVITE"
+ room_state = await self.state_handler.get_current_state(room_id)
+ join_rules_event = room_state.get((EventTypes.JoinRules, ""))
+ if join_rules_event:
+ if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=fake_requester.user,
+ room_id=room_id,
+ action="invite",
+ remote_room_hosts=remote_room_hosts,
+ ratelimit=False,
+ )
+
+ await self.room_member_handler.update_membership(
+ requester=fake_requester,
+ target=fake_requester.user,
+ room_id=room_id,
+ action="join",
+ remote_room_hosts=remote_room_hosts,
+ ratelimit=False,
+ )
+
+ return 200, {"room_id": room_id}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 56d713462a..59593cbf6e 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,11 +14,6 @@
# limitations under the License.
import logging
-import xml.etree.ElementTree as ET
-
-from six.moves import urllib
-
-from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -28,9 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
logger = logging.getLogger(__name__)
@@ -72,14 +68,6 @@ def login_id_thirdparty_from_phone(identifier):
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
-def build_service_param(cas_service_url, client_redirect_url):
- return "%s%s?redirectUrl=%s" % (
- cas_service_url,
- "/_matrix/client/r0/login/cas/ticket",
- urllib.parse.quote(client_redirect_url, safe=""),
- )
-
-
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@@ -409,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
- def on_GET(self, request):
+ def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
@@ -418,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
request.redirect(sso_url)
finish_request(request)
- def get_sso_url(self, client_redirect_url):
+ def get_sso_url(self, client_redirect_url: bytes) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
- client_redirect_url (bytes): the URL that we should redirect the
+ client_redirect_url: the URL that we should redirect the
client to when everything is done
Returns:
- bytes: URL to redirect to
+ URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
@@ -434,16 +422,10 @@ class BaseSSORedirectServlet(RestServlet):
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
- super(CasRedirectServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url
- self.cas_service_url = hs.config.cas_service_url
+ self._cas_handler = hs.get_cas_handler()
- def get_sso_url(self, client_redirect_url):
- args = urllib.parse.urlencode(
- {"service": build_service_param(self.cas_service_url, client_redirect_url)}
- )
-
- return "%s/login?%s" % (self.cas_server_url, args)
+ def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ return self._cas_handler.handle_redirect_request(client_redirect_url)
class CasTicketServlet(RestServlet):
@@ -451,81 +433,15 @@ class CasTicketServlet(RestServlet):
def __init__(self, hs):
super(CasTicketServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url
- self.cas_service_url = hs.config.cas_service_url
- self.cas_displayname_attribute = hs.config.cas_displayname_attribute
- self.cas_required_attributes = hs.config.cas_required_attributes
- self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_proxied_http_client()
-
- async def on_GET(self, request):
- client_redirect_url = parse_string(request, "redirectUrl", required=True)
- uri = self.cas_server_url + "/proxyValidate"
- args = {
- "ticket": parse_string(request, "ticket", required=True),
- "service": build_service_param(self.cas_service_url, client_redirect_url),
- }
- try:
- body = await self._http_client.get_raw(uri, args)
- except PartialDownloadError as pde:
- # Twisted raises this error if the connection is closed,
- # even if that's being used old-http style to signal end-of-data
- body = pde.response
- result = await self.handle_cas_response(request, body, client_redirect_url)
- return result
+ self._cas_handler = hs.get_cas_handler()
- def handle_cas_response(self, request, cas_response_body, client_redirect_url):
- user, attributes = self.parse_cas_response(cas_response_body)
- displayname = attributes.pop(self.cas_displayname_attribute, None)
-
- for required_attribute, required_value in self.cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- return self._sso_auth_handler.on_successful_auth(
- user, request, client_redirect_url, displayname
+ async def on_GET(self, request: SynapseRequest) -> None:
+ client_redirect_url = parse_string(request, "redirectUrl", required=True)
+ ticket = parse_string(request, "ticket", required=True)
+ await self._cas_handler.handle_ticket_request(
+ request, client_redirect_url, ticket
)
- def parse_cas_response(self, cas_response_body):
- user = None
- attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.exception("Error parsing CAS response")
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
- )
- return user, attributes
-
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -533,65 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
- def get_sso_url(self, client_redirect_url):
+ def get_sso_url(self, client_redirect_url: bytes) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
-class SSOAuthHandler(object):
- """
- Utility class for Resources and Servlets which handle the response from a SSO
- service
-
- Args:
- hs (synapse.server.HomeServer)
- """
-
- def __init__(self, hs):
- self._hostname = hs.hostname
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
- self._macaroon_gen = hs.get_macaroon_generator()
-
- # cast to tuple for use with str.startswith
- self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
-
- async def on_successful_auth(
- self, username, request, client_redirect_url, user_display_name=None
- ):
- """Called once the user has successfully authenticated with the SSO.
-
- Registers the user if necessary, and then returns a redirect (with
- a login token) to the client.
-
- Args:
- username (unicode|bytes): the remote user id. We'll map this onto
- something sane for a MXID localpath.
-
- request (SynapseRequest): the incoming request from the browser. We'll
- respond to it with a redirect.
-
- client_redirect_url (unicode): the redirect_url the client gave us when
- it first started the process.
-
- user_display_name (unicode|None): if set, and we have to register a new user,
- we will set their displayname to this.
-
- Returns:
- Deferred[none]: Completes once we have handled the request.
- """
- localpart = map_username_to_mxid_localpart(username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=user_display_name
- )
-
- self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url
- )
-
-
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 631cc74cb4..f80b5e40ea 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -234,13 +234,16 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
params = await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester, request, body, self.hs.get_ip_from_request(request),
)
user_id = requester.user.to_string()
else:
requester = None
result, params, _ = await self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
+ [[LoginType.EMAIL_IDENTITY]],
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
)
if LoginType.EMAIL_IDENTITY in result:
@@ -308,7 +311,7 @@ class DeactivateAccountRestServlet(RestServlet):
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester, request, body, self.hs.get_ip_from_request(request),
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
@@ -602,6 +605,11 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
async def on_POST(self, request):
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
+
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -646,6 +654,11 @@ class ThreepidAddRestServlet(RestServlet):
@interactive_auth_handler
async def on_POST(self, request):
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
+
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -656,7 +669,7 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester, request, body, self.hs.get_ip_from_request(request),
)
validation_session = await self.identity_handler.validate_threepid_session(
@@ -741,10 +754,16 @@ class ThreepidDeleteRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
+ self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request):
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
+
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 94ff73f384..119d979052 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -81,7 +81,7 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester, request, body, self.hs.get_ip_from_request(request),
)
await self.device_handler.delete_devices(
@@ -127,7 +127,7 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester, request, body, self.hs.get_ip_from_request(request),
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index f7ed4daf90..5eb7ef35a4 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -263,7 +263,7 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester, request, body, self.hs.get_ip_from_request(request),
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
new file mode 100644
index 0000000000..968403cca4
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import RestServlet
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyServlet(RestServlet):
+ PATTERNS = client_patterns("/password_policy$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(PasswordPolicyServlet, self).__init__()
+
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ def on_GET(self, request):
+ if not self.enabled or not self.policy:
+ return (200, {})
+
+ policy = {}
+
+ for param in [
+ "minimum_length",
+ "require_digit",
+ "require_symbol",
+ "require_lowercase",
+ "require_uppercase",
+ ]:
+ if param in self.policy:
+ policy["m.%s" % param] = self.policy[param]
+
+ return (200, policy)
+
+
+def register_servlets(hs, http_server):
+ PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a09189b1b4..66fc8ec179 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -373,6 +373,7 @@ class RegisterRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
+ self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_flows = _calculate_registration_flows(
@@ -420,6 +421,7 @@ class RegisterRestServlet(RestServlet):
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(body["password"])
desired_username = None
if "username" in body:
@@ -499,7 +501,10 @@ class RegisterRestServlet(RestServlet):
)
auth_result, params, session_id = await self.auth_handler.check_auth(
- self._registration_flows, body, self.hs.get_ip_from_request(request)
+ self._registration_flows,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
)
# Check that we're not trying to register a denied 3pid.
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 38952a1d27..59529707df 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet):
"""
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- version = parse_string(request, "version")
+ version = parse_string(request, "version", required=True)
room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
diff --git a/synapse/server.py b/synapse/server.py
index 1b980371de..c7ca2bda0d 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -56,6 +56,7 @@ from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
+from synapse.handlers.cas_handler import CasHandler
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
@@ -66,6 +67,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerH
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler
+from synapse.handlers.password_policy import PasswordPolicyHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
@@ -85,6 +87,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
+from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
@@ -196,9 +199,12 @@ class HomeServer(object):
"sendmail",
"registration_handler",
"account_validity_handler",
+ "cas_handler",
"saml_handler",
"event_client_serializer",
+ "password_policy_handler",
"storage",
+ "replication_streamer",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -525,6 +531,9 @@ class HomeServer(object):
def build_account_validity_handler(self):
return AccountValidityHandler(self)
+ def build_cas_handler(self):
+ return CasHandler(self)
+
def build_saml_handler(self):
from synapse.handlers.saml_handler import SamlHandler
@@ -533,9 +542,15 @@ class HomeServer(object):
def build_event_client_serializer(self):
return EventClientSerializer(self)
+ def build_password_policy_handler(self):
+ return PasswordPolicyHandler(self)
+
def build_storage(self) -> Storage:
return Storage(self, self.datastores)
+ def build_replication_streamer(self) -> ReplicationStreamer:
+ return ReplicationStreamer(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html
index bcb6bc6bb7..712b0e3980 100644
--- a/synapse/static/client/login/index.html
+++ b/synapse/static/client/login/index.html
@@ -9,7 +9,7 @@
<body onload="matrixLogin.onLoad()">
<center>
<br/>
- <h1>Log in with one of the following methods</h1>
+ <h1 id="title"></h1>
<span id="feedback" style="color: #f00"></span>
diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index 276c271bbe..debe464371 100644
--- a/synapse/static/client/login/js/login.js
+++ b/synapse/static/client/login/js/login.js
@@ -1,37 +1,41 @@
window.matrixLogin = {
endpoint: location.origin + "/_matrix/client/r0/login",
serverAcceptsPassword: false,
- serverAcceptsCas: false,
serverAcceptsSso: false,
};
+var title_pre_auth = "Log in with one of the following methods";
+var title_post_auth = "Logging in...";
+
var submitPassword = function(user, pwd) {
console.log("Logging in with password...");
+ set_title(title_post_auth);
var data = {
type: "m.login.password",
user: user,
password: pwd,
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
- show_login();
matrixLogin.onLogin(response);
}).error(errorFunc);
};
var submitToken = function(loginToken) {
console.log("Logging in with login token...");
+ set_title(title_post_auth);
var data = {
type: "m.login.token",
token: loginToken
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
- show_login();
matrixLogin.onLogin(response);
}).error(errorFunc);
};
var errorFunc = function(err) {
- show_login();
+ // We want to show the error to the user rather than redirecting immediately to the
+ // SSO portal (if SSO is the only login option), so we inhibit the redirect.
+ show_login(true);
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
@@ -45,26 +49,33 @@ var setFeedbackString = function(text) {
$("#feedback").text(text);
};
-var show_login = function() {
- $("#loading").hide();
-
+var show_login = function(inhibit_redirect) {
var this_page = window.location.origin + window.location.pathname;
$("#sso_redirect_url").val(this_page);
- if (matrixLogin.serverAcceptsPassword) {
- $("#password_flow").show();
+ // If inhibit_redirect is false, and SSO is the only supported login method, we can
+ // redirect straight to the SSO page
+ if (matrixLogin.serverAcceptsSso) {
+ if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) {
+ $("#sso_form").submit();
+ return;
+ }
+
+ // Otherwise, show the SSO form
+ $("#sso_form").show();
}
- if (matrixLogin.serverAcceptsSso) {
- $("#sso_flow").show();
- } else if (matrixLogin.serverAcceptsCas) {
- $("#sso_form").attr("action", "/_matrix/client/r0/login/cas/redirect");
- $("#sso_flow").show();
+ if (matrixLogin.serverAcceptsPassword) {
+ $("#password_flow").show();
}
- if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) {
+ if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) {
$("#no_login_types").show();
}
+
+ set_title(title_pre_auth);
+
+ $("#loading").hide();
};
var show_spinner = function() {
@@ -74,17 +85,15 @@ var show_spinner = function() {
$("#loading").show();
};
+var set_title = function(title) {
+ $("#title").text(title);
+};
var fetch_info = function(cb) {
$.get(matrixLogin.endpoint, function(response) {
var serverAcceptsPassword = false;
- var serverAcceptsCas = false;
for (var i=0; i<response.flows.length; i++) {
var flow = response.flows[i];
- if ("m.login.cas" === flow.type) {
- matrixLogin.serverAcceptsCas = true;
- console.log("Server accepts CAS");
- }
if ("m.login.sso" === flow.type) {
matrixLogin.serverAcceptsSso = true;
console.log("Server accepts SSO");
@@ -102,7 +111,7 @@ var fetch_info = function(cb) {
matrixLogin.onLoad = function() {
fetch_info(function() {
if (!try_token()) {
- show_login();
+ show_login(false);
}
});
};
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d4c44dcc75..4dc5da3fe8 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-class CacheInvalidationStore(SQLBaseStore):
+class CacheInvalidationWorkerStore(SQLBaseStore):
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+
+class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
},
)
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_updated_caches", get_all_updated_caches_txn
- )
-
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 0613b49f4a..9a1178fb39 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
+ def get_all_new_device_messages(self, last_pos, current_pos, limit):
+ """
+ Args:
+ last_pos(int):
+ current_pos(int):
+ limit(int):
+ Returns:
+ A deferred list of rows from the device inbox
+ """
+ if last_pos == current_pos:
+ return defer.succeed([])
+
+ def get_all_new_device_messages_txn(txn):
+ # We limit like this as we might have multiple rows per stream_id, and
+ # we want to make sure we always get all entries for any stream_id
+ # we return.
+ upper_pos = min(current_pos, last_pos + limit)
+ sql = (
+ "SELECT max(stream_id), user_id"
+ " FROM device_inbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY user_id"
+ )
+ txn.execute(sql, (last_pos, upper_pos))
+ rows = txn.fetchall()
+
+ sql = (
+ "SELECT max(stream_id), destination"
+ " FROM device_federation_outbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY destination"
+ )
+ txn.execute(sql, (last_pos, upper_pos))
+ rows.extend(txn)
+
+ # Order by ascending stream ordering
+ rows.sort()
+
+ return rows
+
+ return self.db.runInteraction(
+ "get_all_new_device_messages", get_all_new_device_messages_txn
+ )
+
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
-
- def get_all_new_device_messages(self, last_pos, current_pos, limit):
- """
- Args:
- last_pos(int):
- current_pos(int):
- limit(int):
- Returns:
- A deferred list of rows from the device inbox
- """
- if last_pos == current_pos:
- return defer.succeed([])
-
- def get_all_new_device_messages_txn(txn):
- # We limit like this as we might have multiple rows per stream_id, and
- # we want to make sure we always get all entries for any stream_id
- # we return.
- upper_pos = min(current_pos, last_pos + limit)
- sql = (
- "SELECT max(stream_id), user_id"
- " FROM device_inbox"
- " WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY user_id"
- )
- txn.execute(sql, (last_pos, upper_pos))
- rows = txn.fetchall()
-
- sql = (
- "SELECT max(stream_id), destination"
- " FROM device_federation_outbox"
- " WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY destination"
- )
- txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn)
-
- # Order by ascending stream ordering
- rows.sort()
-
- return rows
-
- return self.db.runInteraction(
- "get_all_new_device_messages", get_all_new_device_messages_txn
- )
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 84594cf0a9..23f4570c4b 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -146,7 +146,8 @@ class EndToEndRoomKeyStore(SQLBaseStore):
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
- "is_verified": row["is_verified"],
+ # is_verified must be returned to the client as a boolean
+ "is_verified": bool(row["is_verified"]),
"session_data": json.loads(row["session_data"]),
}
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d593ef47b8..e71c23541d 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1267,104 +1267,6 @@ class EventsStore(
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret
- def get_current_backfill_token(self):
- """The current minimum token that backfilled events have reached"""
- return -self._backfill_id_gen.get_current_token()
-
- def get_current_events_token(self):
- """The current maximum token that events have reached"""
- return self._stream_id_gen.get_current_token()
-
- def get_all_new_forward_event_rows(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_new_forward_event_rows(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- new_event_updates = txn.fetchall()
-
- if len(new_event_updates) == limit:
- upper_bound = new_event_updates[-1][0]
- else:
- upper_bound = current_id
-
- sql = (
- "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? < event_stream_ordering"
- " AND event_stream_ordering <= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_id, upper_bound))
- new_event_updates.extend(txn)
-
- return new_event_updates
-
- return self.db.runInteraction(
- "get_all_new_forward_event_rows", get_all_new_forward_event_rows
- )
-
- def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_new_backfill_event_rows(txn):
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (-last_id, -current_id, limit))
- new_event_updates = txn.fetchall()
-
- if len(new_event_updates) == limit:
- upper_bound = new_event_updates[-1][0]
- else:
- upper_bound = current_id
-
- sql = (
- "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_id, -upper_bound))
- new_event_updates.extend(txn.fetchall())
-
- return new_event_updates
-
- return self.db.runInteraction(
- "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
- )
-
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
@@ -1850,22 +1752,6 @@ class EventsStore(
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
- def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
- def get_all_updated_current_state_deltas_txn(txn):
- sql = """
- SELECT stream_id, room_id, type, state_key, event_id
- FROM current_state_delta_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC LIMIT ?
- """
- txn.execute(sql, (from_token, to_token, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_updated_current_state_deltas",
- get_all_updated_current_state_deltas_txn,
- )
-
def insert_labels_for_event_txn(
self, txn, event_id, labels, room_id, topological_ordering
):
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ca237c6f12..16ea8948b1 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -35,7 +35,7 @@ from synapse.api.room_versions import (
)
from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
@@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
- log_ctx = LoggingContext.current_context()
+ log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
# Note that _get_events_from_db is also responsible for turning db rows
@@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
complexity_v1 = round(state_events / 500, 2)
return {"v1": complexity_v1}
+
+ def get_current_backfill_token(self):
+ """The current minimum token that backfilled events have reached"""
+ return -self._backfill_id_gen.get_current_token()
+
+ def get_current_events_token(self):
+ """The current maximum token that events have reached"""
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_forward_event_rows(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? < event_stream_ordering"
+ " AND event_stream_ordering <= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (last_id, upper_bound))
+ new_event_updates.extend(txn)
+
+ return new_event_updates
+
+ return self.db.runInteraction(
+ "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+ )
+
+ def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_backfill_event_rows(txn):
+ sql = (
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (-last_id, -current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (-last_id, -upper_bound))
+ new_event_updates.extend(txn.fetchall())
+
+ return new_event_updates
+
+ return self.db.runInteraction(
+ "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+ )
+
+ def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+ def get_all_updated_current_state_deltas_txn(txn):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_current_state_deltas",
+ get_all_updated_current_state_deltas_txn,
+ )
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index e6c10c6316..aaebe427d3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
+ def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ def get_all_new_public_rooms(txn):
+ sql = """
+ SELECT stream_id, room_id, visibility, appservice_id, network_id
+ FROM public_room_list_stream
+ WHERE stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (prev_id, current_id, limit))
+ return txn.fetchall()
+
+ if prev_id == current_id:
+ return defer.succeed([])
+
+ return self.db.runInteraction(
+ "get_all_new_public_rooms", get_all_new_public_rooms
+ )
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
- def get_all_new_public_rooms(txn):
- sql = """
- SELECT stream_id, room_id, visibility, appservice_id, network_id
- FROM public_room_list_stream
- WHERE stream_id > ? AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
- """
-
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
-
- if prev_id == current_id:
- return defer.succeed([])
-
- return self.db.runInteraction(
- "get_all_new_public_rooms", get_all_new_public_rooms
- )
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e61595336c..715c0346dd 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import (
LoggingContext,
LoggingContextOrSentinel,
+ current_context,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -483,7 +484,7 @@ class Database(object):
end = monotonic_time()
duration = end - start
- LoggingContext.current_context().add_database_transaction(duration)
+ current_context().add_database_transaction(duration)
transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
@@ -510,7 +511,7 @@ class Database(object):
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
- if LoggingContext.current_context() == LoggingContext.sentinel:
+ if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
@@ -547,10 +548,8 @@ class Database(object):
Returns:
Deferred: The result of func
"""
- parent_context = (
- LoggingContext.current_context()
- ) # type: Optional[LoggingContextOrSentinel]
- if parent_context == LoggingContext.sentinel:
+ parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
+ if not parent_context:
logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
)
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 2bfeefd54e..3bc2e8b986 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,14 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import sqlite3
import struct
import threading
+import typing
from synapse.storage.engines import BaseDatabaseEngine
+if typing.TYPE_CHECKING:
+ import sqlite3 # noqa: F401
-class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
+
+class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def __init__(self, database_module, database_config):
super().__init__(database_module, database_config)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 7b18455469..ec61e14423 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -21,7 +21,7 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import LoggingContext, current_context
from synapse.metrics import InFlightGauge
logger = logging.getLogger(__name__)
@@ -106,7 +106,7 @@ class Measure(object):
raise RuntimeError("Measure() objects cannot be re-used")
self.start = self.clock.time()
- parent_context = LoggingContext.current_context()
+ parent_context = current_context()
self._logging_context = LoggingContext(
"Measure[%s]" % (self.name,), parent_context
)
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 3925927f9f..fdff195771 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -32,7 +32,7 @@ def do_patch():
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
"""
- from synapse.logging.context import LoggingContext
+ from synapse.logging.context import current_context
global _already_patched
@@ -43,35 +43,35 @@ def do_patch():
def new_inline_callbacks(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
- start_context = LoggingContext.current_context()
+ start_context = current_context()
changes = [] # type: List[str]
orig = orig_inline_callbacks(_check_yield_points(f, changes))
try:
res = orig(*args, **kwargs)
except Exception:
- if LoggingContext.current_context() != start_context:
+ if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
err = "%s changed context from %s to %s on exception" % (
f,
start_context,
- LoggingContext.current_context(),
+ current_context(),
)
print(err, file=sys.stderr)
raise Exception(err)
raise
if not isinstance(res, Deferred) or res.called:
- if LoggingContext.current_context() != start_context:
+ if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
err = "Completed %s changed context from %s to %s" % (
f,
start_context,
- LoggingContext.current_context(),
+ current_context(),
)
# print the error to stderr because otherwise all we
# see in travis-ci is the 500 error
@@ -79,23 +79,23 @@ def do_patch():
raise Exception(err)
return res
- if LoggingContext.current_context() != LoggingContext.sentinel:
+ if current_context():
err = (
"%s returned incomplete deferred in non-sentinel context "
"%s (start was %s)"
- ) % (f, LoggingContext.current_context(), start_context)
+ ) % (f, current_context(), start_context)
print(err, file=sys.stderr)
raise Exception(err)
def check_ctx(r):
- if LoggingContext.current_context() != start_context:
+ if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
err = "%s completion of %s changed context from %s to %s" % (
"Failure" if isinstance(r, Failure) else "Success",
f,
start_context,
- LoggingContext.current_context(),
+ current_context(),
)
print(err, file=sys.stderr)
raise Exception(err)
@@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
function
"""
- from synapse.logging.context import LoggingContext
+ from synapse.logging.context import current_context
@functools.wraps(f)
def check_yield_points_inner(*args, **kwargs):
@@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
last_yield_line_no = gen.gi_frame.f_lineno
result = None # type: Any
while True:
- expected_context = LoggingContext.current_context()
+ expected_context = current_context()
try:
isFailure = isinstance(result, Failure)
@@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
else:
d = gen.send(result)
except (StopIteration, defer._DefGen_Return) as e:
- if LoggingContext.current_context() != expected_context:
+ if current_context() != expected_context:
# This happens when the context is lost sometime *after* the
# final yield and returning. E.g. we forgot to yield on a
# function that returns a deferred.
@@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
% (
f.__qualname__,
expected_context,
- LoggingContext.current_context(),
+ current_context(),
f.__code__.co_filename,
last_yield_line_no,
)
@@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]):
# This happens if we yield on a deferred that doesn't follow
# the log context rules without wrapping in a `make_deferred_yieldable`.
# We raise here as this should never happen.
- if LoggingContext.current_context() is not LoggingContext.sentinel:
+ if current_context():
err = (
"%s yielded with context %s rather than sentinel,"
" yielded on line %d in %s"
% (
frame.f_code.co_name,
- LoggingContext.current_context(),
+ current_context(),
frame.f_lineno,
frame.f_code.co_filename,
)
@@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
except Exception as e:
result = Failure(e)
- if LoggingContext.current_context() != expected_context:
+ if current_context() != expected_context:
# This happens because the context is lost sometime *after* the
# previous yield and *after* the current yield. E.g. the
@@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
% (
frame.f_code.co_name,
expected_context,
- LoggingContext.current_context(),
+ current_context(),
last_yield_line_no,
frame.f_lineno,
frame.f_code.co_filename,
|