diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1951f6e178..48c4d7b0be 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,7 @@ from twisted.web.server import Request
import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -648,7 +648,8 @@ class Auth:
)
if (
visibility
- and visibility.content["history_visibility"] == "world_readable"
+ and visibility.content.get("history_visibility")
+ == HistoryVisibility.WORLD_READABLE
):
return Membership.JOIN, None
raise AuthError(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 9c227218e0..d8088f524a 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -36,6 +36,7 @@ class AuthBlocking:
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname
+ self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
async def check_auth_blocking(
self,
@@ -76,6 +77,12 @@ class AuthBlocking:
# We never block the server from doing actions on behalf of
# users.
return
+ elif requester.app_service and not self._track_appservice_user_ips:
+ # If we're authenticated as an appservice then we only block
+ # auth if `track_appservice_user_ips` is set, as that option
+ # implicitly means that application services are part of MAU
+ # limits.
+ return
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 592abd844b..565a8cd76a 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -95,6 +95,8 @@ class EventTypes:
Presence = "m.presence"
+ Dummy = "org.matrix.dummy_event"
+
class RejectedReason:
AUTH_ERROR = "auth_error"
@@ -160,3 +162,10 @@ class RoomEncryptionAlgorithms:
class AccountDataTypes:
DIRECT = "m.direct"
IGNORED_USER_LIST = "m.ignored_user_list"
+
+
+class HistoryVisibility:
+ INVITED = "invited"
+ JOINED = "joined"
+ SHARED = "shared"
+ WORLD_READABLE = "world_readable"
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index bbb7407838..8d9b53be53 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore
@@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self),
"/_synapse/admin": AdminRestResource(self),
+ "/_synapse/client/pick_username": pick_username_resource(self),
}
)
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index ed26e2fb60..29aa064e57 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -3,6 +3,7 @@ from typing import Any, Iterable, List, Optional
from synapse.config import (
api,
appservice,
+ auth,
captcha,
cas,
consent_config,
@@ -14,7 +15,6 @@ from synapse.config import (
logger,
metrics,
oidc_config,
- password,
password_auth_providers,
push,
ratelimiting,
@@ -65,7 +65,7 @@ class RootConfig:
sso: sso.SSOConfig
oidc: oidc_config.OIDCConfig
jwt: jwt_config.JWTConfig
- password: password.PasswordConfig
+ auth: auth.AuthConfig
email: emailconfig.EmailConfig
worker: workers.WorkerConfig
authproviders: password_auth_providers.PasswordAuthProviderConfig
diff --git a/synapse/config/password.py b/synapse/config/auth.py
index 9c0ea8c30a..2b3e2ce87b 100644
--- a/synapse/config/password.py
+++ b/synapse/config/auth.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,11 +17,11 @@
from ._base import Config
-class PasswordConfig(Config):
- """Password login configuration
+class AuthConfig(Config):
+ """Password and login configuration
"""
- section = "password"
+ section = "auth"
def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
@@ -35,6 +36,10 @@ class PasswordConfig(Config):
self.password_policy = password_config.get("policy") or {}
self.password_policy_enabled = self.password_policy.get("enabled", False)
+ # User-interactive authentication
+ ui_auth = config.get("ui_auth") or {}
+ self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0)
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -87,4 +92,19 @@ class PasswordConfig(Config):
# Defaults to 'false'.
#
#require_uppercase: true
+
+ ui_auth:
+ # The number of milliseconds to allow a user-interactive authentication
+ # session to be active.
+ #
+ # This defaults to 0, meaning the user is queried for their credentials
+ # before every action, but this can be overridden to alow a single
+ # validation to be re-used. This weakens the protections afforded by
+ # the user-interactive authentication process, by allowing for multiple
+ # (and potentially different) operations to use the same validation session.
+ #
+ # Uncomment below to allow for credential validation to last for 15
+ # seconds.
+ #
+ #session_timeout: 15000
"""
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 7c8b64d84b..d4328c46b9 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -322,6 +322,22 @@ class EmailConfig(Config):
self.email_subjects = EmailSubjectConfig(**subjects)
+ # The invite client location should be a HTTP(S) URL or None.
+ self.invite_client_location = email_config.get("invite_client_location") or None
+ if self.invite_client_location:
+ if not isinstance(self.invite_client_location, str):
+ raise ConfigError(
+ "Config option email.invite_client_location must be type str"
+ )
+ if not (
+ self.invite_client_location.startswith("http://")
+ or self.invite_client_location.startswith("https://")
+ ):
+ raise ConfigError(
+ "Config option email.invite_client_location must be a http or https URL",
+ path=("email", "invite_client_location"),
+ )
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return (
"""\
@@ -389,6 +405,12 @@ class EmailConfig(Config):
#
#validation_token_lifetime: 15m
+ # The web client location to direct users to during an invite. This is passed
+ # to the identity server as the org.matrix.web_client_location key. Defaults
+ # to unset, giving no guidance to the identity server.
+ #
+ #invite_client_location: https://app.element.io
+
# Directory in which Synapse will try to find the template files below.
# If not set, or the files named below are not found within the template
# directory, default templates from within the Synapse package will be used.
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index a03a419e23..9f3c57e6a1 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -56,18 +56,6 @@ class FederationConfig(Config):
# - nyc.example.com
# - syd.example.com
- # List of IP address CIDR ranges that should be allowed for federation,
- # identity servers, push servers, and for checking key validity for
- # third-party invite events. This is useful for specifying exceptions to
- # wide-ranging blacklisted target IP ranges - e.g. for communication with
- # a push server only visible in your network.
- #
- # This whitelist overrides ip_range_blacklist and defaults to an empty
- # list.
- #
- #ip_range_whitelist:
- # - '192.168.1.1'
-
# Report prometheus metrics on the age of PDUs being sent to and received from
# the following domains. This can be used to give an idea of "delay" on inbound
# and outbound federation, though be aware that any delay can be due to problems
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index d6862d9a64..7b7860ea71 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -32,5 +32,5 @@ class GroupsConfig(Config):
# If enabled, non server admins can only create groups with local parts
# starting with this prefix
#
- #group_creation_prefix: "unofficial/"
+ #group_creation_prefix: "unofficial_"
"""
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index be65554524..4bd2b3587b 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -17,6 +17,7 @@
from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
+from .auth import AuthConfig
from .cache import CacheConfig
from .captcha import CaptchaConfig
from .cas import CasConfig
@@ -30,7 +31,6 @@ from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
from .oidc_config import OIDCConfig
-from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
from .ratelimiting import RatelimitConfig
@@ -76,7 +76,7 @@ class HomeServerConfig(RootConfig):
CasConfig,
SSOConfig,
JWTConfig,
- PasswordConfig,
+ AuthConfig,
EmailConfig,
PasswordAuthProviderConfig,
PushConfig,
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 1abf8ed405..4e3055282d 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -203,9 +203,10 @@ class OIDCConfig(Config):
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token
#
- # This must be configured if using the default mapping provider.
+ # If this is not set, the user will be prompted to choose their
+ # own username.
#
- localpart_template: "{{{{ user.preferred_username }}}}"
+ #localpart_template: "{{{{ user.preferred_username }}}}"
# Jinja2 template for the display name to set on first login.
#
diff --git a/synapse/config/server.py b/synapse/config/server.py
index f3815e5add..7242a4aa8e 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -832,6 +832,18 @@ class ServerConfig(Config):
#ip_range_blacklist:
%(ip_range_blacklist)s
+ # List of IP address CIDR ranges that should be allowed for federation,
+ # identity servers, push servers, and for checking key validity for
+ # third-party invite events. This is useful for specifying exceptions to
+ # wide-ranging blacklisted target IP ranges - e.g. for communication with
+ # a push server only visible in your network.
+ #
+ # This whitelist overrides ip_range_blacklist and defaults to an empty
+ # list.
+ #
+ #ip_range_whitelist:
+ # - '192.168.1.1'
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 57fd426e87..74b67b230a 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -227,7 +227,7 @@ class ConnectionVerifier:
# This code is based on twisted.internet.ssl.ClientTLSOptions.
- def __init__(self, hostname: bytes, verify_certs):
+ def __init__(self, hostname: bytes, verify_certs: bool):
self._verify_certs = verify_certs
_decoded = hostname.decode("ascii")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 0422c43fab..8fb116ae18 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -18,7 +18,7 @@
import collections.abc
import hashlib
import logging
-from typing import Dict
+from typing import Any, Callable, Dict, Tuple
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
@@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
+from synapse.events import EventBase
from synapse.events.utils import prune_event, prune_event_dict
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+Hasher = Callable[[bytes], "hashlib._Hash"]
-def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
+
+def check_event_content_hash(
+ event: EventBase, hash_algorithm: Hasher = hashlib.sha256
+) -> bool:
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
logger.debug(
@@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
return message_hash_bytes == expected_hash
-def compute_content_hash(event_dict, hash_algorithm):
+def compute_content_hash(
+ event_dict: Dict[str, Any], hash_algorithm: Hasher
+) -> Tuple[str, bytes]:
"""Compute the content hash of an event, which is the hash of the
unredacted event.
Args:
- event_dict (dict): The unredacted event as a dict
+ event_dict: The unredacted event as a dict
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
- tuple[str, bytes]: A tuple of the name of hash and the hash as raw
- bytes.
+ A tuple of the name of hash and the hash as raw bytes.
"""
event_dict = dict(event_dict)
event_dict.pop("age_ts", None)
@@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
return hashed.name, hashed.digest()
-def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+def compute_event_reference_hash(
+ event, hash_algorithm: Hasher = hashlib.sha256
+) -> Tuple[str, bytes]:
"""Computes the event reference hash. This is the hash of the redacted
event.
Args:
- event (FrozenEvent)
+ event
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
- tuple[str, bytes]: A tuple of the name of hash and the hash as raw
- bytes.
+ A tuple of the name of hash and the hash as raw bytes.
"""
tmp_event = prune_event(event)
event_dict = tmp_event.get_pdu_json()
@@ -156,7 +163,7 @@ def add_hashes_and_signatures(
event_dict: JsonDict,
signature_name: str,
signing_key: SigningKey,
-):
+) -> None:
"""Add content hash and sign the event
Args:
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f23eacc0d7..902128a23c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
import urllib
from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from signedjson.key import (
@@ -40,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.config.key import TrustedKeyServer
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
@@ -47,11 +50,15 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -61,16 +68,17 @@ class VerifyJsonRequest:
A request to verify a JSON object.
Attributes:
- server_name(str): The name of the server to verify against.
-
- key_ids(set[str]): The set of key_ids to that could be used to verify the
- JSON object
+ server_name: The name of the server to verify against.
- json_object(dict): The JSON object to verify.
+ json_object: The JSON object to verify.
- minimum_valid_until_ts (int): time at which we require the signing key to
+ minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
+ request_name: The name of the request.
+
+ key_ids: The set of key_ids to that could be used to verify the JSON object
+
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
@@ -80,12 +88,12 @@ class VerifyJsonRequest:
errbacks with an M_UNAUTHORIZED SynapseError.
"""
- server_name = attr.ib()
- json_object = attr.ib()
- minimum_valid_until_ts = attr.ib()
- request_name = attr.ib()
- key_ids = attr.ib(init=False)
- key_ready = attr.ib(default=attr.Factory(defer.Deferred))
+ server_name = attr.ib(type=str)
+ json_object = attr.ib(type=JsonDict)
+ minimum_valid_until_ts = attr.ib(type=int)
+ request_name = attr.ib(type=str)
+ key_ids = attr.ib(init=False, type=List[str])
+ key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name)
@@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
class Keyring:
- def __init__(self, hs, key_fetchers=None):
+ def __init__(
+ self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
+ ):
self.clock = hs.get_clock()
if key_fetchers is None:
@@ -112,22 +122,26 @@ class Keyring:
# completes.
#
# These are regular, logcontext-agnostic Deferreds.
- self.key_downloads = {}
+ self.key_downloads = {} # type: Dict[str, defer.Deferred]
def verify_json_for_server(
- self, server_name, json_object, validity_time, request_name
- ):
+ self,
+ server_name: str,
+ json_object: JsonDict,
+ validity_time: int,
+ request_name: str,
+ ) -> defer.Deferred:
"""Verify that a JSON object has been signed by a given server
Args:
- server_name (str): name of the server which must have signed this object
+ server_name: name of the server which must have signed this object
- json_object (dict): object to be checked
+ json_object: object to be checked
- validity_time (int): timestamp at which we require the signing key to
+ validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care)
- request_name (str): an identifier for this json object (eg, an event id)
+ request_name: an identifier for this json object (eg, an event id)
for logging.
Returns:
@@ -138,12 +152,14 @@ class Keyring:
requests = (req,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
- def verify_json_objects_for_server(self, server_and_json):
+ def verify_json_objects_for_server(
+ self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+ ) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
- server_and_json (iterable[Tuple[str, dict, int, str]):
+ server_and_json:
Iterable of (server_name, json_object, validity_time, request_name)
tuples.
@@ -164,13 +180,14 @@ class Keyring:
for server_name, json_object, validity_time, request_name in server_and_json
)
- def _verify_objects(self, verify_requests):
+ def _verify_objects(
+ self, verify_requests: Iterable[VerifyJsonRequest]
+ ) -> List[defer.Deferred]:
"""Does the work of verify_json_[objects_]for_server
Args:
- verify_requests (iterable[VerifyJsonRequest]):
- Iterable of verification requests.
+ verify_requests: Iterable of verification requests.
Returns:
List<Deferred[None]>: for each input item, a deferred indicating success
@@ -182,7 +199,7 @@ class Keyring:
key_lookups = []
handle = preserve_fn(_handle_key_deferred)
- def process(verify_request):
+ def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
"""Process an entry in the request list
Adds a key request to key_lookups, and returns a deferred which
@@ -222,18 +239,20 @@ class Keyring:
return results
- async def _start_key_lookups(self, verify_requests):
+ async def _start_key_lookups(
+ self, verify_requests: List[VerifyJsonRequest]
+ ) -> None:
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
Args:
- verify_requests (List[VerifyJsonRequest]):
+ verify_requests:
"""
try:
# map from server name to a set of outstanding request ids
- server_to_request_ids = {}
+ server_to_request_ids = {} # type: Dict[str, Set[int]]
for verify_request in verify_requests:
server_name = verify_request.server_name
@@ -275,11 +294,11 @@ class Keyring:
except Exception:
logger.exception("Error starting key lookups")
- async def wait_for_previous_lookups(self, server_names) -> None:
+ async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_names (Iterable[str]): list of servers which we want to look up
+ server_names: list of servers which we want to look up
Returns:
Resolves once all key lookups for the given servers have
@@ -304,7 +323,7 @@ class Keyring:
loop_count += 1
- def _get_server_verify_keys(self, verify_requests):
+ def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.key_ready is called back with
@@ -312,7 +331,7 @@ class Keyring:
with a SynapseError if none of the keys are found.
Args:
- verify_requests (list[VerifyJsonRequest]): list of verify requests
+ verify_requests: list of verify requests
"""
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@@ -366,17 +385,19 @@ class Keyring:
run_in_background(do_iterations)
- async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ async def _attempt_key_fetches_with_fetcher(
+ self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
+ ):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
- fetcher (KeyFetcher): fetcher to use to fetch the keys
- remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
+ fetcher: fetcher to use to fetch the keys
+ remaining_requests: outstanding key requests.
Any successfully-completed requests will be removed from the list.
"""
- # dict[str, dict[str, int]]: keys to fetch.
+ # The keys to fetch.
# server_name -> key_id -> min_valid_ts
- missing_keys = defaultdict(dict)
+ missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
for verify_request in remaining_requests:
# any completed requests should already have been removed
@@ -438,16 +459,18 @@ class Keyring:
remaining_requests.difference_update(completed)
-class KeyFetcher:
- async def get_keys(self, keys_to_fetch):
+class KeyFetcher(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, dict[str, int]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
- Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
- map from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
"""
raise NotImplementedError
@@ -455,31 +478,35 @@ class KeyFetcher:
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
- keys_to_fetch = (
+ key_ids_to_fetch = (
(server_name, key_id)
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
)
- res = await self.store.get_server_verify_keys(keys_to_fetch)
- keys = {}
+ res = await self.store.get_server_verify_keys(key_ids_to_fetch)
+ keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
return keys
-class BaseV2KeyFetcher:
- def __init__(self, hs):
+class BaseV2KeyFetcher(KeyFetcher):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.config = hs.get_config()
- async def process_v2_response(self, from_server, response_json, time_added_ms):
+ async def process_v2_response(
+ self, from_server: str, response_json: JsonDict, time_added_ms: int
+ ) -> Dict[str, FetchKeyResult]:
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
@@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
to /_matrix/key/v2/query.
Args:
- from_server (str): the name of the server producing this result: either
+ from_server: the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
- response_json (dict): the json-decoded Server Keys response object
+ response_json: the json-decoded Server Keys response object
- time_added_ms (int): the timestamp to record in server_keys_json
+ time_added_ms: the timestamp to record in server_keys_json
Returns:
- Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ Map from key_id to result object
"""
ts_valid_until_ms = response_json["valid_until_ts"]
@@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
- async def get_key(key_server):
+ async def get_key(key_server: TrustedKeyServer) -> Dict:
try:
- result = await self.get_server_verify_key_v2_indirect(
+ return await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
- return result
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError)
)
- union_of_keys = {}
+ union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
return union_of_keys
- async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+ async def get_server_verify_key_v2_indirect(
+ self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, dict[str, int]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
- key_server (synapse.config.key.TrustedKeyServer): notary server to query for
- the keys
+ key_server: notary server to query for the keys
Returns:
- dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
- from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
Raises:
KeyLookupError if there was an error processing the entire response from
@@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
- keys = {}
- added_keys = []
+ keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
+ added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
time_now_ms = self.clock.time_msec()
+ assert isinstance(query_response, dict)
for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter
server_name = response.get("server_name")
@@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return keys
- def _validate_perspectives_response(self, key_server, response):
+ def _validate_perspectives_response(
+ self, key_server: TrustedKeyServer, response: JsonDict
+ ) -> None:
"""Optionally check the signature on the result of a /key/query request
Args:
- key_server (synapse.config.key.TrustedKeyServer): the notary server that
- produced this result
+ key_server: the notary server that produced this result
- response (dict): the json-decoded Server Keys response object
+ response: the json-decoded Server Keys response object
"""
perspective_name = key_server.server_name
perspective_keys = key_server.verify_keys
@@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, iterable[str]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_ids
Returns:
- dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
- map from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
"""
results = {}
- async def get_key(key_to_fetch_item):
+ async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
server_name, key_ids = key_to_fetch_item
try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
@@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
await yieldable_gather_results(get_key, keys_to_fetch.items())
return results
- async def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ async def get_server_verify_key_v2_direct(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, FetchKeyResult]:
"""
Args:
- server_name (str):
- key_ids (iterable[str]):
+ server_name:
+ key_ids:
Returns:
- dict[str, FetchKeyResult]: map from key ID to lookup result
+ Map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
"""
- keys = {} # type: dict[str, FetchKeyResult]
+ keys = {} # type: Dict[str, FetchKeyResult]
for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
@@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
+ assert isinstance(response, dict)
if response["server_name"] != server_name:
raise KeyLookupError(
"Expected a response for server %r not %r"
@@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys
-async def _handle_key_deferred(verify_request) -> None:
+async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
- verify_request (VerifyJsonRequest):
+ verify_request:
Raises:
SynapseError if there was a problem performing the verification
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 434718ddfc..cfd094e58f 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -144,7 +144,7 @@ class Authenticator:
):
raise FederationDeniedError(origin)
- if not json_request["signatures"]:
+ if origin is None or not json_request["signatures"]:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a703944543..37e63da9b1 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,27 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import List
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from synapse.api.constants import Membership
-from synapse.events import FrozenEvent
-from synapse.types import RoomStreamToken, StateMap
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.storage = hs.get_storage()
self.state_store = self.storage.state
- async def get_whois(self, user):
+ async def get_whois(self, user: UserID) -> JsonDict:
connections = []
sessions = await self.store.get_user_ip_and_agents(user)
@@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
return ret
- async def get_user(self, user):
+ async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
ret = await self.store.get_user_by_id(user.to_string())
if ret:
@@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
ret["threepids"] = threepids
return ret
- async def export_user_data(self, user_id, writer):
+ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
"""Write all data we have on the user to the given writer.
Args:
- user_id (str)
- writer (ExfiltrationWriter)
+ user_id: The user ID to fetch data of.
+ writer: The writer to write to.
Returns:
Resolves when all data for a user has been written.
@@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
- written_events = set() # Events that we've processed in this room
+ # Events that we've processed in this room
+ written_events = set() # type: Set[str]
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
@@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
- # events "children". dict[str, set[str]]
- unseen_to_child_events = {}
+ # events "children".
+ unseen_to_child_events = {} # type: Dict[str, Set[str]]
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
@@ -197,38 +202,46 @@ class AdminHandler(BaseHandler):
return writer.finished()
-class ExfiltrationWriter:
+class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data.
"""
- def write_events(self, room_id: str, events: List[FrozenEvent]):
+ @abc.abstractmethod
+ def write_events(self, room_id: str, events: List[EventBase]) -> None:
"""Write a batch of events for a room.
"""
- pass
+ raise NotImplementedError()
- def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
+ @abc.abstractmethod
+ def write_state(
+ self, room_id: str, event_id: str, state: StateMap[EventBase]
+ ) -> None:
"""Write the state at the given event in the room.
This only gets called for backward extremities rather than for each
event.
"""
- pass
+ raise NotImplementedError()
- def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
+ @abc.abstractmethod
+ def write_invite(
+ self, room_id: str, event: EventBase, state: StateMap[dict]
+ ) -> None:
"""Write an invite for the room, with associated invite state.
Args:
- room_id
- event
- state: A subset of the state at the
- invite, with a subset of the event keys (type, state_key
- content and sender)
+ room_id: The room ID the invite is for.
+ event: The invite event.
+ state: A subset of the state at the invite, with a subset of the
+ event keys (type, state_key content and sender).
"""
+ raise NotImplementedError()
- def finished(self):
+ @abc.abstractmethod
+ def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
This functions return value is passed to the caller of
`export_user_data`.
"""
- pass
+ raise NotImplementedError()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 21e568f226..f4434673dc 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -226,6 +226,9 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # The number of seconds to keep a UI auth session active.
+ self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
+
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -283,7 +286,7 @@ class AuthHandler(BaseHandler):
request_body: Dict[str, Any],
clientip: str,
description: str,
- ) -> Tuple[dict, str]:
+ ) -> Tuple[dict, Optional[str]]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -310,7 +313,8 @@ class AuthHandler(BaseHandler):
have been given only in a previous call).
'session_id' is the ID of this session, either passed in by the
- client or assigned by this call
+ client or assigned by this call. This is None if UI auth was
+ skipped (by re-using a previous validation).
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
@@ -324,6 +328,16 @@ class AuthHandler(BaseHandler):
"""
+ if self._ui_auth_session_timeout:
+ last_validated = await self.store.get_access_token_last_validated(
+ requester.access_token_id
+ )
+ if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout:
+ # Return the input parameters, minus the auth key, which matches
+ # the logic in check_ui_auth.
+ request_body.pop("auth", None)
+ return request_body, None
+
user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
@@ -359,6 +373,9 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
+ # Note that the access token has been validated.
+ await self.store.update_access_token_last_validated(requester.access_token_id)
+
return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
@@ -452,13 +469,10 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows.
"""
- authdict = None
sid = None # type: Optional[str]
- if clientdict and "auth" in clientdict:
- authdict = clientdict["auth"]
- del clientdict["auth"]
- if "session" in authdict:
- sid = authdict["session"]
+ authdict = clientdict.pop("auth", {})
+ if "session" in authdict:
+ sid = authdict["session"]
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
@@ -563,6 +577,8 @@ class AuthHandler(BaseHandler):
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
+ # If all the required credentials have been supplied, the user has
+ # successfully completed the UI auth process!
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log
@@ -738,6 +754,7 @@ class AuthHandler(BaseHandler):
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
+ is_appservice_ghost: bool = False,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -754,6 +771,7 @@ class AuthHandler(BaseHandler):
we should always have a device ID)
valid_until_ms: when the token is valid until. None for
no expiry.
+ is_appservice_ghost: Whether the user is an application ghost user
Returns:
The access token for the user's session.
Raises:
@@ -774,7 +792,11 @@ class AuthHandler(BaseHandler):
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
)
- await self.auth.check_auth_blocking(user_id)
+ if (
+ not is_appservice_ghost
+ or self.hs.config.appservice.track_appservice_user_ips
+ ):
+ await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f4ea0a9767..fca210a5a6 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET
+import attr
+
from twisted.web.client import PartialDownloadError
-from synapse.api.errors import Codes, LoginError
+from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -29,6 +32,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class CasError(Exception):
+ """Used to catch errors when validating the CAS ticket.
+ """
+
+ def __init__(self, error, error_description=None):
+ self.error = error
+ self.error_description = error_description
+
+ def __str__(self):
+ if self.error_description:
+ return "{}: {}".format(self.error, self.error_description)
+ return self.error
+
+
+@attr.s(slots=True, frozen=True)
+class CasResponse:
+ username = attr.ib(type=str)
+ attributes = attr.ib(type=Dict[str, Optional[str]])
+
+
class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
@@ -40,6 +63,7 @@ class CasHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
+ self._store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -50,6 +74,11 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client()
+ # identifier for the external_ids table
+ self._auth_provider_id = "cas"
+
+ self._sso_handler = hs.get_sso_handler()
+
def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
@@ -69,14 +98,20 @@ class CasHandler:
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
- ) -> Tuple[str, Optional[str]]:
+ ) -> CasResponse:
"""
- Validate a CAS ticket with the server, parse the response, and return the user and display name.
+ Validate a CAS ticket with the server, and return the parsed the response.
Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`.
+
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
+ Returns:
+ The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
@@ -89,66 +124,65 @@ class CasHandler:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
+ except HttpResponseException as e:
+ description = (
+ (
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=e.code),
+ )
+ raise CasError("server_error", description) from e
- user, attributes = self._parse_cas_response(body)
- displayname = attributes.pop(self._cas_displayname_attribute, None)
-
- for required_attribute, required_value in self._cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- return user, displayname
+ return self._parse_cas_response(body)
- def _parse_cas_response(
- self, cas_response_body: bytes
- ) -> Tuple[str, Dict[str, Optional[str]]]:
+ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
Returns:
- A tuple of the user and a mapping of other attributes.
+ The parsed CAS response.
"""
+
+ # Ensure the response is valid.
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise CasError(
+ "missing_service_response",
+ "root of CAS response is not serviceResponse",
+ )
+
+ success = root[0].tag.endswith("authenticationSuccess")
+ if not success:
+ raise CasError("unsucessful_response", "Unsuccessful CAS response")
+
+ # Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.exception("Error parsing CAS response")
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
- )
- return user, attributes
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes[tag] = attribute.text
+
+ # Ensure a user was found.
+ if user is None:
+ raise CasError("no_user", "CAS response does not contain user")
+
+ return CasResponse(user, attributes)
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
"""
@@ -201,59 +235,150 @@ class CasHandler:
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
- username, user_display_name = await self._validate_ticket(ticket, args)
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
- # Get the matrix ID from the CAS username.
- user_id = await self._map_cas_user_to_matrix_user(
- username, user_display_name, user_agent, ip_address
+ try:
+ cas_response = await self._validate_ticket(ticket, args)
+ except CasError as e:
+ logger.exception("Could not validate ticket")
+ self._sso_handler.render_error(request, e.error, e.error_description, 401)
+ return
+
+ await self._handle_cas_response(
+ request, cas_response, client_redirect_url, session
)
+ async def _handle_cas_response(
+ self,
+ request: SynapseRequest,
+ cas_response: CasResponse,
+ client_redirect_url: Optional[str],
+ session: Optional[str],
+ ) -> None:
+ """Handle a CAS response to a ticket request.
+
+ Assumes that the response has been validated. Maps the user onto an MXID,
+ registering them if necessary, and returns a response to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+
+ cas_response: The parsed CAS response.
+
+ client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+ session: The session parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the UI Auth session id.
+ """
+
+ # first check if we're doing a UIA
if session:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, session, request,
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id, cas_response.username, session, request,
)
- else:
- # If this not a UI auth request than there must be a redirect URL.
- assert client_redirect_url
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
- )
+ # otherwise, we're handling a login request.
+
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ for required_attribute, required_value in self._cas_required_attributes.items():
+ # If required attribute was not in CAS Response - Forbidden
+ if required_attribute not in cas_response.attributes:
+ self._sso_handler.render_error(
+ request,
+ "unauthorised",
+ "You are not authorised to log in here.",
+ 401,
+ )
+ return
+
+ # Also need to check value
+ if required_value is not None:
+ actual_value = cas_response.attributes[required_attribute]
+ # If required attribute value does not match expected - Forbidden
+ if required_value != actual_value:
+ self._sso_handler.render_error(
+ request,
+ "unauthorised",
+ "You are not authorised to log in here.",
+ 401,
+ )
+ return
+
+ # Call the mapper to register/login the user
+
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url is not None
+
+ try:
+ await self._complete_cas_login(cas_response, request, client_redirect_url)
+ except MappingException as e:
+ logger.exception("Could not map user")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
- async def _map_cas_user_to_matrix_user(
+ async def _complete_cas_login(
self,
- remote_user_id: str,
- display_name: Optional[str],
- user_agent: str,
- ip_address: str,
- ) -> str:
+ cas_response: CasResponse,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ) -> None:
"""
- Given a CAS username, retrieve the user ID for it and possibly register the user.
+ Given a CAS response, complete the login flow
- Args:
- remote_user_id: The username from the CAS response.
- display_name: The display name from the CAS response.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
- Returns:
- The user ID associated with this response.
+ Args:
+ cas_response: The parsed CAS response.
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
"""
+ # Note that CAS does not support a mapping provider, so the logic is hard-coded.
+ localpart = map_username_to_mxid_localpart(cas_response.username)
+
+ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Map from CAS attributes to user attributes.
+ """
+ # Due to the grandfathering logic matching any previously registered
+ # mxids it isn't expected for there to be any failures.
+ if failures:
+ raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+ display_name = cas_response.attributes.get(
+ self._cas_displayname_attribute, None
+ )
- localpart = map_username_to_mxid_localpart(remote_user_id)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ return UserAttributes(localpart=localpart, display_name=display_name)
- # If the user does not exist, register it.
- if not registered_user_id:
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=display_name,
- user_agent_ips=[(user_agent, ip_address)],
+ async def grandfather_existing_users() -> Optional[str]:
+ # Since CAS did not always use the user_external_ids table, always
+ # to attempt to map to existing users.
+ user_id = UserID(localpart, self._hostname).to_string()
+
+ logger.debug(
+ "Looking for existing account based on mapped %s", user_id,
)
- return registered_user_id
+ users = await self._store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ registered_user_id = list(users.keys())[0]
+ logger.info("Grandfathering mapping to %s", registered_user_id)
+ return registered_user_id
+
+ return None
+
+ await self._sso_handler.complete_sso_login_request(
+ self._auth_provider_id,
+ cas_response.username,
+ request,
+ client_redirect_url,
+ cas_response_to_user_attributes,
+ grandfather_existing_users,
+ )
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index abd8d2af44..df29edeb83 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -29,7 +29,7 @@ def _create_rerouter(func_name):
async def f(self, group_id, *args, **kwargs):
if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
if self.is_mine_id(group_id):
return await getattr(self.groups_server_handler, func_name)(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 7301c24710..c05036ad1f 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -55,6 +55,8 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
+ self._web_client_location = hs.config.invite_client_location
+
async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
@@ -803,6 +805,9 @@ class IdentityHandler(BaseHandler):
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
+ # If a custom web client location is available, include it in the request.
+ if self._web_client_location:
+ invite_config["org.matrix.web_client_location"] = self._web_client_location
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index cb11754bf8..fbd8df9dcc 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
- room_state = await self.state_store.get_state_for_events([member_event_id])
-
- room_state = room_state[member_event_id]
+ room_state = await self.state_store.get_state_for_event(member_event_id)
limit = pagin_config.limit if pagin_config else None
if limit is None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index cbac43c536..97c4b1f262 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1261,7 +1261,7 @@ class EventCreationHandler:
event, context = await self.create_event(
requester,
{
- "type": "org.matrix.dummy_event",
+ "type": EventTypes.Dummy,
"content": {},
"room_id": room_id,
"sender": user_id,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index f626117f76..709f8dfc13 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -115,8 +115,6 @@ class OidcHandler(BaseHandler):
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client()
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
@@ -689,33 +687,14 @@ class OidcHandler(BaseHandler):
# otherwise, it's a login
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
# Call the mapper to register/login the user
try:
- user_id = await self._map_userinfo_to_user(
- userinfo, token, user_agent, ip_address
+ await self._complete_oidc_login(
+ userinfo, token, request, client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
-
- # Mapping providers might not have get_extra_attributes: only call this
- # method if it exists.
- extra_attributes = None
- get_extra_attributes = getattr(
- self._user_mapping_provider, "get_extra_attributes", None
- )
- if get_extra_attributes:
- extra_attributes = await get_extra_attributes(userinfo, token)
-
- # and finally complete the login
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url, extra_attributes
- )
def _generate_oidc_session_token(
self,
@@ -838,10 +817,14 @@ class OidcHandler(BaseHandler):
now = self.clock.time_msec()
return now < expiry
- async def _map_userinfo_to_user(
- self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
- ) -> str:
- """Maps a UserInfo object to a mxid.
+ async def _complete_oidc_login(
+ self,
+ userinfo: UserInfo,
+ token: Token,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ) -> None:
+ """Given a UserInfo response, complete the login flow
UserInfo should have a claim that uniquely identifies users. This claim
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
@@ -853,17 +836,16 @@ class OidcHandler(BaseHandler):
If a user already exists with the mxid we've mapped and allow_existing_users
is disabled, raise an exception.
+ Otherwise, render a redirect back to the client_redirect_url with a loginToken.
+
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
Raises:
MappingException: if there was an error while mapping some properties
-
- Returns:
- The mxid of the user
"""
try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
@@ -931,13 +913,23 @@ class OidcHandler(BaseHandler):
return None
- return await self._sso_handler.get_mxid_from_sso(
+ # Mapping providers might not have get_extra_attributes: only call this
+ # method if it exists.
+ extra_attributes = None
+ get_extra_attributes = getattr(
+ self._user_mapping_provider, "get_extra_attributes", None
+ )
+ if get_extra_attributes:
+ extra_attributes = await get_extra_attributes(userinfo, token)
+
+ await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
remote_user_id,
- user_agent,
- ip_address,
+ request,
+ client_redirect_url,
oidc_response_to_user_attributes,
grandfather_existing_users,
+ extra_attributes,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
@@ -955,7 +947,7 @@ class OidcHandler(BaseHandler):
UserAttributeDict = TypedDict(
- "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+ "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -1036,10 +1028,10 @@ env = Environment(finalize=jinja_finalize)
@attr.s
class JinjaOidcMappingConfig:
- subject_claim = attr.ib() # type: str
- localpart_template = attr.ib() # type: Template
- display_name_template = attr.ib() # type: Optional[Template]
- extra_attributes = attr.ib() # type: Dict[str, Template]
+ subject_claim = attr.ib(type=str)
+ localpart_template = attr.ib(type=Optional[Template])
+ display_name_template = attr.ib(type=Optional[Template])
+ extra_attributes = attr.ib(type=Dict[str, Template])
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1055,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
- if "localpart_template" not in config:
- raise ConfigError(
- "missing key: oidc_config.user_mapping_provider.config.localpart_template"
- )
-
- try:
- localpart_template = env.from_string(config["localpart_template"])
- except Exception as e:
- raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
- % (e,)
- )
+ localpart_template = None # type: Optional[Template]
+ if "localpart_template" in config:
+ try:
+ localpart_template = env.from_string(config["localpart_template"])
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template", path=["localpart_template"]
+ ) from e
display_name_template = None # type: Optional[Template]
if "display_name_template" in config:
@@ -1074,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
display_name_template = env.from_string(config["display_name_template"])
except Exception as e:
raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
- % (e,)
- )
+ "invalid jinja template", path=["display_name_template"]
+ ) from e
extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config:
extra_attributes_config = config.get("extra_attributes") or {}
if not isinstance(extra_attributes_config, dict):
- raise ConfigError(
- "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
- )
+ raise ConfigError("must be a dict", path=["extra_attributes"])
for key, value in extra_attributes_config.items():
try:
extra_attributes[key] = env.from_string(value)
except Exception as e:
raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
- % (key, e)
- )
+ "invalid jinja template", path=["extra_attributes", key]
+ ) from e
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
@@ -1108,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
- localpart = self._config.localpart_template.render(user=userinfo).strip()
+ localpart = None
+
+ if self._config.localpart_template:
+ localpart = self._config.localpart_template.render(user=userinfo).strip()
- # Ensure only valid characters are included in the MXID.
- localpart = map_username_to_mxid_localpart(localpart)
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
- # Append suffix integer if last call to this function failed to produce
- # a usable mxid.
- localpart += str(failures) if failures else ""
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e850e45e46..a9abdf42e0 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.config.server_name
@@ -36,7 +39,7 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- async def _received_remote_receipt(self, origin, content):
+ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = []
@@ -63,11 +66,11 @@ class ReceiptsHandler(BaseHandler):
await self._handle_new_receipts(receipts)
- async def _handle_new_receipts(self, receipts):
+ async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier.
"""
- min_batch_id = None
- max_batch_id = None
+ min_batch_id = None # type: Optional[int]
+ max_batch_id = None # type: Optional[int]
for receipt in receipts:
res = await self.store.insert_receipt(
@@ -89,7 +92,8 @@ class ReceiptsHandler(BaseHandler):
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
- if min_batch_id is None:
+ # Either both of these should be None or neither.
+ if min_batch_id is None or max_batch_id is None:
# no new receipts
return False
@@ -103,7 +107,9 @@ class ReceiptsHandler(BaseHandler):
return True
- async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+ async def received_client_receipt(
+ self, room_id: str, receipt_type: str, user_id: str, event_id: str
+ ) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -123,10 +129,12 @@ class ReceiptsHandler(BaseHandler):
class ReceiptEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(
+ self, from_key: int, room_ids: List[str], **kwargs
+ ) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()
@@ -171,5 +179,5 @@ class ReceiptEventSource:
return (events, to_key)
- def get_current_key(self, direction="f"):
+ def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 94b5610acd..a2cf0f6f3e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -630,6 +630,7 @@ class RegistrationHandler(BaseHandler):
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
+ is_appservice_ghost: bool = False,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@@ -651,6 +652,7 @@ class RegistrationHandler(BaseHandler):
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
+ is_appservice_ghost=is_appservice_ghost,
)
return r["device_id"], r["access_token"]
@@ -672,7 +674,10 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
+ user_id,
+ device_id=registered_device_id,
+ valid_until_ms=valid_until_ms,
+ is_appservice_ghost=is_appservice_ghost,
)
return (registered_device_id, access_token)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7583418946..1f809fa161 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
+ HistoryVisibility,
JoinRules,
Membership,
RoomCreationPreset,
@@ -81,21 +82,21 @@ class RoomCreationHandler(BaseHandler):
self._presets_dict = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False,
"guest_can_join": True,
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": True,
"guest_can_join": True,
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False,
"guest_can_join": False,
"power_level_content_override": {},
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 9dedb9a4b3..a2c0340a3c 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,19 +15,22 @@
import logging
from collections import namedtuple
-from typing import Any, Dict, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.api.errors import Codes, HttpResponseException
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
@@ -37,38 +40,39 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
- self.response_cache = ResponseCache(hs, "room_list", timeout_ms=10 * 60 * 1000)
+ self.response_cache = ResponseCache(
+ hs, "room_list", timeout_ms=10 * 60 * 1000
+ ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
hs, "remote_room_list", timeout_ms=30 * 1000
- )
+ ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(
self,
- limit=None,
- since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False,
- ):
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ from_federation: bool = False,
+ ) -> JsonDict:
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
party network. A client can ask for a specific list or to return all.
Args:
- limit (int|None)
- since_token (str|None)
- search_filter (dict|None)
- network_tuple (ThirdPartyInstanceID): Which public list to use.
+ limit
+ since_token
+ search_filter
+ network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
- from_federation (bool): true iff the request comes from the federation
- API
+ from_federation: true iff the request comes from the federation API
"""
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -108,10 +112,10 @@ class RoomListHandler(BaseHandler):
self,
limit: Optional[int] = None,
since_token: Optional[str] = None,
- search_filter: Optional[Dict] = None,
+ search_filter: Optional[dict] = None,
network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
- ) -> Dict[str, Any]:
+ ) -> JsonDict:
"""Generate a public room list.
Args:
limit: Maximum amount of rooms to return.
@@ -132,13 +136,17 @@ class RoomListHandler(BaseHandler):
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
- bounds = (batch_token.last_joined_members, batch_token.last_room_id)
+ bounds = (
+ batch_token.last_joined_members,
+ batch_token.last_room_id,
+ ) # type: Optional[Tuple[int, str]]
forwards = batch_token.direction_is_forward
+ has_batch_token = True
else:
- batch_token = None
bounds = None
forwards = True
+ has_batch_token = False
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
@@ -160,7 +168,8 @@ class RoomListHandler(BaseHandler):
"canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"],
"avatar_url": room["avatar"],
- "world_readable": room["history_visibility"] == "world_readable",
+ "world_readable": room["history_visibility"]
+ == HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
}
@@ -169,7 +178,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results]
- response = {}
+ response = {} # type: JsonDict
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
@@ -187,7 +196,7 @@ class RoomListHandler(BaseHandler):
initial_entry = results[0]
if forwards:
- if batch_token:
+ if has_batch_token:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
@@ -203,7 +212,7 @@ class RoomListHandler(BaseHandler):
direction_is_forward=True,
).to_token()
else:
- if batch_token:
+ if has_batch_token:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
@@ -293,7 +302,7 @@ class RoomListHandler(BaseHandler):
return None
# Return whether this room is open to federation users or not
- create_event = current_state.get((EventTypes.Create, ""))
+ create_event = current_state[EventTypes.Create, ""]
result["m.federate"] = create_event.content.get("m.federate", True)
name_event = current_state.get((EventTypes.Name, ""))
@@ -318,7 +327,7 @@ class RoomListHandler(BaseHandler):
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
- result["world_readable"] = visibility == "world_readable"
+ result["world_readable"] = visibility == HistoryVisibility.WORLD_READABLE
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
@@ -336,13 +345,13 @@ class RoomListHandler(BaseHandler):
async def get_remote_public_room_list(
self,
- server_name,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
- ):
+ server_name: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
+ ) -> JsonDict:
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -399,13 +408,13 @@ class RoomListHandler(BaseHandler):
async def _get_remote_list_cached(
self,
- server_name,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
- ):
+ server_name: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
+ ) -> JsonDict:
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
@@ -456,24 +465,24 @@ class RoomListNextBatch(
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod
- def from_token(cls, token):
+ def from_token(cls, token: str) -> "RoomListNextBatch":
decoded = msgpack.loads(decode_base64(token), raw=False)
return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
)
- def to_token(self):
+ def to_token(self) -> str:
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in self._asdict().items()}
)
)
- def copy_and_replace(self, **kwds):
+ def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
return self._replace(**kwds)
-def _matches_room_entry(room_entry, search_filter):
+def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper():
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index f2ca1ddb53..5fa7ab3f8b 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -58,8 +58,6 @@ class SamlHandler(BaseHandler):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
@@ -163,6 +161,29 @@ class SamlHandler(BaseHandler):
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+ await self._handle_authn_response(request, saml2_auth, relay_state)
+
+ async def _handle_authn_response(
+ self,
+ request: SynapseRequest,
+ saml2_auth: saml2.response.AuthnResponse,
+ relay_state: str,
+ ) -> None:
+ """Handle an AuthnResponse, having parsed it from the request params
+
+ Assumes that the signature on the response object has been checked. Maps
+ the user onto an MXID, registering them if necessary, and returns a response
+ to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+ saml2_auth: the parsed AuthnResponse object
+ relay_state: the RelayState query param, which encodes the URI to rediret
+ back to
+ """
+
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
@@ -206,40 +227,29 @@ class SamlHandler(BaseHandler):
)
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
# Call the mapper to register/login the user
try:
- user_id = await self._map_saml_response_to_user(
- saml2_auth, relay_state, user_agent, ip_address
- )
+ await self._complete_saml_login(saml2_auth, request, relay_state)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
-
- await self._auth_handler.complete_sso_login(user_id, request, relay_state)
- async def _map_saml_response_to_user(
+ async def _complete_saml_login(
self,
saml2_auth: saml2.response.AuthnResponse,
+ request: SynapseRequest,
client_redirect_url: str,
- user_agent: str,
- ip_address: str,
- ) -> str:
+ ) -> None:
"""
- Given a SAML response, retrieve the user ID for it and possibly register the user.
+ Given a SAML response, complete the login flow
+
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
Args:
saml2_auth: The parsed SAML2 response.
+ request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
-
- Returns:
- The user ID associated with this response.
Raises:
MappingException if there was a problem mapping the response to a user.
@@ -295,11 +305,11 @@ class SamlHandler(BaseHandler):
return None
- return await self._sso_handler.get_mxid_from_sso(
+ await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
remote_user_id,
- user_agent,
- ip_address,
+ request,
+ client_redirect_url,
saml_response_to_remapped_user_attributes,
grandfather_existing_users,
)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 112a7d5b2c..33cd6bc178 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -13,16 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
import attr
+from typing_extensions import NoReturn
from twisted.web.http import Request
-from synapse.api.errors import RedirectException
+from synapse.api.errors import RedirectException, SynapseError
from synapse.http.server import respond_with_html
-from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -39,16 +42,52 @@ class MappingException(Exception):
@attr.s
class UserAttributes:
- localpart = attr.ib(type=str)
+ # the localpart of the mxid that the mapper has assigned to the user.
+ # if `None`, the mapper has not picked a userid, and the user should be prompted to
+ # enter one.
+ localpart = attr.ib(type=Optional[str])
display_name = attr.ib(type=Optional[str], default=None)
emails = attr.ib(type=List[str], default=attr.Factory(list))
+@attr.s(slots=True)
+class UsernameMappingSession:
+ """Data we track about SSO sessions"""
+
+ # A unique identifier for this SSO provider, e.g. "oidc" or "saml".
+ auth_provider_id = attr.ib(type=str)
+
+ # user ID on the IdP server
+ remote_user_id = attr.ib(type=str)
+
+ # attributes returned by the ID mapper
+ display_name = attr.ib(type=Optional[str])
+ emails = attr.ib(type=List[str])
+
+ # An optional dictionary of extra attributes to be provided to the client in the
+ # login response.
+ extra_login_attributes = attr.ib(type=Optional[JsonDict])
+
+ # where to redirect the client back to
+ client_redirect_url = attr.ib(type=str)
+
+ # expiry time for the session, in milliseconds
+ expiry_time_ms = attr.ib(type=int)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
+ # the time a UsernameMappingSession remains valid for
+ _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
self._store = hs.get_datastore()
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
@@ -58,8 +97,15 @@ class SsoHandler:
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
+ # a map from session id to session data
+ self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
+
def render_error(
- self, request, error: str, error_description: Optional[str] = None
+ self,
+ request: Request,
+ error: str,
+ error_description: Optional[str] = None,
+ code: int = 400,
) -> None:
"""Renders the error template and responds with it.
@@ -71,11 +117,12 @@ class SsoHandler:
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
+ code: The integer error code (an HTTP response code)
"""
html = self._error_template.render(
error=error, error_description=error_description
)
- respond_with_html(request, 400, html)
+ respond_with_html(request, code, html)
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
@@ -119,15 +166,16 @@ class SsoHandler:
# No match.
return None
- async def get_mxid_from_sso(
+ async def complete_sso_login_request(
self,
auth_provider_id: str,
remote_user_id: str,
- user_agent: str,
- ip_address: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
- grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
- ) -> str:
+ grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
+ extra_login_attributes: Optional[JsonDict] = None,
+ ) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -146,12 +194,18 @@ class SsoHandler:
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
+ Finally, we generate a redirect to the supplied redirect uri, with a login token
+
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
+
remote_user_id: The unique identifier from the SSO provider.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+
+ request: The request to respond to
+
+ client_redirect_url: The redirect URL passed in by the client.
+
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
@@ -163,12 +217,13 @@ class SsoHandler:
to the user.
RedirectException to redirect to an additional page (e.g.
to prompt the user for more information).
+
grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned
matrix ID.
- Returns:
- The user ID associated with the SSO response.
+ extra_login_attributes: An optional dictionary of extra
+ attributes to be provided to the client in the login response.
Raises:
MappingException if there was a problem mapping the response to a user.
@@ -181,28 +236,45 @@ class SsoHandler:
# interstitial pages.
with await self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
)
- if previously_registered_user_id:
- return previously_registered_user_id
# Check for grandfathering of users.
- if grandfather_existing_users:
- previously_registered_user_id = await grandfather_existing_users()
- if previously_registered_user_id:
+ if not user_id:
+ user_id = await grandfather_existing_users()
+ if user_id:
# Future logins should also match this user ID.
await self._store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
+ auth_provider_id, remote_user_id, user_id
)
- return previously_registered_user_id
# Otherwise, generate a new user.
- attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
- user_id = await self._register_mapped_user(
- attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
- )
- return user_id
+ if not user_id:
+ attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+ if attributes.localpart is None:
+ # the mapper doesn't return a username. bail out with a redirect to
+ # the username picker.
+ await self._redirect_to_username_picker(
+ auth_provider_id,
+ remote_user_id,
+ attributes,
+ client_redirect_url,
+ extra_login_attributes,
+ )
+
+ user_id = await self._register_mapped_user(
+ attributes,
+ auth_provider_id,
+ remote_user_id,
+ request.get_user_agent(""),
+ request.getClientIP(),
+ )
+
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url, extra_login_attributes
+ )
async def _call_attribute_mapper(
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
@@ -229,10 +301,8 @@ class SsoHandler:
)
if not attributes.localpart:
- raise MappingException(
- "Error parsing SSO response: SSO mapping provider plugin "
- "did not return a localpart value"
- )
+ # the mapper has not picked a localpart
+ return attributes
# Check if this mxid already exists
user_id = UserID(attributes.localpart, self._server_name).to_string()
@@ -247,6 +317,59 @@ class SsoHandler:
)
return attributes
+ async def _redirect_to_username_picker(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ attributes: UserAttributes,
+ client_redirect_url: str,
+ extra_login_attributes: Optional[JsonDict],
+ ) -> NoReturn:
+ """Creates a UsernameMappingSession and redirects the browser
+
+ Called if the user mapping provider doesn't return a localpart for a new user.
+ Raises a RedirectException which redirects the browser to the username picker.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+
+ remote_user_id: The unique identifier from the SSO provider.
+
+ attributes: the user attributes returned by the user mapping provider.
+
+ client_redirect_url: The redirect URL passed in by the client, which we
+ will eventually redirect back to.
+
+ extra_login_attributes: An optional dictionary of extra
+ attributes to be provided to the client in the login response.
+
+ Raises:
+ RedirectException
+ """
+ session_id = random_string(16)
+ now = self._clock.time_msec()
+ session = UsernameMappingSession(
+ auth_provider_id=auth_provider_id,
+ remote_user_id=remote_user_id,
+ display_name=attributes.display_name,
+ emails=attributes.emails,
+ client_redirect_url=client_redirect_url,
+ expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+ extra_login_attributes=extra_login_attributes,
+ )
+
+ self._username_mapping_sessions[session_id] = session
+ logger.info("Recorded registration session id %s", session_id)
+
+ # Set the cookie and redirect to the username picker
+ e = RedirectException(b"/_synapse/client/pick_username")
+ e.cookies.append(
+ b"%s=%s; path=/"
+ % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+ )
+ raise e
+
async def _register_mapped_user(
self,
attributes: UserAttributes,
@@ -255,9 +378,38 @@ class SsoHandler:
user_agent: str,
ip_address: str,
) -> str:
+ """Register a new SSO user.
+
+ This is called once we have successfully mapped the remote user id onto a local
+ user id, one way or another.
+
+ Args:
+ attributes: user attributes returned by the user mapping provider,
+ including a non-empty localpart.
+
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+
+ remote_user_id: The unique identifier from the SSO provider.
+
+ user_agent: The user-agent in the HTTP request (used for potential
+ shadow-banning.)
+
+ ip_address: The IP address of the requester (used for potential
+ shadow-banning.)
+
+ Raises:
+ a MappingException if the localpart is invalid.
+
+ a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+ is already taken.
+ """
+
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(attributes.localpart):
+ if not attributes.localpart or contains_invalid_mxid_characters(
+ attributes.localpart
+ ):
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -312,3 +464,108 @@ class SsoHandler:
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
)
+
+ async def check_username_availability(
+ self, localpart: str, session_id: str,
+ ) -> bool:
+ """Handle an "is username available" callback check
+
+ Args:
+ localpart: desired localpart
+ session_id: the session id for the username picker
+ Returns:
+ True if the username is available
+ Raises:
+ SynapseError if the localpart is invalid or the session is unknown
+ """
+
+ # make sure that there is a valid mapping session, to stop people dictionary-
+ # scanning for accounts
+
+ self._expire_old_sessions()
+ session = self._username_mapping_sessions.get(session_id)
+ if not session:
+ logger.info("Couldn't find session id %s", session_id)
+ raise SynapseError(400, "unknown session")
+
+ logger.info(
+ "[session %s] Checking for availability of username %s",
+ session_id,
+ localpart,
+ )
+
+ if contains_invalid_mxid_characters(localpart):
+ raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+ user_id = UserID(localpart, self._server_name).to_string()
+ user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+ logger.info("[session %s] users: %s", session_id, user_infos)
+ return not user_infos
+
+ async def handle_submit_username_request(
+ self, request: SynapseRequest, localpart: str, session_id: str
+ ) -> None:
+ """Handle a request to the username-picker 'submit' endpoint
+
+ Will serve an HTTP response to the request.
+
+ Args:
+ request: HTTP request
+ localpart: localpart requested by the user
+ session_id: ID of the username mapping session, extracted from a cookie
+ """
+ self._expire_old_sessions()
+ session = self._username_mapping_sessions.get(session_id)
+ if not session:
+ logger.info("Couldn't find session id %s", session_id)
+ raise SynapseError(400, "unknown session")
+
+ logger.info("[session %s] Registering localpart %s", session_id, localpart)
+
+ attributes = UserAttributes(
+ localpart=localpart,
+ display_name=session.display_name,
+ emails=session.emails,
+ )
+
+ # the following will raise a 400 error if the username has been taken in the
+ # meantime.
+ user_id = await self._register_mapped_user(
+ attributes,
+ session.auth_provider_id,
+ session.remote_user_id,
+ request.get_user_agent(""),
+ request.getClientIP(),
+ )
+
+ logger.info("[session %s] Registered userid %s", session_id, user_id)
+
+ # delete the mapping session and the cookie
+ del self._username_mapping_sessions[session_id]
+
+ # delete the cookie
+ request.addCookie(
+ USERNAME_MAPPING_SESSION_COOKIE_NAME,
+ b"",
+ expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+ path=b"/",
+ )
+
+ await self._auth_handler.complete_sso_login(
+ user_id,
+ request,
+ session.client_redirect_url,
+ session.extra_login_attributes,
+ )
+
+ def _expire_old_sessions(self):
+ to_expire = []
+ now = int(self._clock.time_msec())
+
+ for session_id, session in self._username_mapping_sessions.items():
+ if session.expiry_time_ms <= now:
+ to_expire.append(session_id)
+
+ for session_id in to_expire:
+ logger.info("Expiring mapping session %s", session_id)
+ del self._username_mapping_sessions[session_id]
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b9ae70adbe..893a571466 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -555,7 +555,7 @@ class SyncHandler:
event.event_id, state_filter=state_filter
)
if event.is_state():
- state_ids = state_ids.copy()
+ state_ids = dict(state_ids)
state_ids[(event.type, event.state_key)] = event.event_id
return state_ids
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index f263a638f8..d4651c8348 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,14 +14,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
import synapse.metrics
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
+from synapse.types import JsonDict
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,7 +41,7 @@ class UserDirectoryHandler(StateDeltasHandler):
be in the directory or not when necessary.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
@@ -49,7 +54,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
- self.pos = None
+ self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -61,7 +66,9 @@ class UserDirectoryHandler(StateDeltasHandler):
# we start populating the user directory
self.clock.call_later(0, self.notify_new_event)
- async def search_users(self, user_id, search_term, limit):
+ async def search_users(
+ self, user_id: str, search_term: str, limit: int
+ ) -> JsonDict:
"""Searches for users in directory
Returns:
@@ -89,7 +96,7 @@ class UserDirectoryHandler(StateDeltasHandler):
return results
- def notify_new_event(self):
+ def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.update_user_directory:
@@ -107,27 +114,33 @@ class UserDirectoryHandler(StateDeltasHandler):
self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process)
- async def handle_local_profile_change(self, user_id, profile):
+ async def handle_local_profile_change(
+ self, user_id: str, profile: ProfileInfo
+ ) -> None:
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- is_support = await self.store.is_support_user(user_id)
+
# Support users are for diagnostics and should not appear in the user directory.
- if not is_support:
+ is_support = await self.store.is_support_user(user_id)
+ # When change profile information of deactivated user it should not appear in the user directory.
+ is_deactivated = await self.store.get_user_deactivated_status(user_id)
+
+ if not (is_support or is_deactivated):
await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- async def handle_user_deactivated(self, user_id):
+ async def handle_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
await self.store.remove_from_user_dir(user_id)
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
@@ -162,7 +175,7 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos)
- async def _handle_deltas(self, deltas):
+ async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
"""Called with the state deltas to process
"""
for delta in deltas:
@@ -232,16 +245,20 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Ignoring irrelevant type: %r", typ)
async def _handle_room_publicity_change(
- self, room_id, prev_event_id, event_id, typ
- ):
+ self,
+ room_id: str,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ typ: str,
+ ) -> None:
"""Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
- room_id (str)
- prev_event_id (str|None): The previous event before the state change
- event_id (str|None): The new event after the state change
- typ (str): Type of the event
+ room_id: The ID of the room which changed.
+ prev_event_id: The previous event before the state change
+ event_id: The new event after the state change
+ typ: Type of the event
"""
logger.debug("Handling change for %s: %s", typ, room_id)
@@ -250,7 +267,7 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_event_id,
event_id,
key_name="history_visibility",
- public_value="world_readable",
+ public_value=HistoryVisibility.WORLD_READABLE,
)
elif typ == EventTypes.JoinRules:
change = await self._get_key_change(
@@ -299,12 +316,14 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile)
- async def _handle_new_user(self, room_id, user_id, profile):
+ async def _handle_new_user(
+ self, room_id: str, user_id: str, profile: ProfileInfo
+ ) -> None:
"""Called when we might need to add user to directory
Args:
- room_id (str): room_id that user joined or started being public
- user_id (str)
+ room_id: The room ID that user joined or started being public
+ user_id
"""
logger.debug("Adding new user to dir, %r", user_id)
@@ -352,12 +371,12 @@ class UserDirectoryHandler(StateDeltasHandler):
if to_insert:
await self.store.add_users_who_share_private_room(room_id, to_insert)
- async def _handle_remove_user(self, room_id, user_id):
+ async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
"""Called when we might need to remove user from directory
Args:
- room_id (str): room_id that user left or stopped being public that
- user_id (str)
+ room_id: The room ID that user left or stopped being public that
+ user_id
"""
logger.debug("Removing user %r", user_id)
@@ -370,7 +389,13 @@ class UserDirectoryHandler(StateDeltasHandler):
if len(rooms_user_is_in) == 0:
await self.store.remove_from_user_dir(user_id)
- async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+ async def _handle_profile_change(
+ self,
+ user_id: str,
+ room_id: str,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ ) -> None:
"""Check member event changes for any profile changes and update the
database if there are.
"""
diff --git a/synapse/http/client.py b/synapse/http/client.py
index df7730078f..29f40ddf5f 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -720,11 +720,14 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- readBodyToFile(response, output_stream, max_size)
+ read_body_with_max_size(response, output_stream, max_size)
+ )
+ except BodyExceededMaxSize:
+ SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (max_size,),
+ Codes.TOO_LARGE,
)
- except SynapseError:
- # This can happen e.g. because the body is too large.
- raise
except Exception as e:
raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
@@ -748,7 +751,11 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-class _ReadBodyToFileProtocol(protocol.Protocol):
+class BodyExceededMaxSize(Exception):
+ """The maximum allowed size of the HTTP body was exceeded."""
+
+
+class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -761,13 +768,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
+ self.deferred.errback(BodyExceededMaxSize())
self.deferred = defer.Deferred()
self.transport.loseConnection()
@@ -782,12 +783,15 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-def readBodyToFile(
+def read_body_with_max_size(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ If the maximum file size is reached, the returned Deferred will resolve to a
+ Failure with a BodyExceededMaxSize exception.
+
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
@@ -798,7 +802,7 @@ def readBodyToFile(
"""
d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
+ response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 5e08ef1664..b3b6dbcab0 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -15,17 +15,19 @@
import logging
import random
import time
+from io import BytesIO
from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime
-from twisted.web.client import RedirectAgent, readBody
+from twisted.web.client import RedirectAgent
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IResponse
+from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
from synapse.util.caches.ttlcache import TTLCache
@@ -53,6 +55,9 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
# lower bound for .well-known cache period
WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
+# The maximum size (in bytes) to allow a well-known file to be.
+WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB
+
# Attempt to refetch a cached well-known N% of the TTL before it expires.
# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
# we'll start trying to refetch 1 minute before it expires.
@@ -229,6 +234,9 @@ class WellKnownResolver:
server_name: name of the server, from the requested url
retry: Whether to retry the request if it fails.
+ Raises:
+ _FetchWellKnownFailure if we fail to lookup a result
+
Returns:
Returns the response object and body. Response may be a non-200 response.
"""
@@ -250,7 +258,11 @@ class WellKnownResolver:
b"GET", uri, headers=Headers(headers)
)
)
- body = await make_deferred_yieldable(readBody(response))
+ body_stream = BytesIO()
+ await make_deferred_yieldable(
+ read_body_with_max_size(response, body_stream, WELL_KNOWN_MAX_SIZE)
+ )
+ body = body_stream.getvalue()
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
@@ -259,6 +271,15 @@ class WellKnownResolver:
except defer.CancelledError:
# Bail if we've been cancelled
raise
+ except BodyExceededMaxSize:
+ # If the well-known file was too large, do not keep attempting
+ # to download it, but consider it a temporary error.
+ logger.warning(
+ "Requested .well-known file for %s is too large > %r bytes",
+ server_name.decode("ascii"),
+ WELL_KNOWN_MAX_SIZE,
+ )
+ raise _FetchWellKnownFailure(temporary=True)
except Exception as e:
if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
logger.info("Error fetching %s: %s", uri_str, e)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c962994727..b261e078c4 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -37,16 +37,19 @@ from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
+ Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
+ SynapseError,
)
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
+ BodyExceededMaxSize,
encode_query_args,
- readBodyToFile,
+ read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
@@ -975,9 +978,15 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
- d = readBodyToFile(response, output_stream, max_size)
+ d = read_body_with_max_size(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
+ except BodyExceededMaxSize:
+ msg = "Requested file is too large > %r bytes" % (max_size,)
+ logger.warning(
+ "{%s} [%s] %s", request.txn_id, request.destination, msg,
+ )
+ SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
diff --git a/synapse/notifier.py b/synapse/notifier.py
index a17352ef46..c4c8bb271d 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -34,7 +34,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.server
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -611,7 +611,9 @@ class Notifier:
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
- return state.content["history_visibility"] == "world_readable"
+ return (
+ state.content["history_visibility"] == HistoryVisibility.WORLD_READABLE
+ )
else:
return False
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index ad07ee86f6..9e7ac149a1 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -14,24 +14,70 @@
# limitations under the License.
import abc
-from typing import TYPE_CHECKING, Any, Dict
+from typing import TYPE_CHECKING, Any, Dict, Optional
-from synapse.types import RoomStreamToken
+import attr
+
+from synapse.types import JsonDict, RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
+@attr.s(slots=True)
+class PusherConfig:
+ """Parameters necessary to configure a pusher."""
+
+ id = attr.ib(type=Optional[str])
+ user_name = attr.ib(type=str)
+ access_token = attr.ib(type=Optional[int])
+ profile_tag = attr.ib(type=str)
+ kind = attr.ib(type=str)
+ app_id = attr.ib(type=str)
+ app_display_name = attr.ib(type=str)
+ device_display_name = attr.ib(type=str)
+ pushkey = attr.ib(type=str)
+ ts = attr.ib(type=int)
+ lang = attr.ib(type=Optional[str])
+ data = attr.ib(type=Optional[JsonDict])
+ last_stream_ordering = attr.ib(type=Optional[int])
+ last_success = attr.ib(type=Optional[int])
+ failing_since = attr.ib(type=Optional[int])
+
+ def as_dict(self) -> Dict[str, Any]:
+ """Information that can be retrieved about a pusher after creation."""
+ return {
+ "app_display_name": self.app_display_name,
+ "app_id": self.app_id,
+ "data": self.data,
+ "device_display_name": self.device_display_name,
+ "kind": self.kind,
+ "lang": self.lang,
+ "profile_tag": self.profile_tag,
+ "pushkey": self.pushkey,
+ }
+
+
+@attr.s(slots=True)
+class ThrottleParams:
+ """Parameters for controlling the rate of sending pushes via email."""
+
+ last_sent_ts = attr.ib(type=int)
+ throttle_ms = attr.ib(type=int)
+
+
class Pusher(metaclass=abc.ABCMeta):
- def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
- self.pusher_id = pusherdict["id"]
- self.user_id = pusherdict["user_name"]
- self.app_id = pusherdict["app_id"]
- self.pushkey = pusherdict["pushkey"]
+ self.pusher_id = pusher_config.id
+ self.user_id = pusher_config.user_name
+ self.app_id = pusher_config.app_id
+ self.pushkey = pusher_config.pushkey
+
+ self.last_stream_ordering = pusher_config.last_stream_ordering
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 11a97b8df4..d2eff75a58 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,13 +14,13 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import Pusher
+from synapse.push import Pusher, PusherConfig, ThrottleParams
from synapse.push.mailer import Mailer
if TYPE_CHECKING:
@@ -60,15 +60,14 @@ class EmailPusher(Pusher):
factor out the common parts
"""
- def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
- super().__init__(hs, pusherdict)
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
+ super().__init__(hs, pusher_config)
self.mailer = mailer
self.store = self.hs.get_datastore()
- self.email = pusherdict["pushkey"]
- self.last_stream_ordering = pusherdict["last_stream_ordering"]
+ self.email = pusher_config.pushkey
self.timed_call = None # type: Optional[DelayedCall]
- self.throttle_params = {} # type: Dict[str, Dict[str, int]]
+ self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False
self._is_processing = False
@@ -132,6 +131,7 @@ class EmailPusher(Pusher):
if not self._inited:
# this is our first loop: load up the throttle params
+ assert self.pusher_id is not None
self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
@@ -157,6 +157,7 @@ class EmailPusher(Pusher):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
+ assert start is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)
@@ -244,13 +245,13 @@ class EmailPusher(Pusher):
def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params:
- return self.throttle_params[room_id]["throttle_ms"]
+ return self.throttle_params[room_id].throttle_ms
else:
return 0
def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params:
- return self.throttle_params[room_id]["last_sent_ts"]
+ return self.throttle_params[room_id].last_sent_ts
else:
return 0
@@ -301,10 +302,10 @@ class EmailPusher(Pusher):
new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
)
- self.throttle_params[room_id] = {
- "last_sent_ts": self.clock.time_msec(),
- "throttle_ms": new_throttle_ms,
- }
+ self.throttle_params[room_id] = ThrottleParams(
+ self.clock.time_msec(), new_throttle_ms,
+ )
+ assert self.pusher_id is not None
await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index ff05ef705a..f05fb054b4 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import Pusher, PusherConfigException
+from synapse.push import Pusher, PusherConfig, PusherConfigException
from . import push_rule_evaluator, push_tools
@@ -62,33 +62,29 @@ class HttpPusher(Pusher):
# This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
- def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
- super().__init__(hs, pusherdict)
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+ super().__init__(hs, pusher_config)
self.storage = self.hs.get_storage()
- self.app_display_name = pusherdict["app_display_name"]
- self.device_display_name = pusherdict["device_display_name"]
- self.pushkey_ts = pusherdict["ts"]
- self.data = pusherdict["data"]
- self.last_stream_ordering = pusherdict["last_stream_ordering"]
+ self.app_display_name = pusher_config.app_display_name
+ self.device_display_name = pusher_config.device_display_name
+ self.pushkey_ts = pusher_config.ts
+ self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.failing_since = pusherdict["failing_since"]
+ self.failing_since = pusher_config.failing_since
self.timed_call = None
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
- if "data" not in pusherdict:
- raise PusherConfigException("No 'data' key for HTTP pusher")
- self.data = pusherdict["data"]
+ self.data = pusher_config.data
+ if self.data is None:
+ raise PusherConfigException("'data' key can not be null for HTTP pusher")
self.name = "%s/%s/%s" % (
- pusherdict["user_name"],
- pusherdict["app_id"],
- pusherdict["pushkey"],
+ pusher_config.user_name,
+ pusher_config.app_id,
+ pusher_config.pushkey,
)
- if self.data is None:
- raise PusherConfigException("data can not be null for HTTP pusher")
-
# Validate that there's a URL and it is of the proper form.
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
@@ -185,6 +181,7 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
+ assert self.last_stream_ordering is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@@ -213,6 +210,7 @@ class HttpPusher(Pusher):
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
+ assert self.last_stream_ordering is not None
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.pushkey,
@@ -319,6 +317,8 @@ class HttpPusher(Pusher):
# or may do so (i.e. is encrypted so has unknown effects).
priority = "high"
+ # This was checked in the __init__, but mypy doesn't seem to know that.
+ assert self.data is not None
if self.data.get("format") == "event_id_only":
d = {
"notification": {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 9ff092e8bb..4d875dcb91 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -486,7 +486,11 @@ class Mailer:
def add_image_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
) -> None:
- messagevars["image_url"] = event.content["url"]
+ """
+ Potentially add an image URL to the message variables.
+ """
+ if "url" in event.content:
+ messagevars["image_url"] = event.content["url"]
async def make_summary_text(
self,
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 8f1072b094..2aa7918fb4 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -14,9 +14,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Callable, Dict, Optional
-from synapse.push import Pusher
+from synapse.push import Pusher, PusherConfig
from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer
@@ -34,7 +34,7 @@ class PusherFactory:
self.pusher_types = {
"http": HttpPusher
- } # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
+ } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
@@ -47,18 +47,18 @@ class PusherFactory:
logger.info("defined email pusher type")
- def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
- kind = pusherdict["kind"]
+ def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
+ kind = pusher_config.kind
f = self.pusher_types.get(kind, None)
if not f:
return None
- logger.debug("creating %s pusher for %r", kind, pusherdict)
- return f(self.hs, pusherdict)
+ logger.debug("creating %s pusher for %r", kind, pusher_config)
+ return f(self.hs, pusher_config)
def _create_email_pusher(
- self, _hs: "HomeServer", pusherdict: Dict[str, Any]
+ self, _hs: "HomeServer", pusher_config: PusherConfig
) -> EmailPusher:
- app_name = self._app_name_from_pusherdict(pusherdict)
+ app_name = self._app_name_from_pusherdict(pusher_config)
mailer = self.mailers.get(app_name)
if not mailer:
mailer = Mailer(
@@ -68,10 +68,10 @@ class PusherFactory:
template_text=self._notif_template_text,
)
self.mailers[app_name] = mailer
- return EmailPusher(self.hs, pusherdict, mailer)
+ return EmailPusher(self.hs, pusher_config, mailer)
- def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
- data = pusherdict["data"]
+ def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
+ data = pusher_config.data
if isinstance(data, dict):
brand = data.get("brand")
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 9c12d81cfb..8158356d40 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge
@@ -23,9 +23,9 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.push import Pusher, PusherConfigException
+from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory
-from synapse.types import RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
@@ -77,7 +77,7 @@ class PusherPool:
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
- def start(self):
+ def start(self) -> None:
"""Starts the pushers off in a background process.
"""
if not self._should_start_pushers:
@@ -87,16 +87,16 @@ class PusherPool:
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- lang,
- data,
- profile_tag="",
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ lang: Optional[str],
+ data: JsonDict,
+ profile_tag: str = "",
) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
@@ -111,21 +111,23 @@ class PusherPool:
# recreated, added and started: this means we have only one
# code path adding pushers.
self.pusher_factory.create_pusher(
- {
- "id": None,
- "user_name": user_id,
- "kind": kind,
- "app_id": app_id,
- "app_display_name": app_display_name,
- "device_display_name": device_display_name,
- "pushkey": pushkey,
- "ts": time_now_msec,
- "lang": lang,
- "data": data,
- "last_stream_ordering": None,
- "last_success": None,
- "failing_since": None,
- }
+ PusherConfig(
+ id=None,
+ user_name=user_id,
+ access_token=access_token,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=time_now_msec,
+ lang=lang,
+ data=data,
+ last_stream_ordering=None,
+ last_success=None,
+ failing_since=None,
+ )
)
# create the pusher setting last_stream_ordering to the current maximum
@@ -151,43 +153,44 @@ class PusherPool:
return pusher
async def remove_pushers_by_app_id_and_pushkey_not_user(
- self, app_id, pushkey, not_user_id
- ):
+ self, app_id: str, pushkey: str, not_user_id: str
+ ) -> None:
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
- if p["user_name"] != not_user_id:
+ if p.user_name != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
app_id,
pushkey,
- p["user_name"],
+ p.user_name,
)
- await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+ await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
- async def remove_pushers_by_access_token(self, user_id, access_tokens):
+ async def remove_pushers_by_access_token(
+ self, user_id: str, access_tokens: Iterable[int]
+ ) -> None:
"""Remove the pushers for a given user corresponding to a set of
access_tokens.
Args:
- user_id (str): user to remove pushers for
- access_tokens (Iterable[int]): access token *ids* to remove pushers
- for
+ user_id: user to remove pushers for
+ access_tokens: access token *ids* to remove pushers for
"""
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id):
- if p["access_token"] in tokens:
+ if p.access_token in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
- p["app_id"],
- p["pushkey"],
- p["user_name"],
+ p.app_id,
+ p.pushkey,
+ p.user_name,
)
- await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+ await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
- def on_new_notifications(self, max_token: RoomStreamToken):
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
if not self.pushers:
# nothing to do here.
return
@@ -206,7 +209,7 @@ class PusherPool:
self._on_new_notifications(max_token)
@wrap_as_background_process("on_new_notifications")
- async def _on_new_notifications(self, max_token: RoomStreamToken):
+ async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@@ -236,7 +239,9 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+ async def on_new_receipts(
+ self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
+ ) -> None:
if not self.pushers:
# nothing to do here.
return
@@ -280,14 +285,14 @@ class PusherPool:
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
- pusher_dict = None
+ pusher_config = None
for r in resultlist:
- if r["user_name"] == user_id:
- pusher_dict = r
+ if r.user_name == user_id:
+ pusher_config = r
pusher = None
- if pusher_dict:
- pusher = await self._start_pusher(pusher_dict)
+ if pusher_config:
+ pusher = await self._start_pusher(pusher_config)
return pusher
@@ -302,44 +307,44 @@ class PusherPool:
logger.info("Started pushers")
- async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
+ async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
"""Start the given pusher
Args:
- pusherdict: dict with the values pulled from the db table
+ pusher_config: The pusher configuration with the values pulled from the db table
Returns:
The newly created pusher or None.
"""
if not self._pusher_shard_config.should_handle(
- self._instance_name, pusherdict["user_name"]
+ self._instance_name, pusher_config.user_name
):
return None
try:
- p = self.pusher_factory.create_pusher(pusherdict)
+ p = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e:
logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
- pusherdict["id"],
- pusherdict.get("user_name"),
- pusherdict.get("app_id"),
- pusherdict.get("pushkey"),
+ pusher_config.id,
+ pusher_config.user_name,
+ pusher_config.app_id,
+ pusher_config.pushkey,
e,
)
return None
except Exception:
logger.exception(
- "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+ "Couldn't start pusher id %i: caught Exception", pusher_config.id,
)
return None
if not p:
return None
- appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
+ appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
- byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+ byuser = self.pushers.setdefault(pusher_config.user_name, {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
@@ -349,8 +354,8 @@ class PusherPool:
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
- user_id = pusherdict["user_name"]
- last_stream_ordering = pusherdict["last_stream_ordering"]
+ user_id = pusher_config.user_name
+ last_stream_ordering = pusher_config.last_stream_ordering
if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
@@ -364,7 +369,7 @@ class PusherPool:
return p
- async def remove_pusher(self, app_id, pushkey, user_id):
+ async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 4c81e2d784..36071feb36 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+ async def _serialize_payload(
+ user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
+ ):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
@@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"device_id": device_id,
"initial_display_name": initial_display_name,
"is_guest": is_guest,
+ "is_appservice_ghost": is_appservice_ghost,
}
async def _handle_request(self, request, user_id):
@@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
+ is_appservice_ghost = content["is_appservice_ghost"]
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest,
+ is_appservice_ghost=is_appservice_ghost,
)
return 200, {"device_id": device_id, "access_token": access_token}
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index eb74903d68..0d39a93ed2 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -12,21 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional, Tuple
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker:
- def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+ def __init__(
+ self,
+ db_conn: Connection,
+ table: str,
+ column: str,
+ extra_tables: Optional[List[Tuple[str, str]]] = None,
+ step: int = 1,
+ ):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
- for table, column in extra_tables:
- self.advance(None, _load_current_id(db_conn, table, column))
+ if extra_tables:
+ for table, column in extra_tables:
+ self.advance(None, _load_current_id(db_conn, table, column))
- def advance(self, instance_name, new_id):
+ def advance(self, instance_name: Optional[str], new_id: int):
self._current = (max if self.step > 0 else min)(self._current, new_id)
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""
Returns:
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index c418730ba8..045bd014da 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -13,26 +13,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.types import Connection
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- self._pushers_id_gen = SlavedIdTracker(
+ self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token, rows
+ ) -> None:
if stream_name == PushersStream.NAME:
- self._pushers_id_gen.advance(instance_name, token)
+ self._pushers_id_gen.advance(instance_name, token) # type: ignore
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html
index 6d76064d13..0aaef97df8 100644
--- a/synapse/res/templates/notif.html
+++ b/synapse/res/templates/notif.html
@@ -29,7 +29,7 @@
{{ message.body_text_html }}
{%- elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
- {%- elif message.msgtype == "m.image" %}
+ {%- elif message.msgtype == "m.image" and message.image_url %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{%- elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html
new file mode 100644
index 0000000000..37ea8bb6d8
--- /dev/null
+++ b/synapse/res/username_picker/index.html
@@ -0,0 +1,19 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <title>Synapse Login</title>
+ <link rel="stylesheet" href="style.css" type="text/css" />
+ </head>
+ <body>
+ <div class="card">
+ <form method="post" class="form__input" id="form" action="submit">
+ <label for="field-username">Please pick your username:</label>
+ <input type="text" name="username" id="field-username" autofocus="">
+ <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
+ </form>
+ <!-- this is used for feedback -->
+ <div role=alert class="tooltip hidden" id="message"></div>
+ <script src="script.js"></script>
+ </div>
+ </body>
+</html>
diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js
new file mode 100644
index 0000000000..416a7c6f41
--- /dev/null
+++ b/synapse/res/username_picker/script.js
@@ -0,0 +1,95 @@
+let inputField = document.getElementById("field-username");
+let inputForm = document.getElementById("form");
+let submitButton = document.getElementById("button-submit");
+let message = document.getElementById("message");
+
+// Submit username and receive response
+function showMessage(messageText) {
+ // Unhide the message text
+ message.classList.remove("hidden");
+
+ message.textContent = messageText;
+};
+
+function doSubmit() {
+ showMessage("Success. Please wait a moment for your browser to redirect.");
+
+ // remove the event handler before re-submitting the form.
+ delete inputForm.onsubmit;
+ inputForm.submit();
+}
+
+function onResponse(response) {
+ // Display message
+ showMessage(response);
+
+ // Enable submit button and input field
+ submitButton.classList.remove('button--disabled');
+ submitButton.value = "Submit";
+};
+
+let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]");
+function usernameIsValid(username) {
+ return !allowedUsernameCharacters.test(username);
+}
+let allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
+
+function buildQueryString(params) {
+ return Object.keys(params)
+ .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k]))
+ .join('&');
+}
+
+function submitUsername(username) {
+ if(username.length == 0) {
+ onResponse("Please enter a username.");
+ return;
+ }
+ if(!usernameIsValid(username)) {
+ onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString);
+ return;
+ }
+
+ // if this browser doesn't support fetch, skip the availability check.
+ if(!window.fetch) {
+ doSubmit();
+ return;
+ }
+
+ let check_uri = 'check?' + buildQueryString({"username": username});
+ fetch(check_uri, {
+ // include the cookie
+ "credentials": "same-origin",
+ }).then((response) => {
+ if(!response.ok) {
+ // for non-200 responses, raise the body of the response as an exception
+ return response.text().then((text) => { throw text; });
+ } else {
+ return response.json();
+ }
+ }).then((json) => {
+ if(json.error) {
+ throw json.error;
+ } else if(json.available) {
+ doSubmit();
+ } else {
+ onResponse("This username is not available, please choose another.");
+ }
+ }).catch((err) => {
+ onResponse("Error checking username availability: " + err);
+ });
+}
+
+function clickSubmit() {
+ event.preventDefault();
+ if(submitButton.classList.contains('button--disabled')) { return; }
+
+ // Disable submit button and input field
+ submitButton.classList.add('button--disabled');
+
+ // Submit username
+ submitButton.value = "Checking...";
+ submitUsername(inputField.value);
+};
+
+inputForm.onsubmit = clickSubmit;
diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css
new file mode 100644
index 0000000000..745bd4c684
--- /dev/null
+++ b/synapse/res/username_picker/style.css
@@ -0,0 +1,27 @@
+input[type="text"] {
+ font-size: 100%;
+ background-color: #ededf0;
+ border: 1px solid #fff;
+ border-radius: .2em;
+ padding: .5em .9em;
+ display: block;
+ width: 26em;
+}
+
+.button--disabled {
+ border-color: #fff;
+ background-color: transparent;
+ color: #000;
+ text-transform: none;
+}
+
+.hidden {
+ display: none;
+}
+
+.tooltip {
+ background-color: #f9f9fa;
+ padding: 1em;
+ margin: 1em 0;
+}
+
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 55ddebb4fe..6f7dc06503 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -38,6 +38,7 @@ from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
+ MakeRoomAdminRestServlet,
RoomMembersRestServlet,
RoomRestServlet,
ShutdownRoomRestServlet,
@@ -228,6 +229,7 @@ def register_servlets(hs, http_server):
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
+ MakeRoomAdminRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index b902af8028..ab7cc9102a 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -16,8 +16,8 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple
-from synapse.api.constants import EventTypes, JoinRules
-from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -37,6 +37,7 @@ from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
if TYPE_CHECKING:
from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -367,3 +368,134 @@ class JoinRoomAliasServlet(RestServlet):
)
return 200, {"room_id": room_id}
+
+
+class MakeRoomAdminRestServlet(RestServlet):
+ """Allows a server admin to get power in a room if a local user has power in
+ a room. Will also invite the user if they're not in the room and it's a
+ private room. Can specify another user (rather than the admin user) to be
+ granted power, e.g.:
+
+ POST/_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
+ {
+ "user_id": "@foo:example.com"
+ }
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.state_handler = hs.get_state_handler()
+ self.is_mine_id = hs.is_mine_id
+
+ async def on_POST(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ content = parse_json_object_from_request(request, allow_empty_body=True)
+
+ # Resolve to a room ID, if necessary.
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ # Which user to grant room admin rights to.
+ user_to_add = content.get("user_id", requester.user.to_string())
+
+ # Figure out which local users currently have power in the room, if any.
+ room_state = await self.state_handler.get_current_state(room_id)
+ if not room_state:
+ raise SynapseError(400, "Server not in room")
+
+ create_event = room_state[(EventTypes.Create, "")]
+ power_levels = room_state.get((EventTypes.PowerLevels, ""))
+
+ if power_levels is not None:
+ # We pick the local user with the highest power.
+ user_power = power_levels.content.get("users", {})
+ admin_users = [
+ user_id for user_id in user_power if self.is_mine_id(user_id)
+ ]
+ admin_users.sort(key=lambda user: user_power[user])
+
+ if not admin_users:
+ raise SynapseError(400, "No local admin user in room")
+
+ admin_user_id = admin_users[-1]
+
+ pl_content = power_levels.content
+ else:
+ # If there is no power level events then the creator has rights.
+ pl_content = {}
+ admin_user_id = create_event.sender
+ if not self.is_mine_id(admin_user_id):
+ raise SynapseError(
+ 400, "No local admin user in room",
+ )
+
+ # Grant the user power equal to the room admin by attempting to send an
+ # updated power level event.
+ new_pl_content = dict(pl_content)
+ new_pl_content["users"] = dict(pl_content.get("users", {}))
+ new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
+
+ fake_requester = create_requester(
+ admin_user_id, authenticated_entity=requester.authenticated_entity,
+ )
+
+ try:
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ fake_requester,
+ event_dict={
+ "content": new_pl_content,
+ "sender": admin_user_id,
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "room_id": room_id,
+ },
+ )
+ except AuthError:
+ # The admin user we found turned out not to have enough power.
+ raise SynapseError(
+ 400, "No local admin user in room with power to update power levels."
+ )
+
+ # Now we check if the user we're granting admin rights to is already in
+ # the room. If not and it's not a public room we invite them.
+ member_event = room_state.get((EventTypes.Member, user_to_add))
+ is_joined = False
+ if member_event:
+ is_joined = member_event.content["membership"] in (
+ Membership.JOIN,
+ Membership.INVITE,
+ )
+
+ if is_joined:
+ return 200, {}
+
+ join_rules = room_state.get((EventTypes.JoinRules, ""))
+ is_public = False
+ if join_rules:
+ is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
+
+ if is_public:
+ return 200, {}
+
+ await self.room_member_handler.update_membership(
+ fake_requester,
+ target=UserID.from_string(user_to_add),
+ room_id=room_id,
+ action=Membership.INVITE,
+ )
+
+ return 200, {}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 88cba369f5..6658c2da56 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -42,17 +42,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_GET_PUSHERS_ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@@ -770,10 +759,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.store.get_pushers_by_user_id(user_id)
- filtered_pushers = [
- {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
- for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f321a..89823fcc39 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- filtered_pushers = [
- {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index eebee44a44..d837bde1d6 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -254,14 +254,18 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- # If we have a password in this request, prefer it. Otherwise, there
- # must be a password hash from an earlier request.
+ # If we have a password in this request, prefer it. Otherwise, use the
+ # password hash from an earlier request.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
- else:
+ elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
)
+ else:
+ # UI validation was skipped, but the request did not include a new
+ # password.
+ password_hash = None
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index a3bb095c2d..5b5da71815 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from functools import wraps
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -25,6 +26,22 @@ from ._base import client_patterns
logger = logging.getLogger(__name__)
+def _validate_group_id(f):
+ """Wrapper to validate the form of the group ID.
+
+ Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
+ """
+
+ @wraps(f)
+ def wrapper(self, request, group_id, *args, **kwargs):
+ if not GroupID.is_valid(group_id):
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+ return f(self, request, group_id, *args, **kwargs)
+
+ return wrapper
+
+
class GroupServlet(RestServlet):
"""Get the group profile
"""
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
return 200, group_description
+ @_validate_group_id
async def on_POST(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
return 200, category
+ @_validate_group_id
async def on_PUT(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
return 200, category
+ @_validate_group_id
async def on_PUT(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result
+ @_validate_group_id
async def on_DELETE(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, room_id, config_key):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
+ @_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -565,6 +603,7 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -589,6 +628,7 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -613,6 +653,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -637,6 +678,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 9041e7ed76..6b5a1b7109 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -655,9 +655,13 @@ class RegisterRestServlet(RestServlet):
user_id = await self.registration_handler.appservice_register(
username, as_token
)
- return await self._create_registration_details(user_id, body)
+ return await self._create_registration_details(
+ user_id, body, is_appservice_ghost=True,
+ )
- async def _create_registration_details(self, user_id, params):
+ async def _create_registration_details(
+ self, user_id, params, is_appservice_ghost=False
+ ):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -674,7 +678,11 @@ class RegisterRestServlet(RestServlet):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=False
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest=False,
+ is_appservice_ghost=is_appservice_ghost,
)
result.update({"access_token": access_token, "device_id": device_id})
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f43639a..a3dee14ed4 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging
from typing import Tuple
from synapse.http import servlet
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
@@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ("messages",))
sender_user_id = requester.user.to_string()
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f843f02454..c57ac22e58 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Dict, Set
+from typing import Dict
from signedjson.sign import sign_json
@@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
- cache_misses = {} # type: Dict[str, Set[str]]
+ # Note that the value is unused.
+ cache_misses = {} # type: Dict[str, Dict[str, int]]
for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
if not results and key_id is not None:
- cache_misses.setdefault(server_name, set()).add(key_id)
+ cache_misses.setdefault(server_name, {})[key_id] = 0
continue
if key_id is not None:
@@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
)
if miss:
- cache_misses.setdefault(server_name, set()).add(key_id)
+ cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
new file mode 100644
index 0000000000..d3b6803e65
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import pkg_resources
+
+from twisted.web.http import Request
+from twisted.web.resource import Resource
+from twisted.web.static import File
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
+from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+def pick_username_resource(hs: "HomeServer") -> Resource:
+ """Factory method to generate the username picker resource.
+
+ This resource gets mounted under /_synapse/client/pick_username. The top-level
+ resource is just a File resource which serves up the static files in the resources
+ "res" directory, but it has a couple of children:
+
+ * "submit", which does the mechanics of registering the new user, and redirects the
+ browser back to the client URL
+
+ * "check": checks if a userid is free.
+ """
+
+ # XXX should we make this path customisable so that admins can restyle it?
+ base_path = pkg_resources.resource_filename("synapse", "res/username_picker")
+
+ res = File(base_path)
+ res.putChild(b"submit", SubmitResource(hs))
+ res.putChild(b"check", AvailabilityCheckResource(hs))
+
+ return res
+
+
+class AvailabilityCheckResource(DirectServeJsonResource):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_GET(self, request: Request):
+ localpart = parse_string(request, "username", required=True)
+
+ session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+ if not session_id:
+ raise SynapseError(code=400, msg="missing session_id")
+
+ is_available = await self._sso_handler.check_username_availability(
+ localpart, session_id.decode("ascii", errors="replace")
+ )
+ return 200, {"available": is_available}
+
+
+class SubmitResource(DirectServeHtmlResource):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_POST(self, request: SynapseRequest):
+ localpart = parse_string(request, "username", required=True)
+
+ session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+ if not session_id:
+ raise SynapseError(code=400, msg="missing session_id")
+
+ await self._sso_handler.handle_submit_username_request(
+ request, localpart, session_id.decode("ascii", errors="replace")
+ )
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f85124bf81..e585954bd8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -658,7 +658,7 @@ async def _get_mainline_depth_for_event(
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while tmp_event:
- depth = mainline_map.get(event.event_id)
+ depth = mainline_map.get(tmp_event.event_id)
if depth is not None:
return depth
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index bbff3c8d5b..c0d9d1240f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
+from typing import TYPE_CHECKING
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
@@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
-__all__ = ["DataStores", "DataStore"]
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
class Storage:
"""The high level interfaces for talking to various storage layers.
"""
- def __init__(self, hs, stores: Databases):
+ def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2b196ded1b..a25c4093bc 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@
import logging
import random
from abc import ABCMeta
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
-from synapse.types import Collection, get_domain_from_id
+from synapse.storage.types import Connection
+from synapse.types import Collection, StreamToken, get_domain_from_id
from synapse.util import json_decoder
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
self.rand = random.SystemRandom()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: StreamToken,
+ rows: Iterable[Any],
+ ) -> None:
pass
- def _invalidate_state_caches(self, room_id, members_changed):
+ def _invalidate_state_caches(
+ self, room_id: str, members_changed: Iterable[str]
+ ) -> None:
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
Args:
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have
- changed
+ room_id: Room where state changed
+ members_changed: The user_ids of members that have changed
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
@@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
- ):
+ ) -> None:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
cache.invalidate(tuple(key))
-def db_to_json(db_content):
+def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
Take some data from a database row and return a JSON-decoded object.
Args:
- db_content (memoryview|buffer|bytes|bytearray|unicode)
+ db_content: The JSON-encoded contents from the database.
+
+ Returns:
+ The object decoded from JSON.
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721ebe9..29b8ca676a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
from synapse.util import json_encoder
from . import engines
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.storage.database import DatabasePool, LoggingTransaction
+
logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
- def __init__(self, name):
+ def __init__(self, name: str):
self.name = name
self.total_item_count = 0
- self.total_duration_ms = 0
- self.avg_item_count = 0
- self.avg_duration_ms = 0
+ self.total_duration_ms = 0.0
+ self.avg_item_count = 0.0
+ self.avg_duration_ms = 0.0
- def update(self, item_count, duration_ms):
+ def update(self, item_count: int, duration_ms: float) -> None:
"""Update the stats after doing an update"""
self.total_item_count += item_count
self.total_duration_ms += duration_ms
@@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
- def average_items_per_ms(self):
+ def average_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
# changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms)
- def total_items_per_ms(self):
+ def total_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@@ -83,21 +88,25 @@ class BackgroundUpdater:
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
- def __init__(self, hs, database):
+ def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
- self._background_update_performance = {}
- self._background_update_handlers = {}
+ self._background_update_performance = (
+ {}
+ ) # type: Dict[str, BackgroundUpdatePerformance]
+ self._background_update_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._all_done = False
- def start_doing_background_updates(self):
+ def start_doing_background_updates(self) -> None:
run_as_background_process("background_updates", self.run_background_updates)
- async def run_background_updates(self, sleep=True):
+ async def run_background_updates(self, sleep: bool = True) -> None:
logger.info("Starting background schema updates")
while True:
if sleep:
@@ -148,7 +157,7 @@ class BackgroundUpdater:
return False
- async def has_completed_background_update(self, update_name) -> bool:
+ async def has_completed_background_update(self, update_name: str) -> bool:
"""Check if the given background update has finished running.
"""
if self._all_done:
@@ -173,8 +182,7 @@ class BackgroundUpdater:
Returns once some amount of work is done.
Args:
- desired_duration_ms(float): How long we want to spend
- updating.
+ desired_duration_ms: How long we want to spend updating.
Returns:
True if we have finished running all the background updates, otherwise False
"""
@@ -220,6 +228,7 @@ class BackgroundUpdater:
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
+ assert self._current_background_update is not None
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
@@ -273,7 +282,11 @@ class BackgroundUpdater:
return len(self._background_update_performance)
- def register_background_update_handler(self, update_name, update_handler):
+ def register_background_update_handler(
+ self,
+ update_name: str,
+ update_handler: Callable[[JsonDict, int], Awaitable[int]],
+ ):
"""Register a handler for doing a background update.
The handler should take two arguments:
@@ -287,12 +300,12 @@ class BackgroundUpdater:
The handler is responsible for updating the progress of the update.
Args:
- update_name(str): The name of the update that this code handles.
- update_handler(function): The function that does the update.
+ update_name: The name of the update that this code handles.
+ update_handler: The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
- def register_noop_background_update(self, update_name):
+ def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
This is useful when we previously did a background update, but no
@@ -302,10 +315,10 @@ class BackgroundUpdater:
also be called to clear the update.
Args:
- update_name (str): Name of update
+ update_name: Name of update
"""
- async def noop_update(progress, batch_size):
+ async def noop_update(progress: JsonDict, batch_size: int) -> int:
await self._end_background_update(update_name)
return 1
@@ -313,14 +326,14 @@ class BackgroundUpdater:
def register_background_index_update(
self,
- update_name,
- index_name,
- table,
- columns,
- where_clause=None,
- unique=False,
- psql_only=False,
- ):
+ update_name: str,
+ index_name: str,
+ table: str,
+ columns: Iterable[str],
+ where_clause: Optional[str] = None,
+ unique: bool = False,
+ psql_only: bool = False,
+ ) -> None:
"""Helper for store classes to do a background index addition
To use:
@@ -332,19 +345,19 @@ class BackgroundUpdater:
2. In the Store constructor, call this method
Args:
- update_name (str): update_name to register for
- index_name (str): name of index to add
- table (str): table to add index to
- columns (list[str]): columns/expressions to include in index
- unique (bool): true to make a UNIQUE index
+ update_name: update_name to register for
+ index_name: name of index to add
+ table: table to add index to
+ columns: columns/expressions to include in index
+ unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
"""
- def create_index_psql(conn):
+ def create_index_psql(conn: Connection) -> None:
conn.rollback()
# postgres insists on autocommit for the index
- conn.set_session(autocommit=True)
+ conn.set_session(autocommit=True) # type: ignore
try:
c = conn.cursor()
@@ -371,9 +384,9 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
finally:
- conn.set_session(autocommit=False)
+ conn.set_session(autocommit=False) # type: ignore
- def create_index_sqlite(conn):
+ def create_index_sqlite(conn: Connection) -> None:
# Sqlite doesn't support concurrent creation of indexes.
#
# We don't use partial indices on SQLite as it wasn't introduced
@@ -399,7 +412,7 @@ class BackgroundUpdater:
c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine):
- runner = create_index_psql
+ runner = create_index_psql # type: Optional[Callable[[Connection], None]]
elif psql_only:
runner = None
else:
@@ -433,7 +446,9 @@ class BackgroundUpdater:
"background_updates", keyvalues={"update_name": update_name}
)
- async def _background_update_progress(self, update_name: str, progress: dict):
+ async def _background_update_progress(
+ self, update_name: str, progress: dict
+ ) -> None:
"""Update the progress of a background update
Args:
@@ -441,20 +456,22 @@ class BackgroundUpdater:
progress: The progress of the update.
"""
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
progress,
)
- def _background_update_progress_txn(self, txn, update_name, progress):
+ def _background_update_progress_txn(
+ self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
+ ) -> None:
"""Update the progress of a background update
Args:
- txn(cursor): The transaction.
- update_name(str): The name of the background update task
- progress(dict): The progress of the update.
+ txn: The transaction.
+ update_name: The name of the background update task
+ progress: The progress of the update.
"""
progress_json = json_encoder.encode(progress)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 43660ec4fb..701748f93b 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -149,9 +149,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
- )
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
@@ -342,12 +339,13 @@ class DataStore(
filters = []
args = [self.hs.config.server_name]
+ # `name` is in database already in lower case
if name:
- filters.append("(name LIKE ? OR displayname LIKE ?)")
- args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
+ args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
elif user_id:
filters.append("name LIKE ?")
- args.extend(["%" + user_id + "%"])
+ args.extend(["%" + user_id.lower() + "%"])
if not guests:
filters.append("is_guest = 0")
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 2408432738..c5468c7b0d 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,11 +14,12 @@
# limitations under the License.
import logging
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.types import UserID
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -546,7 +547,9 @@ class ClientIpStore(ClientIpWorkerStore):
}
return ret
- async def get_user_ip_and_agents(self, user):
+ async def get_user_ip_and_agents(
+ self, user: UserID
+ ) -> List[Dict[str, Union[str, int]]]:
user_id = user.to_string()
results = {}
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index f8f4bb9b3f..04ac2d0ced 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
from synapse.storage.keys import FetchKeyResult
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore):
)
async def get_server_verify_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
- ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
+ ) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
@@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore):
"""
keys = {}
- def _get_keys(txn, batch):
+ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
@@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore):
# `ts_valid_until_ms`.
ts_valid_until_ms = 0
- res = FetchKeyResult(
+ keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
- keys[(server_name, key_id)] = res
- def _txn(txn):
+ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7997242d90..77ba9d819e 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -15,18 +15,32 @@
# limitations under the License.
import logging
-from typing import Iterable, Iterator, List, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
from canonicaljson import encode_canonical_json
+from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+ self._pushers_id_gen = StreamIdGenerator(
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+ )
+
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
@@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore):
)
continue
- yield r
+ yield PusherConfig(**r)
- async def user_has_pusher(self, user_id):
+ async def user_has_pusher(self, user_id: str) -> bool:
ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
- def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
- return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
+ async def get_pushers_by_app_id_and_pushkey(
+ self, app_id: str, pushkey: str
+ ) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
- def get_pushers_by_user_id(self, user_id):
- return self.get_pushers_by({"user_name": user_id})
+ async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"user_name": user_id})
- async def get_pushers_by(self, keyvalues):
+ async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
@@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- async def get_all_pushers(self):
+ async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
@@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore):
)
@cached(num_args=1, max_entries=15000)
- async def get_if_user_has_pusher(self, user_id):
+ async def get_if_user_has_pusher(self, user_id: str):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- async def get_if_users_have_pushers(self, user_ids):
+ async def get_if_users_have_pushers(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, bool]:
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
@@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
async def update_pusher_failing_since(
- self, app_id, pushkey, user_id, failing_since
+ self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
) -> None:
await self.db_pool.simple_update(
table="pushers",
@@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore):
desc="update_pusher_failing_since",
)
- async def get_throttle_params_by_room(self, pusher_id):
+ async def get_throttle_params_by_room(
+ self, pusher_id: str
+ ) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
@@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore):
params_by_room = {}
for row in res:
- params_by_room[row["room_id"]] = {
- "last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"],
- }
+ params_by_room[row["room_id"]] = ThrottleParams(
+ row["last_sent_ts"], row["throttle_ms"],
+ )
return params_by_room
- async def set_throttle_params(self, pusher_id, room_id, params) -> None:
+ async def set_throttle_params(
+ self, pusher_id: str, room_id: str, params: ThrottleParams
+ ) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
- params,
+ {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params",
lock=False,
)
class PusherStore(PusherWorkerStore):
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- pushkey_ts,
- lang,
- data,
- last_stream_ordering,
- profile_tag="",
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ pushkey_ts: int,
+ lang: Optional[str],
+ data: Optional[JsonDict],
+ last_stream_ordering: int,
+ profile_tag: str = "",
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
@@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore):
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
"add_pusher",
- self._invalidate_cache_and_stream,
+ self._invalidate_cache_and_stream, # type: ignore
self.get_if_user_has_pusher,
(user_id,),
)
async def delete_pusher_by_app_id_pushkey_user_id(
- self, app_id, pushkey, user_id
+ self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
- self._invalidate_cache_and_stream(
+ self._invalidate_cache_and_stream( # type: ignore
txn, self.get_if_user_has_pusher, (user_id,)
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index ff96c34c2e..8d05288ed4 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -943,6 +943,42 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="del_user_pending_deactivation",
)
+ async def get_access_token_last_validated(self, token_id: int) -> int:
+ """Retrieves the time (in milliseconds) of the last validation of an access token.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if the access token was not found.
+
+ Returns:
+ The last validation time.
+ """
+ result = await self.db_pool.simple_select_one_onecol(
+ "access_tokens", {"id": token_id}, "last_validated"
+ )
+
+ # If this token has not been validated (since starting to track this),
+ # return 0 instead of None.
+ return result or 0
+
+ async def update_access_token_last_validated(self, token_id: int) -> None:
+ """Updates the last time an access token was validated.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if there was a problem updating this.
+ """
+ now = self._clock.time_msec()
+
+ await self.db_pool.simple_update_one(
+ "access_tokens",
+ {"id": token_id},
+ {"last_validated": now},
+ desc="update_access_token_last_validated",
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -1150,6 +1186,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
The token ID
"""
next_id = self._access_tokens_id_gen.get_next()
+ now = self._clock.time_msec()
await self.db_pool.simple_insert(
"access_tokens",
@@ -1160,6 +1197,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
+ "last_validated": now,
},
desc="add_access_token_to_user",
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6b89db15c9..4650d0689b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -379,14 +379,14 @@ class RoomWorkerStore(SQLBaseStore):
# Filter room names by a string
where_statement = ""
if search_term:
- where_statement = "WHERE state.name LIKE ?"
+ where_statement = "WHERE LOWER(state.name) LIKE ?"
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term + "%"
+ search_term = "%" + search_term.lower() + "%"
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
new file mode 100644
index 0000000000..1a101cd5eb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The last time this access token was "validated" (i.e. logged in or succeeded
+-- at user-interactive authentication).
+ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/27local_invites.sql b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
new file mode 100644
index 0000000000..44b2a0572f
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
@@ -0,0 +1,18 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is unused since Synapse v1.17.0.
+DROP TABLE local_invites;
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index d87ceec6da..ef11f1c3b3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -17,7 +17,7 @@ import logging
import re
from typing import Any, Dict, Iterable, Optional, Set, Tuple
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
@@ -360,7 +360,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
- if hist_vis_ev.content.get("history_visibility") == "world_readable":
+ if (
+ hist_vis_ev.content.get("history_visibility")
+ == HistoryVisibility.WORLD_READABLE
+ ):
return True
return False
@@ -393,9 +396,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute(
@@ -415,9 +418,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
)
"""
txn.execute(
@@ -432,9 +435,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
elif new_entry is False:
sql = """
UPDATE user_directory_search
- SET vector = setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ SET vector = setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
txn.execute(
@@ -761,7 +764,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
INNER JOIN user_directory AS d USING (user_id)
WHERE
%s
- AND vector @@ to_tsquery('english', ?)
+ AND vector @@ to_tsquery('simple', ?)
ORDER BY
(CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
* (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
@@ -770,13 +773,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
3 * ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
- to_tsquery('english', ?),
+ to_tsquery('simple', ?),
8
)
+ ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
- to_tsquery('english', ?),
+ to_tsquery('simple', ?),
8
)
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index afd10f7bae..c03871f393 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -17,11 +17,12 @@
import logging
import attr
+from signedjson.types import VerifyKey
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class FetchKeyResult:
- verify_key = attr.ib() # VerifyKey: the key itself
- valid_until_ts = attr.ib() # int: how long we can use this key for
+ verify_key = attr.ib(type=VerifyKey) # the key itself
+ valid_until_ts = attr.ib(type=int) # how long we can use this key for
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 70e636b0ba..61fc49c69c 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,14 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ Collection,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -68,6 +75,21 @@ stale_forward_extremities_counter = Histogram(
buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
)
+state_resolutions_during_persistence = Counter(
+ "synapse_storage_events_state_resolutions_during_persistence",
+ "Number of times we had to do state res to calculate new current state",
+)
+
+potential_times_prune_extremities = Counter(
+ "synapse_storage_events_potential_times_prune_extremities",
+ "Number of times we might be able to prune extremities",
+)
+
+times_pruned_extremities = Counter(
+ "synapse_storage_events_times_pruned_extremities",
+ "Number of times we were actually be able to prune extremities",
+)
+
class _EventPeristenceQueue:
"""Queues up events so that they can be persisted in bulk with only one
@@ -454,7 +476,15 @@ class EventsPersistenceStorage:
latest_event_ids,
new_latest_event_ids,
)
- current_state, delta_ids = res
+ current_state, delta_ids, new_latest_event_ids = res
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
# If either are not None then there has been a change,
# and we need to work out the delta (or use that
@@ -573,29 +603,35 @@ class EventsPersistenceStorage:
self,
room_id: str,
events_context: List[Tuple[EventBase, EventContext]],
- old_latest_event_ids: Iterable[str],
- new_latest_event_ids: Iterable[str],
- ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+ old_latest_event_ids: Set[str],
+ new_latest_event_ids: Set[str],
+ ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
"""Calculate the current state dict after adding some new events to
a room
Args:
- room_id (str):
+ room_id:
room to which the events are being added. Used for logging etc
- events_context (list[(EventBase, EventContext)]):
+ events_context:
events and contexts which are being added to the room
- old_latest_event_ids (iterable[str]):
+ old_latest_event_ids:
the old forward extremities for the room.
- new_latest_event_ids (iterable[str]):
+ new_latest_event_ids :
the new forward extremities for the room.
Returns:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
+ Returns a tuple of two state maps and a set of new forward
+ extremities.
+
+ The first state map is the full new current state and the second
+ is the delta to the existing current state. If both are None then
+ there has been no change.
+
+ The function may prune some old entries from the set of new
+ forward extremities if it's safe to do so.
If there has been a change then we only return the delta if its
already been calculated. Conversely if we do know the delta then
@@ -672,7 +708,7 @@ class EventsPersistenceStorage:
# If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
- return None, None
+ return None, None, new_latest_event_ids
if len(new_state_groups) == 1 and len(old_state_groups) == 1:
# If we're going from one state group to another, lets check if
@@ -689,7 +725,7 @@ class EventsPersistenceStorage:
# the current state in memory then lets also return that,
# but it doesn't matter if we don't.
new_state = state_groups_map.get(new_state_group)
- return new_state, delta_ids
+ return new_state, delta_ids, new_latest_event_ids
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
@@ -701,7 +737,7 @@ class EventsPersistenceStorage:
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- return state_groups_map[new_state_groups.pop()], None
+ return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids
# Ok, we need to defer to the state handler to resolve our state sets.
@@ -734,7 +770,139 @@ class EventsPersistenceStorage:
state_res_store=StateResolutionStore(self.main_store),
)
- return res.state, None
+ state_resolutions_during_persistence.inc()
+
+ # If the returned state matches the state group of one of the new
+ # forward extremities then we check if we are able to prune some state
+ # extremities.
+ if res.state_group and res.state_group in new_state_groups:
+ new_latest_event_ids = await self._prune_extremities(
+ room_id,
+ new_latest_event_ids,
+ res.state_group,
+ event_id_to_state_group,
+ events_context,
+ )
+
+ return res.state, None, new_latest_event_ids
+
+ async def _prune_extremities(
+ self,
+ room_id: str,
+ new_latest_event_ids: Set[str],
+ resolved_state_group: int,
+ event_id_to_state_group: Dict[str, int],
+ events_context: List[Tuple[EventBase, EventContext]],
+ ) -> Set[str]:
+ """See if we can prune any of the extremities after calculating the
+ resolved state.
+ """
+ potential_times_prune_extremities.inc()
+
+ # We keep all the extremities that have the same state group, and
+ # see if we can drop the others.
+ new_new_extrems = {
+ e
+ for e in new_latest_event_ids
+ if event_id_to_state_group[e] == resolved_state_group
+ }
+
+ dropped_extrems = set(new_latest_event_ids) - new_new_extrems
+
+ logger.debug("Might drop extremities: %s", dropped_extrems)
+
+ # We only drop events from the extremities list if:
+ # 1. we're not currently persisting them;
+ # 2. they're not our own events (or are dummy events); and
+ # 3. they're either:
+ # 1. over N hours old and more than N events ago (we use depth to
+ # calculate); or
+ # 2. we are persisting an event from the same domain and more than
+ # M events ago.
+ #
+ # The idea is that we don't want to drop events that are "legitimate"
+ # extremities (that we would want to include as prev events), only
+ # "stuck" extremities that are e.g. due to a gap in the graph.
+ #
+ # Note that we either drop all of them or none of them. If we only drop
+ # some of the events we don't know if state res would come to the same
+ # conclusion.
+
+ for ev, _ in events_context:
+ if ev.event_id in dropped_extrems:
+ logger.debug(
+ "Not dropping extremities: %s is being persisted", ev.event_id
+ )
+ return new_latest_event_ids
+
+ dropped_events = await self.main_store.get_events(
+ dropped_extrems,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+
+ new_senders = {get_domain_from_id(e.sender) for e, _ in events_context}
+
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ current_depth = max(e.depth for e, _ in events_context)
+ for event in dropped_events.values():
+ # If the event is a local dummy event then we should check it
+ # doesn't reference any local events, as we want to reference those
+ # if we send any new events.
+ #
+ # Note we do this recursively to handle the case where a dummy event
+ # references a dummy event that only references remote events.
+ #
+ # Ideally we'd figure out a way of still being able to drop old
+ # dummy events that reference local events, but this is good enough
+ # as a first cut.
+ events_to_check = [event]
+ while events_to_check:
+ new_events = set()
+ for event_to_check in events_to_check:
+ if self.is_mine_id(event_to_check.sender):
+ if event_to_check.type != EventTypes.Dummy:
+ logger.debug("Not dropping own event")
+ return new_latest_event_ids
+ new_events.update(event_to_check.prev_event_ids())
+
+ prev_events = await self.main_store.get_events(
+ new_events,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ events_to_check = prev_events.values()
+
+ if (
+ event.origin_server_ts < one_day_ago
+ and event.depth < current_depth - 100
+ ):
+ continue
+
+ # We can be less conservative about dropping extremities from the
+ # same domain, though we do want to wait a little bit (otherwise
+ # we'll immediately remove all extremities from a given server).
+ if (
+ get_domain_from_id(event.sender) in new_senders
+ and event.depth < current_depth - 20
+ ):
+ continue
+
+ logger.debug(
+ "Not dropping as too new and not in new_senders: %s", new_senders,
+ )
+
+ return new_latest_event_ids
+
+ times_pruned_extremities.inc()
+
+ logger.info(
+ "Pruning forward extremities in room %s: from %s -> %s",
+ room_id,
+ new_latest_event_ids,
+ new_new_extrems,
+ )
+ return new_new_extrems
async def _calculate_state_delta(
self, room_id: str, current_state: StateMap[str]
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754feab..f91a2eae7a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging
import os
import re
from collections import Counter
-from typing import Optional, TextIO
+from typing import Generator, Iterable, List, Optional, TextIO, Tuple
import attr
+from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
@@ -70,7 +71,7 @@ def prepare_database(
db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
- databases: Collection[str] = ["main", "state"],
+ databases: Collection[str] = ("main", "state"),
):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -155,7 +156,9 @@ def prepare_database(
raise
-def _setup_new_database(cur, database_engine, databases):
+def _setup_new_database(
+ cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
+) -> None:
"""Sets up the physical database by finding a base set of "full schemas" and
then applying any necessary deltas, including schemas from the given data
stores.
@@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
folder as well those in the data stores specified.
Args:
- cur (Cursor): a database cursor
- database_engine (DatabaseEngine)
- databases (list[str]): The names of the databases to instantiate
- on the given physical database.
+ cur: a database cursor
+ database_engine
+ databases: The names of the databases to instantiate on the given physical database.
"""
# We're about to set up a brand new database so we check that its
@@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
database_engine.check_new_database(cur)
current_dir = os.path.join(dir_path, "schema", "full_schemas")
- directory_entries = os.listdir(current_dir)
# First we find the highest full schema version we have
valid_versions = []
- for filename in directory_entries:
+ for filename in os.listdir(current_dir):
try:
ver = int(filename)
except ValueError:
@@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
for database in databases
)
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
def _upgrade_existing_database(
- cur,
- current_version,
- applied_delta_files,
- upgraded,
- database_engine,
- config,
- databases,
- is_empty=False,
-):
+ cur: Cursor,
+ current_version: int,
+ applied_delta_files: List[str],
+ upgraded: bool,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str],
+ is_empty: bool = False,
+) -> None:
"""Upgrades an existing physical database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -323,21 +324,20 @@ def _upgrade_existing_database(
for a version before applying those in the next version.
Args:
- cur (Cursor)
- current_version (int): The current version of the schema.
- applied_delta_files (list): A list of deltas that have already been
- applied.
- upgraded (bool): Whether the current version was generated by having
+ cur
+ current_version: The current version of the schema.
+ applied_delta_files: A list of deltas that have already been applied.
+ upgraded: Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
- database_engine (DatabaseEngine)
- config (synapse.config.homeserver.HomeServerConfig|None):
+ database_engine
+ config:
None if we are initialising a blank database, otherwise the application
config
- databases (list[str]): The names of the databases to instantiate
+ databases: The names of the databases to instantiate
on the given physical database.
- is_empty (bool): Is this a blank database? I.e. do we need to run the
+ is_empty: Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts.
"""
if is_empty:
@@ -358,6 +358,7 @@ def _upgrade_existing_database(
if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade
+ assert config is not None
check_database_before_upgrade(cur, database_engine, config)
start_ver = current_version
@@ -388,10 +389,10 @@ def _upgrade_existing_database(
)
# Used to check if we have any duplicate file names
- file_name_counter = Counter()
+ file_name_counter = Counter() # type: CounterType[str]
# Now find which directories have anything of interest.
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
logger.debug("Looking for schema deltas in %s", directory)
try:
@@ -445,11 +446,11 @@ def _upgrade_existing_database(
module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
- module = imp.load_source(module_name, absolute_path, python_file)
+ module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
logger.info("Running script %s", relative_path)
- module.run_create(cur, database_engine)
+ module.run_create(cur, database_engine) # type: ignore
if not is_empty:
- module.run_upgrade(cur, database_engine, config=config)
+ module.run_upgrade(cur, database_engine, config=config) # type: ignore
elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
@@ -497,14 +498,15 @@ def _upgrade_existing_database(
logger.info("Schema now up to date")
-def _apply_module_schemas(txn, database_engine, config):
+def _apply_module_schemas(
+ txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
"""Apply the module schemas for the dynamic modules, if any
Args:
cur: database cursor
- database_engine: synapse database engine class
- config (synapse.config.homeserver.HomeServerConfig):
- application config
+ database_engine:
+ config: application config
"""
for (mod, _config) in config.password_providers:
if not hasattr(mod, "get_db_schema_files"):
@@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
)
-def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+def _apply_module_schema_files(
+ cur: Cursor,
+ database_engine: BaseDatabaseEngine,
+ modname: str,
+ names_and_streams: Iterable[Tuple[str, TextIO]],
+) -> None:
"""Apply the module schemas for a single module
Args:
cur: database cursor
database_engine: synapse database engine class
- modname (str): fully qualified name of the module
- names_and_streams (Iterable[(str, file)]): the names and streams of
- schemas to be applied
+ modname: fully qualified name of the module
+ names_and_streams: the names and streams of schemas to be applied
"""
cur.execute(
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
@@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
)
-def get_statements(f):
+def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
@@ -594,17 +600,19 @@ def get_statements(f):
statement_buffer = statements[-1].strip()
-def executescript(txn, schema_path):
+def executescript(txn: Cursor, schema_path: str) -> None:
with open(schema_path, "r") as f:
execute_statements_from_stream(txn, f)
-def execute_statements_from_stream(cur: Cursor, f: TextIO):
+def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
for statement in get_statements(f):
cur.execute(statement)
-def _get_or_create_schema_state(txn, database_engine):
+def _get_or_create_schema_state(
+ txn: Cursor, database_engine: BaseDatabaseEngine
+) -> Optional[Tuple[int, List[str], bool]]:
# Bluntly try creating the schema_version tables.
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
executescript(txn, schema_path)
@@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None
- upgraded = bool(row[1]) if row else None
if current_version:
txn.execute(
@@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
(current_version,),
)
applied_deltas = [d for d, in txn]
+ upgraded = bool(row[1])
return current_version, applied_deltas, upgraded
return None
@@ -634,5 +642,5 @@ class _DirectoryListing:
`file_name` attr is kept first.
"""
- file_name = attr.ib()
- absolute_path = attr.ib()
+ file_name = attr.ib(type=str)
+ absolute_path = attr.ib(type=str)
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index bfa0a9fd06..6c359c1aae 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,7 +15,12 @@
import itertools
import logging
-from typing import Set
+from typing import TYPE_CHECKING, Set
+
+from synapse.storage.databases import Databases
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -24,10 +29,10 @@ class PurgeEventsStorage:
"""High level interface for purging rooms and event history.
"""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: Databases):
self.stores = stores
- async def purge_room(self, room_id: str):
+ async def purge_room(self, room_id: str) -> None:
"""Deletes all record of a room
"""
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index cec96ad6a7..2564f34b47 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+from typing import Any, Dict, List, Optional, Tuple
import attr
from synapse.api.errors import SynapseError
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -27,18 +29,18 @@ class PaginationChunk:
"""Returned by relation pagination APIs.
Attributes:
- chunk (list): The rows returned by pagination
- next_batch (Any|None): Token to fetch next set of results with, if
+ chunk: The rows returned by pagination
+ next_batch: Token to fetch next set of results with, if
None then there are no more results.
- prev_batch (Any|None): Token to fetch previous set of results with, if
+ prev_batch: Token to fetch previous set of results with, if
None then there are no previous results.
"""
- chunk = attr.ib()
- next_batch = attr.ib(default=None)
- prev_batch = attr.ib(default=None)
+ chunk = attr.ib(type=List[JsonDict])
+ next_batch = attr.ib(type=Optional[Any], default=None)
+ prev_batch = attr.ib(type=Optional[Any], default=None)
- def to_dict(self):
+ def to_dict(self) -> Dict[str, Any]:
d = {"chunk": self.chunk}
if self.next_batch:
@@ -59,25 +61,25 @@ class RelationPaginationToken:
boundaries of the chunk as pagination tokens.
Attributes:
- topological (int): The topological ordering of the boundary event
- stream (int): The stream ordering of the boundary event.
+ topological: The topological ordering of the boundary event
+ stream: The stream ordering of the boundary event.
"""
- topological = attr.ib()
- stream = attr.ib()
+ topological = attr.ib(type=int)
+ stream = attr.ib(type=int)
@staticmethod
- def from_string(string):
+ def from_string(string: str) -> "RelationPaginationToken":
try:
t, s = string.split("-")
return RelationPaginationToken(int(t), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
- def to_string(self):
+ def to_string(self) -> str:
return "%d-%d" % (self.topological, self.stream)
- def as_tuple(self):
+ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
@@ -89,23 +91,23 @@ class AggregationPaginationToken:
aggregation groups, we can just use them as our pagination token.
Attributes:
- count (int): The count of relations in the boundar group.
- stream (int): The MAX stream ordering in the boundary group.
+ count: The count of relations in the boundary group.
+ stream: The MAX stream ordering in the boundary group.
"""
- count = attr.ib()
- stream = attr.ib()
+ count = attr.ib(type=int)
+ stream = attr.ib(type=int)
@staticmethod
- def from_string(string):
+ def from_string(string: str) -> "AggregationPaginationToken":
try:
c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
- def to_string(self):
+ def to_string(self) -> str:
return "%d-%d" % (self.count, self.stream)
- def as_tuple(self):
+ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..31ccbf23dc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
import attr
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.storage.databases import Databases
+
logger = logging.getLogger(__name__)
# Used for generic functions below
@@ -330,10 +343,12 @@ class StateGroupStorage:
"""High level interface to fetching state for event.
"""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
- async def get_state_group_delta(self, state_group: int):
+ async def get_state_group_delta(
+ self, state_group: int
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -341,8 +356,8 @@ class StateGroupStorage:
state_group: The state group used to retrieve state deltas.
Returns:
- Tuple[Optional[int], Optional[StateMap[str]]]:
- (prev_group, delta_ids)
+ A tuple of the previous group and a state map of the event IDs which
+ make up the delta between the old and new state groups.
"""
return await self.stores.state.get_state_group_delta(state_group)
@@ -436,7 +451,7 @@ class StateGroupStorage:
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -472,7 +487,7 @@ class StateGroupStorage:
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
@@ -500,7 +515,7 @@ class StateGroupStorage:
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
@@ -516,7 +531,7 @@ class StateGroupStorage:
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 02d71302ea..133c0e7a28 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -153,12 +153,12 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager())
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
- int
+ The maximum stream id.
"""
with self._lock:
if self._unfinished_ids:
diff --git a/synapse/types.py b/synapse/types.py
index 3ab6bdbe06..c7d4e95809 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -349,15 +349,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile(
)
-def map_username_to_mxid_localpart(username, case_sensitive=False):
+def map_username_to_mxid_localpart(
+ username: Union[str, bytes], case_sensitive: bool = False
+) -> str:
"""Map a username onto a string suitable for a MXID
This follows the algorithm laid out at
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
Args:
- username (unicode|bytes): username to be mapped
- case_sensitive (bool): true if TEST and test should be mapped
+ username: username to be mapped
+ case_sensitive: true if TEST and test should be mapped
onto different mxids
Returns:
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 527365498e..ec50e7e977 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -12,11 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import operator
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import (
+ AccountDataTypes,
+ EventTypes,
+ HistoryVisibility,
+ Membership,
+)
from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.state import StateFilter
@@ -25,7 +29,12 @@ from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
-VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined")
+VISIBILITY_PRIORITY = (
+ HistoryVisibility.WORLD_READABLE,
+ HistoryVisibility.SHARED,
+ HistoryVisibility.INVITED,
+ HistoryVisibility.JOINED,
+)
MEMBERSHIP_PRIORITY = (
@@ -116,7 +125,7 @@ async def filter_events_for_client(
# see events in the room at that point in the DAG, and that shouldn't be decided
# on those checks.
if filter_send_to_client:
- if event.type == "org.matrix.dummy_event":
+ if event.type == EventTypes.Dummy:
return None
if not event.is_state() and event.sender in ignore_list:
@@ -150,12 +159,14 @@ async def filter_events_for_client(
# get the room_visibility at the time of the event.
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
- visibility = visibility_event.content.get("history_visibility", "shared")
+ visibility = visibility_event.content.get(
+ "history_visibility", HistoryVisibility.SHARED
+ )
else:
- visibility = "shared"
+ visibility = HistoryVisibility.SHARED
if visibility not in VISIBILITY_PRIORITY:
- visibility = "shared"
+ visibility = HistoryVisibility.SHARED
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
@@ -165,7 +176,7 @@ async def filter_events_for_client(
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
- prev_visibility = "shared"
+ prev_visibility = HistoryVisibility.SHARED
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
@@ -210,17 +221,17 @@ async def filter_events_for_client(
# otherwise, it depends on the room visibility.
- if visibility == "joined":
+ if visibility == HistoryVisibility.JOINED:
# we weren't a member at the time of the event, so we can't
# see this event.
return None
- elif visibility == "invited":
+ elif visibility == HistoryVisibility.INVITED:
# user can also see the event if they were *invited* at the time
# of the event.
return event if membership == Membership.INVITE else None
- elif visibility == "shared" and is_peeking:
+ elif visibility == HistoryVisibility.SHARED and is_peeking:
# if the visibility is shared, users cannot see the event unless
# they have *subequently* joined the room (or were members at the
# time, of course)
@@ -284,8 +295,10 @@ async def filter_events_for_server(
def check_event_is_visible(event, state):
history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if history:
- visibility = history.content.get("history_visibility", "shared")
- if visibility in ["invited", "joined"]:
+ visibility = history.content.get(
+ "history_visibility", HistoryVisibility.SHARED
+ )
+ if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
@@ -305,7 +318,7 @@ async def filter_events_for_server(
if memtype == Membership.JOIN:
return True
elif memtype == Membership.INVITE:
- if visibility == "invited":
+ if visibility == HistoryVisibility.INVITED:
return True
else:
# server has no users in the room: redact
@@ -336,7 +349,8 @@ async def filter_events_for_server(
else:
event_map = await storage.main.get_events(visibility_ids)
all_open = all(
- e.content.get("history_visibility") in (None, "shared", "world_readable")
+ e.content.get("history_visibility")
+ in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
for e in event_map.values()
)
|