From 0cc3bf97b4399234cf20f52ae4c09d03661225ff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Dec 2021 11:15:46 -0500 Subject: Additional type hints for the config module, part 2. (#11480) --- synapse/config/key.py | 36 +++++++++++++++++++++--------------- synapse/config/metrics.py | 6 ++++-- synapse/config/server.py | 2 +- synapse/config/tls.py | 2 +- 4 files changed, 27 insertions(+), 19 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/key.py b/synapse/config/key.py index 035ee2416b..ee83c6c06b 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,12 +16,14 @@ import hashlib import logging import os -from typing import Any, Dict +from typing import Any, Dict, Iterator, List, Optional import attr import jsonschema from signedjson.key import ( NACL_ED25519, + SigningKey, + VerifyKey, decode_signing_key_base64, decode_verify_key_bytes, generate_signing_key, @@ -31,6 +33,7 @@ from signedjson.key import ( ) from unpaddedbase64 import decode_base64 +from synapse.types import JsonDict from synapse.util.stringutils import random_string, random_string_with_symbols from ._base import Config, ConfigError @@ -81,14 +84,13 @@ To suppress this warning and continue using 'matrix.org', admins should set logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True, auto_attribs=True) class TrustedKeyServer: - # string: name of the server. - server_name = attr.ib() + # name of the server. + server_name: str - # dict[str,VerifyKey]|None: map from key id to key object, or None to disable - # signature verification. - verify_keys = attr.ib(default=None) + # map from key id to key object, or None to disable signature verification. + verify_keys: Optional[Dict[str, VerifyKey]] = None class KeyConfig(Config): @@ -279,15 +281,15 @@ class KeyConfig(Config): % locals() ) - def read_signing_keys(self, signing_key_path, name): + def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]: """Read the signing keys in the given path. Args: - signing_key_path (str) - name (str): Associated config key name + signing_key_path + name: Associated config key name Returns: - list[SigningKey] + The signing keys read from the given path. """ signing_keys = self.read_file(signing_key_path, name) @@ -296,7 +298,9 @@ class KeyConfig(Config): except Exception as e: raise ConfigError("Error reading %s: %s" % (name, str(e))) - def read_old_signing_keys(self, old_signing_keys): + def read_old_signing_keys( + self, old_signing_keys: Optional[JsonDict] + ) -> Dict[str, VerifyKey]: if old_signing_keys is None: return {} keys = {} @@ -340,7 +344,7 @@ class KeyConfig(Config): write_signing_keys(signing_key_file, (key,)) -def _perspectives_to_key_servers(config): +def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]: """Convert old-style 'perspectives' configs into new-style 'trusted_key_servers' Returns an iterable of entries to add to trusted_key_servers. @@ -402,7 +406,9 @@ TRUSTED_KEY_SERVERS_SCHEMA = { } -def _parse_key_servers(key_servers, federation_verify_certificates): +def _parse_key_servers( + key_servers: List[Any], federation_verify_certificates: bool +) -> Iterator[TrustedKeyServer]: try: jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA) except jsonschema.ValidationError as e: @@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates): yield result -def _assert_keyserver_has_verify_keys(trusted_key_server): +def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None: if not trusted_key_server.verify_keys: raise ConfigError(INSECURE_NOTARY_ERROR) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 7ac82edb0e..1cc26e7578 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -22,10 +22,12 @@ from ._base import Config, ConfigError @attr.s class MetricsFlags: - known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool)) + known_servers: bool = attr.ib( + default=False, validator=attr.validators.instance_of(bool) + ) @classmethod - def all_off(cls): + def all_off(cls) -> "MetricsFlags": """ Instantiate the flags with all options set to off. """ diff --git a/synapse/config/server.py b/synapse/config/server.py index ba5b954263..1de2dea9b0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1257,7 +1257,7 @@ class ServerConfig(Config): help="Turn on the twisted telnet manhole service on the given port.", ) - def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]: + def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]: """Reads the three durations for the GC min interval option, returning seconds.""" if durations is None: return None diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 4ca111618f..ffb316e4c0 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -132,7 +132,7 @@ class TlsConfig(Config): self.tls_certificate: Optional[crypto.X509] = None self.tls_private_key: Optional[crypto.PKey] = None - def read_certificate_from_disk(self): + def read_certificate_from_disk(self) -> None: """ Read the certificates and private key from disk. """ -- cgit 1.5.1 From 2519beaad25d2c824129c26eab10962773e643a6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Dec 2021 12:02:46 -0500 Subject: Add missing type hints to `synapse.appservice` (#11360) --- changelog.d/11360.misc | 1 + mypy.ini | 3 ++ synapse/appservice/__init__.py | 101 ++++++++++++++++++++++-------------- synapse/appservice/api.py | 48 ++++++++++++----- synapse/appservice/scheduler.py | 74 +++++++++++++++----------- synapse/config/appservice.py | 3 +- tests/appservice/test_appservice.py | 11 ++-- 7 files changed, 148 insertions(+), 93 deletions(-) create mode 100644 changelog.d/11360.misc (limited to 'synapse/config') diff --git a/changelog.d/11360.misc b/changelog.d/11360.misc new file mode 100644 index 0000000000..43e25720c5 --- /dev/null +++ b/changelog.d/11360.misc @@ -0,0 +1 @@ +Add type hints to `synapse.appservice`. diff --git a/mypy.ini b/mypy.ini index 9aeeca2bb2..4551302c82 100644 --- a/mypy.ini +++ b/mypy.ini @@ -143,6 +143,9 @@ disallow_untyped_defs = True [mypy-synapse.app.*] disallow_untyped_defs = True +[mypy-synapse.appservice.*] +disallow_untyped_defs = True + [mypy-synapse.config._base] disallow_untyped_defs = True diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index f9d3bd337d..8c9ff93b2c 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import re from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Match, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern + +import attr +from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -33,6 +37,13 @@ class ApplicationServiceState(Enum): UP = "up" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Namespace: + exclusive: bool + group_id: Optional[str] + regex: Pattern[str] + + class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -50,17 +61,17 @@ class ApplicationService: def __init__( self, - token, - hostname, - id, - sender, - url=None, - namespaces=None, - hs_token=None, - protocols=None, - rate_limited=True, - ip_range_whitelist=None, - supports_ephemeral=False, + token: str, + hostname: str, + id: str, + sender: str, + url: Optional[str] = None, + namespaces: Optional[JsonDict] = None, + hs_token: Optional[str] = None, + protocols: Optional[Iterable[str]] = None, + rate_limited: bool = True, + ip_range_whitelist: Optional[IPSet] = None, + supports_ephemeral: bool = False, ): self.token = token self.url = ( @@ -85,27 +96,33 @@ class ApplicationService: self.rate_limited = rate_limited - def _check_namespaces(self, namespaces): + def _check_namespaces( + self, namespaces: Optional[JsonDict] + ) -> Dict[str, List[Namespace]]: # Sanity check that it is of the form: # { # users: [ {regex: "[A-z]+.*", exclusive: true}, ...], # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # } - if not namespaces: + if namespaces is None: namespaces = {} + result: Dict[str, List[Namespace]] = {} + for ns in ApplicationService.NS_LIST: + result[ns] = [] + if ns not in namespaces: - namespaces[ns] = [] continue - if type(namespaces[ns]) != list: + if not isinstance(namespaces[ns], list): raise ValueError("Bad namespace value for '%s'" % ns) for regex_obj in namespaces[ns]: if not isinstance(regex_obj, dict): raise ValueError("Expected dict regex for ns '%s'" % ns) - if not isinstance(regex_obj.get("exclusive"), bool): + exclusive = regex_obj.get("exclusive") + if not isinstance(exclusive, bool): raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) group_id = regex_obj.get("group_id") if group_id: @@ -126,22 +143,26 @@ class ApplicationService: ) regex = regex_obj.get("regex") - if isinstance(regex, str): - regex_obj["regex"] = re.compile(regex) # Pre-compile regex - else: + if not isinstance(regex, str): raise ValueError("Expected string for 'regex' in ns '%s'" % ns) - return namespaces - def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]: - for regex_obj in self.namespaces[namespace_key]: - if regex_obj["regex"].match(test_string): - return regex_obj + # Pre-compile regex. + result[ns].append(Namespace(exclusive, group_id, re.compile(regex))) + + return result + + def _matches_regex( + self, namespace_key: str, test_string: str + ) -> Optional[Namespace]: + for namespace in self.namespaces[namespace_key]: + if namespace.regex.match(test_string): + return namespace return None - def _is_exclusive(self, ns_key: str, test_string: str) -> bool: - regex_obj = self._matches_regex(test_string, ns_key) - if regex_obj: - return regex_obj["exclusive"] + def _is_exclusive(self, namespace_key: str, test_string: str) -> bool: + namespace = self._matches_regex(namespace_key, test_string) + if namespace: + return namespace.exclusive return False async def _matches_user( @@ -260,15 +281,15 @@ class ApplicationService: def is_interested_in_user(self, user_id: str) -> bool: return ( - bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) + bool(self._matches_regex(ApplicationService.NS_USERS, user_id)) or user_id == self.sender ) def is_interested_in_alias(self, alias: str) -> bool: - return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) + return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias)) def is_interested_in_room(self, room_id: str) -> bool: - return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) + return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id)) def is_exclusive_user(self, user_id: str) -> bool: return ( @@ -285,14 +306,14 @@ class ApplicationService: def is_exclusive_room(self, room_id: str) -> bool: return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exclusive_user_regexes(self): + def get_exclusive_user_regexes(self) -> List[Pattern[str]]: """Get the list of regexes used to determine if a user is exclusively registered by the AS """ return [ - regex_obj["regex"] - for regex_obj in self.namespaces[ApplicationService.NS_USERS] - if regex_obj["exclusive"] + namespace.regex + for namespace in self.namespaces[ApplicationService.NS_USERS] + if namespace.exclusive ] def get_groups_for_user(self, user_id: str) -> Iterable[str]: @@ -305,15 +326,15 @@ class ApplicationService: An iterable that yields group_id strings. """ return ( - regex_obj["group_id"] - for regex_obj in self.namespaces[ApplicationService.NS_USERS] - if "group_id" in regex_obj and regex_obj["regex"].match(user_id) + namespace.group_id + for namespace in self.namespaces[ApplicationService.NS_USERS] + if namespace.group_id and namespace.regex.match(user_id) ) def is_rate_limited(self) -> bool: return self.rate_limited - def __str__(self): + def __str__(self) -> str: # copy dictionary and redact token fields so they don't get logged dict_copy = self.__dict__.copy() dict_copy["token"] = "" diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index f51b636417..def4424af0 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib -from typing import TYPE_CHECKING, List, Optional, Tuple +import urllib.parse +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from prometheus_client import Counter @@ -53,7 +53,7 @@ HOUR_IN_MS = 60 * 60 * 1000 APP_SERVICE_PREFIX = "/_matrix/app/unstable" -def _is_valid_3pe_metadata(info): +def _is_valid_3pe_metadata(info: JsonDict) -> bool: if "instances" not in info: return False if not isinstance(info["instances"], list): @@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info): return True -def _is_valid_3pe_result(r, field): +def _is_valid_3pe_result(r: JsonDict, field: str) -> bool: if not isinstance(r, dict): return False @@ -93,9 +93,13 @@ class ApplicationServiceApi(SimpleHttpClient): hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) - async def query_user(self, service, user_id): + async def query_user(self, service: "ApplicationService", user_id: str) -> bool: if service.url is None: return False + + # This is required by the configuration. + assert service.hs_token is not None + uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) try: response = await self.get_json(uri, {"access_token": service.hs_token}) @@ -109,9 +113,13 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_user to %s threw exception %s", uri, ex) return False - async def query_alias(self, service, alias): + async def query_alias(self, service: "ApplicationService", alias: str) -> bool: if service.url is None: return False + + # This is required by the configuration. + assert service.hs_token is not None + uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) try: response = await self.get_json(uri, {"access_token": service.hs_token}) @@ -125,7 +133,13 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_alias to %s threw exception %s", uri, ex) return False - async def query_3pe(self, service, kind, protocol, fields): + async def query_3pe( + self, + service: "ApplicationService", + kind: str, + protocol: str, + fields: Dict[bytes, List[bytes]], + ) -> List[JsonDict]: if kind == ThirdPartyEntityKind.USER: required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: @@ -205,11 +219,14 @@ class ApplicationServiceApi(SimpleHttpClient): events: List[EventBase], ephemeral: List[JsonDict], txn_id: Optional[int] = None, - ): + ) -> bool: if service.url is None: return True - events = self._serialize(service, events) + # This is required by the configuration. + assert service.hs_token is not None + + serialized_events = self._serialize(service, events) if txn_id is None: logger.warning( @@ -221,9 +238,12 @@ class ApplicationServiceApi(SimpleHttpClient): # Never send ephemeral events to appservices that do not support it if service.supports_ephemeral: - body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral} + body = { + "events": serialized_events, + "de.sorunome.msc2409.ephemeral": ephemeral, + } else: - body = {"events": events} + body = {"events": serialized_events} try: await self.put_json( @@ -238,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient): [event.get("event_id") for event in events], ) sent_transactions_counter.labels(service.id).inc() - sent_events_counter.labels(service.id).inc(len(events)) + sent_events_counter.labels(service.id).inc(len(serialized_events)) return True except CodeMessageException as e: logger.warning( @@ -260,7 +280,9 @@ class ApplicationServiceApi(SimpleHttpClient): failed_transactions_counter.labels(service.id).inc() return False - def _serialize(self, service, events): + def _serialize( + self, service: "ApplicationService", events: Iterable[EventBase] + ) -> List[JsonDict]: time_now = self.clock.time_msec() return [ serialize_event( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 6a2ce99b55..185e3a5278 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,13 +48,19 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main import DataStore from synapse.types import JsonDict +from synapse.util import Clock + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -72,7 +78,7 @@ class ApplicationServiceScheduler: case is a simple array. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() self.as_api = hs.get_application_service_api() @@ -80,7 +86,7 @@ class ApplicationServiceScheduler: self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) - async def start(self): + async def start(self) -> None: logger.info("Starting appservice scheduler") # check for any DOWN ASes and start recoverers for them. @@ -91,12 +97,14 @@ class ApplicationServiceScheduler: for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as(self, service: ApplicationService, event: EventBase): + def submit_event_for_as( + self, service: ApplicationService, event: EventBase + ) -> None: self.queuer.enqueue_event(service, event) def submit_ephemeral_events_for_as( self, service: ApplicationService, events: List[JsonDict] - ): + ) -> None: self.queuer.enqueue_ephemeral(service, events) @@ -108,16 +116,18 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl, clock): - self.queued_events = {} # dict of {service_id: [events]} - self.queued_ephemeral = {} # dict of {service_id: [events]} + def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): + # dict of {service_id: [events]} + self.queued_events: Dict[str, List[EventBase]] = {} + # dict of {service_id: [events]} + self.queued_ephemeral: Dict[str, List[JsonDict]] = {} # the appservices which currently have a transaction in flight - self.requests_in_flight = set() + self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock - def _start_background_request(self, service): + def _start_background_request(self, service: ApplicationService) -> None: # start a sender for this appservice if we don't already have one if service.id in self.requests_in_flight: return @@ -126,15 +136,17 @@ class _ServiceQueuer: "as-sender-%s" % (service.id,), self._send_request, service ) - def enqueue_event(self, service: ApplicationService, event: EventBase): + def enqueue_event(self, service: ApplicationService, event: EventBase) -> None: self.queued_events.setdefault(service.id, []).append(event) self._start_background_request(service) - def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]): + def enqueue_ephemeral( + self, service: ApplicationService, events: List[JsonDict] + ) -> None: self.queued_ephemeral.setdefault(service.id, []).extend(events) self._start_background_request(service) - async def _send_request(self, service: ApplicationService): + async def _send_request(self, service: ApplicationService) -> None: # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -168,20 +180,15 @@ class _TransactionController: if a transaction fails. (Note we have only have one of these in the homeserver.) - - Args: - clock (synapse.util.Clock): - store (synapse.storage.DataStore): - as_api (synapse.appservice.api.ApplicationServiceApi): """ - def __init__(self, clock, store, as_api): + def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi): self.clock = clock self.store = store self.as_api = as_api # map from service id to recoverer instance - self.recoverers = {} + self.recoverers: Dict[str, "_Recoverer"] = {} # for UTs self.RECOVERER_CLASS = _Recoverer @@ -191,7 +198,7 @@ class _TransactionController: service: ApplicationService, events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, - ): + ) -> None: try: txn = await self.store.create_appservice_txn( service=service, events=events, ephemeral=ephemeral or [] @@ -207,7 +214,7 @@ class _TransactionController: logger.exception("Error creating appservice transaction") run_in_background(self._on_txn_fail, service) - async def on_recovered(self, recoverer): + async def on_recovered(self, recoverer: "_Recoverer") -> None: logger.info( "Successfully recovered application service AS ID %s", recoverer.service.id ) @@ -217,18 +224,18 @@ class _TransactionController: recoverer.service, ApplicationServiceState.UP ) - async def _on_txn_fail(self, service): + async def _on_txn_fail(self, service: ApplicationService) -> None: try: await self.store.set_appservice_state(service, ApplicationServiceState.DOWN) self.start_recoverer(service) except Exception: logger.exception("Error starting AS recoverer") - def start_recoverer(self, service): + def start_recoverer(self, service: ApplicationService) -> None: """Start a Recoverer for the given service Args: - service (synapse.appservice.ApplicationService): + service: """ logger.info("Starting recoverer for AS ID %s", service.id) assert service.id not in self.recoverers @@ -257,7 +264,14 @@ class _Recoverer: callback (callable[_Recoverer]): called once the service recovers. """ - def __init__(self, clock, store, as_api, service, callback): + def __init__( + self, + clock: Clock, + store: DataStore, + as_api: ApplicationServiceApi, + service: ApplicationService, + callback: Callable[["_Recoverer"], Awaitable[None]], + ): self.clock = clock self.store = store self.as_api = as_api @@ -265,8 +279,8 @@ class _Recoverer: self.callback = callback self.backoff_counter = 1 - def recover(self): - def _retry(): + def recover(self) -> None: + def _retry() -> None: run_as_background_process( "as-recoverer-%s" % (self.service.id,), self.retry ) @@ -275,13 +289,13 @@ class _Recoverer: logger.info("Scheduling retries on %s in %fs", self.service.id, delay) self.clock.call_later(delay, _retry) - def _backoff(self): + def _backoff(self) -> None: # cap the backoff to be around 8.5min => (2^9) = 512 secs if self.backoff_counter < 9: self.backoff_counter += 1 self.recover() - async def retry(self): + async def retry(self) -> None: logger.info("Starting retries on %s", self.service.id) try: while True: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index e4bb7224a4..7fad2e0422 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -147,8 +147,7 @@ def _load_appservice( # protocols check protocols = as_info.get("protocols") if protocols: - # Because strings are lists in python - if isinstance(protocols, str) or not isinstance(protocols, list): + if not isinstance(protocols, list): raise KeyError("Optional 'protocols' must be a list if present.") for p in protocols: if not isinstance(p, str): diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index f386b5e128..ba2a2bfd64 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -16,13 +16,13 @@ from unittest.mock import Mock from twisted.internet import defer -from synapse.appservice import ApplicationService +from synapse.appservice import ApplicationService, Namespace from tests import unittest -def _regex(regex, exclusive=True): - return {"regex": re.compile(regex), "exclusive": exclusive} +def _regex(regex: str, exclusive: bool = True) -> Namespace: + return Namespace(exclusive, None, re.compile(regex)) class ApplicationServiceTestCase(unittest.TestCase): @@ -33,11 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase): url="some_url", token="some_token", hostname="matrix.org", # only used by get_groups_for_user - namespaces={ - ApplicationService.NS_USERS: [], - ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [], - }, ) self.event = Mock( type="m.something", room_id="!foo:bar", sender="@someone:somewhere" -- cgit 1.5.1 From 17886d2603112531d4eda459d312f84d0d677652 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 15 Dec 2021 10:40:52 +0000 Subject: Add experimental support for MSC3202: allowing application services to masquerade as specific devices. (#11538) --- changelog.d/11538.feature | 1 + synapse/api/auth.py | 86 ++++++++++++++++++++++++++----- synapse/config/experimental.py | 5 ++ synapse/storage/databases/main/devices.py | 20 +++++++ tests/api/test_auth.py | 64 +++++++++++++++++++++++ 5 files changed, 162 insertions(+), 14 deletions(-) create mode 100644 changelog.d/11538.feature (limited to 'synapse/config') diff --git a/changelog.d/11538.feature b/changelog.d/11538.feature new file mode 100644 index 0000000000..b6229e2b45 --- /dev/null +++ b/changelog.d/11538.feature @@ -0,0 +1 @@ +Add experimental support for MSC3202: allowing application services to masquerade as specific devices. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 44883c6663..0bf58dff08 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -155,7 +155,11 @@ class Auth: access_token = self.get_access_token_from_request(request) - user_id, app_service = await self._get_appservice_user_id(request) + ( + user_id, + device_id, + app_service, + ) = await self._get_appservice_user_id_and_device_id(request) if user_id and app_service: if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( @@ -163,16 +167,22 @@ class Auth: access_token=access_token, ip=ip_addr, user_agent=user_agent, - device_id="dummy-device", # stubbed + device_id="dummy-device" + if device_id is None + else device_id, # stubbed ) - requester = create_requester(user_id, app_service=app_service) + requester = create_requester( + user_id, app_service=app_service, device_id=device_id + ) request.requester = user_id if user_id in self._force_tracing_for_users: opentracing.force_tracing() opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("user_id", user_id) + if device_id is not None: + opentracing.set_tag("device_id", device_id) opentracing.set_tag("appservice_id", app_service.id) return requester @@ -274,33 +284,81 @@ class Auth: 403, "Application service has not registered this user (%s)" % user_id ) - async def _get_appservice_user_id( + async def _get_appservice_user_id_and_device_id( self, request: Request - ) -> Tuple[Optional[str], Optional[ApplicationService]]: + ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]: + """ + Given a request, reads the request parameters to determine: + - whether it's an application service that's making this request + - what user the application service should be treated as controlling + (the user_id URI parameter allows an application service to masquerade + any applicable user in its namespace) + - what device the application service should be treated as controlling + (the device_id[^1] URI parameter allows an application service to masquerade + as any device that exists for the relevant user) + + [^1] Unstable and provided by MSC3202. + Must use `org.matrix.msc3202.device_id` in place of `device_id` for now. + + Returns: + 3-tuple of + (user ID?, device ID?, application service?) + + Postconditions: + - If an application service is returned, so is a user ID + - A user ID is never returned without an application service + - A device ID is never returned without a user ID or an application service + - The returned application service, if present, is permitted to control the + returned user ID. + - The returned device ID, if present, has been checked to be a valid device ID + for the returned user ID. + """ + DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id" + app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) if app_service is None: - return None, None + return None, None, None if app_service.ip_range_whitelist: ip_address = IPAddress(request.getClientIP()) if ip_address not in app_service.ip_range_whitelist: - return None, None + return None, None, None # This will always be set by the time Twisted calls us. assert request.args is not None - if b"user_id" not in request.args: - return app_service.sender, app_service + if b"user_id" in request.args: + effective_user_id = request.args[b"user_id"][0].decode("utf8") + await self.validate_appservice_can_control_user_id( + app_service, effective_user_id + ) + else: + effective_user_id = app_service.sender - user_id = request.args[b"user_id"][0].decode("utf8") - await self.validate_appservice_can_control_user_id(app_service, user_id) + effective_device_id: Optional[str] = None - if app_service.sender == user_id: - return app_service.sender, app_service + if ( + self.hs.config.experimental.msc3202_device_masquerading_enabled + and DEVICE_ID_ARG_NAME in request.args + ): + effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8") + # We only just set this so it can't be None! + assert effective_device_id is not None + device_opt = await self.store.get_device( + effective_user_id, effective_device_id + ) + if device_opt is None: + # For now, use 400 M_EXCLUSIVE if the device doesn't exist. + # This is an open thread of discussion on MSC3202 as of 2021-12-09. + raise AuthError( + 400, + f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})", + Codes.EXCLUSIVE, + ) - return user_id, app_service + return effective_user_id, effective_device_id, app_service async def get_user_by_access_token( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d78a15097c..678c78d565 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -49,3 +49,8 @@ class ExperimentalConfig(Config): # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) + + # The portion of MSC3202 which is related to device masquerading. + self.msc3202_device_masquerading_enabled: bool = experimental.get( + "msc3202_device_masquerading", False + ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3932599988..273adb61fd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -128,6 +128,26 @@ class DeviceWorkerStore(SQLBaseStore): allow_none=True, ) + async def get_device_opt( + self, user_id: str, device_id: str + ) -> Optional[Dict[str, Any]]: + """Retrieve a device. Only returns devices that are not marked as + hidden. + + Args: + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve + Returns: + A dict containing the device information, or None if the device does not exist. + """ + return await self.db_pool.simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + allow_none=True, + ) + async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 3aa9ba3c43..a2dfa1ed05 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -31,6 +31,7 @@ from synapse.types import Requester from tests import unittest from tests.test_utils import simple_async_mock +from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -210,6 +211,69 @@ class AuthTestCase(unittest.HomeserverTestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_failure(self.auth.get_user_by_req(request), AuthError) + @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) + def test_get_user_by_req_appservice_valid_token_valid_device_id(self): + """ + Tests that when an application service passes the device_id URL parameter + with the ID of a valid device for the user in question, + the requester instance tracks that device ID. + """ + masquerading_user_id = b"@doppelganger:matrix.org" + masquerading_device_id = b"DOPPELDEVICE" + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None + ) + app_service.is_interested_in_user = Mock(return_value=True) + self.store.get_app_service_by_token = Mock(return_value=app_service) + # This just needs to return a truth-y value. + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) + # This also needs to just return a truth-y value + self.store.get_device = simple_async_mock({"hidden": False}) + + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] + request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.assertEquals( + requester.user.to_string(), masquerading_user_id.decode("utf8") + ) + self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8")) + + @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) + def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): + """ + Tests that when an application service passes the device_id URL parameter + with an ID that is not a valid device ID for the user in question, + the request fails with the appropriate error code. + """ + masquerading_user_id = b"@doppelganger:matrix.org" + masquerading_device_id = b"NOT_A_REAL_DEVICE_ID" + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None + ) + app_service.is_interested_in_user = Mock(return_value=True) + self.store.get_app_service_by_token = Mock(return_value=app_service) + # This just needs to return a truth-y value. + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) + # This also needs to just return a falsey value + self.store.get_device = simple_async_mock(None) + + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] + request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + failure = self.get_failure(self.auth.get_user_by_req(request), AuthError) + self.assertEquals(failure.value.code, 400) + self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE) + def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = simple_async_mock( TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") -- cgit 1.5.1 From 43f5cc7adc02a05ba4075b8aab3b479bda67f441 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 16 Dec 2021 11:25:37 -0600 Subject: Add MSC2716 and MSC3030 to `/versions` -> `unstable_features` (#11582) As suggested in https://github.com/matrix-org/matrix-react-sdk/pull/7372#discussion_r769523369 --- changelog.d/11582.misc | 1 + synapse/config/experimental.py | 2 +- synapse/rest/client/versions.py | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11582.misc (limited to 'synapse/config') diff --git a/changelog.d/11582.misc b/changelog.d/11582.misc new file mode 100644 index 0000000000..a0291f64e2 --- /dev/null +++ b/changelog.d/11582.misc @@ -0,0 +1 @@ +Add [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) and [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) to `/versions` -> `unstable_features` to detect server support. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 678c78d565..dbaeb10918 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,7 +32,7 @@ class ExperimentalConfig(Config): # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) - # MSC2716 (backfill existing history) + # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) # MSC2285 (hidden read receipts) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 8d888f4565..2290c57c12 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -93,6 +93,10 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving hidden read receipts as per MSC2285 "org.matrix.msc2285": self.config.experimental.msc2285_enabled, + # Adds support for importing historical messages as per MSC2716 + "org.matrix.msc2716": self.config.experimental.msc2716_enabled, + # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 + "org.matrix.msc3030": self.config.experimental.msc3030_enabled, }, }, ) -- cgit 1.5.1 From cbd82d0b2db069400b5d43373838817d8a0209e7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Dec 2021 13:47:12 -0500 Subject: Convert all namedtuples to attrs. (#11665) To improve type hints throughout the code. --- changelog.d/11665.misc | 1 + synapse/api/filtering.py | 3 +- synapse/config/repository.py | 34 +++---- synapse/federation/federation_base.py | 5 - synapse/federation/send_queue.py | 47 +++++----- synapse/handlers/appservice.py | 4 +- synapse/handlers/directory.py | 10 +- synapse/handlers/room_list.py | 22 ++--- synapse/handlers/typing.py | 14 ++- synapse/http/server.py | 10 +- synapse/replication/tcp/streams/_base.py | 129 +++++++++++++------------- synapse/replication/tcp/streams/federation.py | 15 ++- synapse/rest/media/v1/media_repository.py | 19 ++-- synapse/state/__init__.py | 5 +- synapse/storage/databases/main/directory.py | 10 +- synapse/storage/databases/main/events.py | 13 ++- synapse/storage/databases/main/room.py | 26 ++++-- synapse/storage/databases/main/search.py | 16 +++- synapse/storage/databases/main/state.py | 14 --- synapse/storage/databases/main/stream.py | 12 ++- synapse/types.py | 22 ++--- tests/replication/test_federation_ack.py | 6 +- 22 files changed, 231 insertions(+), 206 deletions(-) create mode 100644 changelog.d/11665.misc (limited to 'synapse/config') diff --git a/changelog.d/11665.misc b/changelog.d/11665.misc new file mode 100644 index 0000000000..e7cc8ff23f --- /dev/null +++ b/changelog.d/11665.misc @@ -0,0 +1 @@ +Convert `namedtuples` to `attrs`. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 13dd6ce248..d087c816db 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -351,8 +351,7 @@ class Filter: True if the event matches the filter. """ # We usually get the full "events" as dictionaries coming through, - # except for presence which actually gets passed around as its own - # namedtuple type. + # except for presence which actually gets passed around as its own type. if isinstance(event, UserPresenceState): user_id = event.user_id field_matchers = { diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b129b9dd68..1980351e77 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -14,10 +14,11 @@ import logging import os -from collections import namedtuple from typing import Dict, List, Tuple from urllib.request import getproxies_environment # type: ignore +import attr + from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict @@ -44,18 +45,20 @@ THUMBNAIL_SIZE_YAML = """\ HTTP_PROXY_SET_WARNING = """\ The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" -ThumbnailRequirement = namedtuple( - "ThumbnailRequirement", ["width", "height", "method", "media_type"] -) -MediaStorageProviderConfig = namedtuple( - "MediaStorageProviderConfig", - ( - "store_local", # Whether to store newly uploaded local files - "store_remote", # Whether to store newly downloaded remote files - "store_synchronous", # Whether to wait for successful storage for local uploads - ), -) +@attr.s(frozen=True, slots=True, auto_attribs=True) +class ThumbnailRequirement: + width: int + height: int + method: str + media_type: str + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class MediaStorageProviderConfig: + store_local: bool # Whether to store newly uploaded local files + store_remote: bool # Whether to store newly downloaded remote files + store_synchronous: bool # Whether to wait for successful storage for local uploads def parse_thumbnail_requirements( @@ -66,11 +69,10 @@ def parse_thumbnail_requirements( method, and thumbnail media type to precalculate Args: - thumbnail_sizes(list): List of dicts with "width", "height", and - "method" keys + thumbnail_sizes: List of dicts with "width", "height", and "method" keys + Returns: - Dictionary mapping from media type string to list of - ThumbnailRequirement tuples. + Dictionary mapping from media type string to list of ThumbnailRequirement. """ requirements: Dict[str, List[ThumbnailRequirement]] = {} for size in thumbnail_sizes: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index f56344a3b9..4df90e02d7 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership @@ -104,10 +103,6 @@ class FederationBase: return pdu -class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): - pass - - async def _check_sigs_on_pdu( keyring: Keyring, room_version: RoomVersion, pdu: EventBase ) -> None: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 63289a5a33..0d7c4f5067 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -30,7 +30,6 @@ Events are replicated via a separate events stream. """ import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Dict, @@ -43,6 +42,7 @@ from typing import ( Type, ) +import attr from sortedcontainers import SortedDict from synapse.api.presence import UserPresenceState @@ -382,13 +382,11 @@ class BaseFederationRow: raise NotImplementedError() -class PresenceDestinationsRow( - BaseFederationRow, - namedtuple( - "PresenceDestinationsRow", - ("state", "destinations"), # UserPresenceState # list[str] - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PresenceDestinationsRow(BaseFederationRow): + state: UserPresenceState + destinations: List[str] + TypeId = "pd" @staticmethod @@ -404,17 +402,15 @@ class PresenceDestinationsRow( buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow( - BaseFederationRow, - namedtuple( - "KeyedEduRow", - ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class KeyedEduRow(BaseFederationRow): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ + key: Tuple[str, ...] # the edu key passed to send_edu + edu: Edu + TypeId = "k" @staticmethod @@ -428,9 +424,12 @@ class KeyedEduRow( buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EduRow(BaseFederationRow): """Streams EDUs that don't have keys. See KeyedEduRow""" + edu: Edu + TypeId = "e" @staticmethod @@ -453,14 +452,14 @@ _rowtypes: Tuple[Type[BaseFederationRow], ...] = ( TypeToRow = {Row.TypeId: Row for Row in _rowtypes} -ParsedFederationStreamData = namedtuple( - "ParsedFederationStreamData", - ( - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - ), -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ParsedFederationStreamData: + # list of tuples of UserPresenceState and destinations + presence_destinations: List[Tuple[UserPresenceState, List[str]]] + # dict of destination -> { key -> Edu } + keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]] + # dict of destination -> [Edu] + edus: Dict[str, List[Edu]] def process_rows_for_federation( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 9abdad262b..7833e77e2b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -462,9 +462,9 @@ class ApplicationServicesHandler: Args: room_alias: The room alias to query. + Returns: - namedtuple: with keys "room_id" and "servers" or None if no - association can be found. + RoomAliasMapping or None if no association can be found. """ room_alias_str = room_alias.to_string() services = self.store.get_app_services() diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7ee5c47fd9..082f521791 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -278,13 +278,15 @@ class DirectoryHandler: users = await self.store.get_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} - servers = set(extra_servers) | set(servers) + servers_set = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. - if self.server_name in servers: - servers = [self.server_name] + [s for s in servers if s != self.server_name] + if self.server_name in servers_set: + servers = [self.server_name] + [ + s for s in servers_set if s != self.server_name + ] else: - servers = list(servers) + servers = list(servers_set) return {"room_id": room_id, "servers": servers} diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index ba7a14d651..1a33211a1f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -13,9 +13,9 @@ # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING, Any, Optional, Tuple +import attr import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -474,16 +474,12 @@ class RoomListHandler: ) -class RoomListNextBatch( - namedtuple( - "RoomListNextBatch", - ( - "last_joined_members", # The count to get rooms after/before - "last_room_id", # The room_id to get rooms after/before - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch - ), - ) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomListNextBatch: + last_joined_members: int # The count to get rooms after/before + last_room_id: str # The room_id to get rooms after/before + direction_is_forward: bool # True if this is a next_batch, false if prev_batch + KEY_DICT = { "last_joined_members": "m", "last_room_id": "r", @@ -502,12 +498,12 @@ class RoomListNextBatch( def to_token(self) -> str: return encode_base64( msgpack.dumps( - {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()} ) ) def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch": - return self._replace(**kwds) + return attr.evolve(self, **kwds) def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 1676ebd057..e43c22832d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -13,9 +13,10 @@ # limitations under the License. import logging import random -from collections import namedtuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +import attr + from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( @@ -37,7 +38,10 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomMember: + room_id: str + user_id: str # How often we expect remote servers to resend us presence. @@ -119,7 +123,7 @@ class FollowerTypingHandler: self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member: RoomMember) -> bool: - return member.user_id in self._room_typing.get(member.room_id, []) + return member.user_id in self._room_typing.get(member.room_id, set()) async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: @@ -166,9 +170,9 @@ class FollowerTypingHandler: for row in rows: self._room_serials[row.room_id] = token - prev_typing = set(self._room_typing.get(row.room_id, [])) + prev_typing = self._room_typing.get(row.room_id, set()) now_typing = set(row.user_ids) - self._room_typing[row.room_id] = row.user_ids + self._room_typing[row.room_id] = now_typing if self.federation: run_as_background_process( diff --git a/synapse/http/server.py b/synapse/http/server.py index e302946591..09b4125489 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,7 +14,6 @@ # limitations under the License. import abc -import collections import html import logging import types @@ -37,6 +36,7 @@ from typing import ( Union, ) +import attr import jinja2 from canonicaljson import encode_canonical_json from typing_extensions import Protocol @@ -354,9 +354,11 @@ class DirectServeJsonResource(_AsyncResource): return_json_error(f, request) -_PathEntry = collections.namedtuple( - "_PathEntry", ["pattern", "callback", "servlet_classname"] -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PathEntry: + pattern: Pattern + callback: ServletCallback + servlet_classname: str class JsonResource(DirectServeJsonResource): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 743a01da08..5a2d90c530 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -15,7 +15,6 @@ import heapq import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -30,6 +29,7 @@ from typing import ( import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -226,17 +226,14 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ - BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class BackfillStreamRow: + event_id: str + room_id: str + type: str + state_key: Optional[str] + redacts: Optional[str] + relates_to: Optional[str] NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -256,18 +253,15 @@ class BackfillStream(Stream): class PresenceStream(Stream): - PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PresenceStreamRow: + user_id: str + state: str + last_active_ts: int + last_federation_update_ts: int + last_user_sync_ts: int + status_msg: str + currently_active: bool NAME = "presence" ROW_TYPE = PresenceStreamRow @@ -302,7 +296,7 @@ class PresenceFederationStream(Stream): send. """ - @attr.s(slots=True, auto_attribs=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class PresenceFederationStreamRow: destination: str user_id: str @@ -320,9 +314,10 @@ class PresenceFederationStream(Stream): class TypingStream(Stream): - TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TypingStreamRow: + room_id: str + user_ids: List[str] NAME = "typing" ROW_TYPE = TypingStreamRow @@ -348,16 +343,13 @@ class TypingStream(Stream): class ReceiptsStream(Stream): - ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ReceiptsStreamRow: + room_id: str + receipt_type: str + user_id: str + event_id: str + data: dict NAME = "receipts" ROW_TYPE = ReceiptsStreamRow @@ -374,7 +366,9 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): """A user has changed their push rules""" - PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushRulesStreamRow: + user_id: str NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -396,10 +390,12 @@ class PushRulesStream(Stream): class PushersStream(Stream): """A user has added/changed/removed a pusher""" - PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushersStreamRow: + user_id: str + app_id: str + pushkey: str + deleted: bool NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -419,7 +415,7 @@ class CachesStream(Stream): the cache on the workers """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class CachesStreamRow: """Stream to inform workers they should invalidate their cache. @@ -430,9 +426,9 @@ class CachesStream(Stream): invalidation_ts: Timestamp of when the invalidation took place. """ - cache_func = attr.ib(type=str) - keys = attr.ib(type=Optional[List[Any]]) - invalidation_ts = attr.ib(type=int) + cache_func: str + keys: Optional[List[Any]] + invalidation_ts: int NAME = "caches" ROW_TYPE = CachesStreamRow @@ -451,9 +447,9 @@ class DeviceListsStream(Stream): told about a device update. """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: - entity = attr.ib(type=str) + entity: str NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -470,7 +466,9 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): """New to_device messages for a client""" - ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceStreamRow: + entity: str NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -487,9 +485,11 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room""" - TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TagAccountDataStreamRow: + user_id: str + room_id: str + data: JsonDict NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -506,10 +506,11 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): """Global or per room account data was changed""" - AccountDataStreamRow = namedtuple( - "AccountDataStreamRow", - ("user_id", "room_id", "data_type"), # str # Optional[str] # str - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class AccountDataStreamRow: + user_id: str + room_id: Optional[str] + data_type: str NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -573,10 +574,12 @@ class AccountDataStream(Stream): class GroupServerStream(Stream): - GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class GroupsStreamRow: + group_id: str + user_id: str + type: str + content: JsonDict NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -593,7 +596,9 @@ class GroupServerStream(Stream): class UserSignatureStream(Stream): """A user has signed their own device with their user-signing key""" - UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class UserSignatureStreamRow: + user_id: str NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 0600cdbf36..4046bdec69 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -12,14 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple +import attr + from synapse.replication.tcp.streams._base import ( Stream, current_token_without_instance, make_http_update_function, ) +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -30,13 +32,10 @@ class FederationStream(Stream): sending disabled. """ - FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class FederationStreamRow: + type: str # the type of data as defined in the BaseFederationRows + data: JsonDict # serialization of a federation.send_queue.BaseFederationRow NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 244ba261bb..71b9a34b14 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -739,14 +739,21 @@ class MediaRepository: # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails: Dict[Tuple[int, int, str], str] = {} - for r_width, r_height, r_method, r_type in requirements: - if r_method == "crop": - thumbnails.setdefault((r_width, r_height, r_type), r_method) - elif r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - thumbnails[(t_width, t_height, r_type)] = r_method + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.items(): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 446204dbe5..69ac8c3423 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq import logging -from collections import defaultdict, namedtuple +from collections import defaultdict from typing import ( TYPE_CHECKING, Any, @@ -69,9 +69,6 @@ state_groups_histogram = Histogram( ) -KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) - - EVICTION_TIMEOUT_SECONDS = 60 * 60 diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index a3442814d7..f76c6121e8 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr + from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached -RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomAliasMapping: + room_id: str + room_alias: str + servers: List[str] class DirectoryWorkerStore(CacheInvalidationWorkerStore): diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 81e67ece55..dd255aefb9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1976,14 +1976,17 @@ class PersistEventsStore: txn, self.store.get_retention_policy_for_room, (event.room_id,) ) - def store_event_search_txn(self, txn, event, key, value): + def store_event_search_txn( + self, txn: LoggingTransaction, event: EventBase, key: str, value: str + ) -> None: """Add event to the search table Args: - txn (cursor): - event (EventBase): - key (str): - value (str): + txn: The database transaction. + event: The event being added to the search table. + key: A key describing the search value (one of "content.name", + "content.topic", or "content.body") + value: The value from the event's content. """ self.store.store_search_entries_txn( txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4472335af9..c0e837854a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -13,11 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + +import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError @@ -43,9 +54,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RatelimitOverride: + messages_per_second: int + burst_count: int class RoomSortOrder(Enum): @@ -207,6 +219,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -284,7 +297,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): """ where_clauses = [] - query_args = [] + query_args: List[Union[str, int]] = [] if network_tuple: if network_tuple.appservice_id: @@ -293,6 +306,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f87acfb866..2d085a5764 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,9 +14,10 @@ import logging import re -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set +import attr + from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -33,10 +34,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -SearchEntry = namedtuple( - "SearchEntry", - ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], -) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SearchEntry: + key: str + value: str + event_id: str + room_id: str + stream_ordering: Optional[int] + origin_server_ts: int def _clean_value_for_search(value: str) -> str: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 4bc044fb16..7e5a6aae18 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,6 @@ # limitations under the License. import collections.abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership @@ -43,19 +42,6 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 9488fd5094..b0642ca69f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,9 +36,9 @@ what sort order was used: """ import abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +import attr from frozendict import frozendict from twisted.internet import defer @@ -74,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological" # Used as return values for pagination APIs -_EventDictReturn = namedtuple( - "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventDictReturn: + event_id: str + topological_ordering: Optional[int] + stream_ordering: int def generate_pagination_where_clause( @@ -825,7 +827,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: - topo = row.topological_ordering + topo: Optional[int] = row.topological_ordering else: topo = None internal = event.internal_metadata diff --git a/synapse/types.py b/synapse/types.py index b06979e8e8..42aeaf6270 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -15,7 +15,6 @@ import abc import re import string -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -227,8 +226,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): localpart = attr.ib(type=str) domain = attr.ib(type=str) - # Because this class is a namedtuple of strings and booleans, it is deeply - # immutable. + # Because this is a frozen class, it is deeply immutable. def __copy__(self): return self @@ -708,16 +706,18 @@ class PersistedEventPosition: return RoomStreamToken(None, self.stream) -class ThirdPartyInstanceID( - namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThirdPartyInstanceID: + appservice_id: Optional[str] + network_id: Optional[str] + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) def __iter__(self): raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) - # Because this class is a namedtuple of strings, it is deeply immutable. + # Because this class is a frozen class, it is deeply immutable. def __copy__(self): return self @@ -725,22 +725,18 @@ class ThirdPartyInstanceID( return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str) -> "ThirdPartyInstanceID": bits = s.split("|", 2) if len(bits) != 2: raise SynapseError(400, "Invalid ID %r" % (s,)) return cls(appservice_id=bits[0], network_id=bits[1]) - def to_string(self): + def to_string(self) -> str: return "%s|%s" % (self.appservice_id, self.network_id) __str__ = to_string - @classmethod - def create(cls, appservice_id, network_id): - return cls(appservice_id=appservice_id, network_id=network_id) - @attr.s(slots=True) class ReadReceipt: diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 04a869e295..1b6a4bf4b0 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -62,7 +62,11 @@ class FederationAckTestCase(HomeserverTestCase): "federation", "master", token=10, - rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])], + rows=[ + FederationStream.FederationStreamRow( + type="x", data={"test": [1, 2, 3]} + ) + ], ) ) -- cgit 1.5.1 From 8422a7f7f6154bcbadc52f0d0d27b8e6b989cb4c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Jan 2022 11:08:08 -0500 Subject: Include the topic event in the prejoin state, per MSC3173. (#11666) Invites and knocks will now include the topic in the stripped state send to clients before joining the room. --- changelog.d/11666.feature | 1 + docs/sample_config.yaml | 1 + synapse/config/api.py | 2 ++ tests/federation/transport/test_knocking.py | 9 +++++++++ 4 files changed, 13 insertions(+) create mode 100644 changelog.d/11666.feature (limited to 'synapse/config') diff --git a/changelog.d/11666.feature b/changelog.d/11666.feature new file mode 100644 index 0000000000..6f6b127e22 --- /dev/null +++ b/changelog.d/11666.feature @@ -0,0 +1 @@ +Include the room topic in the stripped state included with invites and knocking. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6696ed5d1e..00dfd2c013 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1488,6 +1488,7 @@ room_prejoin_state: # - m.room.encryption # - m.room.name # - m.room.create + # - m.room.topic # # Uncomment the following to disable these defaults (so that only the event # types listed in 'additional_event_types' are shared). Defaults to 'false'. diff --git a/synapse/config/api.py b/synapse/config/api.py index b18044f982..25538b82d5 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -107,6 +107,8 @@ _DEFAULT_PREJOIN_STATE_TYPES = [ EventTypes.Name, # Per MSC1772. EventTypes.Create, + # Per MSC3173. + EventTypes.Topic, ] diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 663960ff53..bfa156eebb 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -108,6 +108,15 @@ class KnockingStrippedStateEventHelperMixin(TestCase): "state_key": "", }, ), + ( + EventTypes.Topic, + { + "content": { + "topic": "A really cool room", + }, + "state_key": "", + }, + ), ] ) -- cgit 1.5.1 From 84bfe47b01402cf037417f0026ed4077223ed259 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 5 Jan 2022 11:41:49 +0000 Subject: Re-apply: Move glob_to_regex and re_word_boundary to matrix-python-common #11505 (#11687) Co-authored-by: Sean Quah --- changelog.d/11505.misc | 1 + changelog.d/11687.misc | 1 + synapse/config/room_directory.py | 3 +- synapse/config/tls.py | 3 +- synapse/federation/federation_server.py | 3 +- synapse/push/push_rule_evaluator.py | 7 ++-- synapse/python_dependencies.py | 1 + synapse/util/__init__.py | 59 +-------------------------------- tests/util/test_glob_to_regex.py | 59 --------------------------------- 9 files changed, 14 insertions(+), 123 deletions(-) create mode 100644 changelog.d/11505.misc create mode 100644 changelog.d/11687.misc delete mode 100644 tests/util/test_glob_to_regex.py (limited to 'synapse/config') diff --git a/changelog.d/11505.misc b/changelog.d/11505.misc new file mode 100644 index 0000000000..926b562fad --- /dev/null +++ b/changelog.d/11505.misc @@ -0,0 +1 @@ +Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. diff --git a/changelog.d/11687.misc b/changelog.d/11687.misc new file mode 100644 index 0000000000..926b562fad --- /dev/null +++ b/changelog.d/11687.misc @@ -0,0 +1 @@ +Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 57316c59b6..3c5e0f7ce7 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -15,8 +15,9 @@ from typing import List +from matrix_common.regex import glob_to_regex + from synapse.types import JsonDict -from synapse.util import glob_to_regex from ._base import Config, ConfigError diff --git a/synapse/config/tls.py b/synapse/config/tls.py index ffb316e4c0..6e673d65a7 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -16,11 +16,12 @@ import logging import os from typing import List, Optional, Pattern +from matrix_common.regex import glob_to_regex + from OpenSSL import SSL, crypto from twisted.internet._sslverify import Certificate, trustRootFromCertificates from synapse.config._base import Config, ConfigError -from synapse.util import glob_to_regex logger = logging.getLogger(__name__) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index cf067b56c6..ee71f289c8 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,6 +28,7 @@ from typing import ( Union, ) +from matrix_common.regex import glob_to_regex from prometheus_client import Counter, Gauge, Histogram from twisted.internet.abstract import isIPAddress @@ -65,7 +66,7 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict, get_domain_from_id -from synapse.util import glob_to_regex, json_decoder, unwrapFirstError +from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 7f68092ec5..659a53805d 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -17,9 +17,10 @@ import logging import re from typing import Any, Dict, List, Optional, Pattern, Tuple, Union +from matrix_common.regex import glob_to_regex, to_word_pattern + from synapse.events import EventBase from synapse.types import JsonDict, UserID -from synapse.util import glob_to_regex, re_word_boundary from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -184,7 +185,7 @@ class PushRuleEvaluatorForEvent: r = regex_cache.get((display_name, False, True), None) if not r: r1 = re.escape(display_name) - r1 = re_word_boundary(r1) + r1 = to_word_pattern(r1) r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r @@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: try: r = regex_cache.get((glob, True, word_boundary), None) if not r: - r = glob_to_regex(glob, word_boundary) + r = glob_to_regex(glob, word_boundary=word_boundary) regex_cache[(glob, True, word_boundary)] = r return bool(r.search(value)) except re.error: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 13fb69460e..d844fbb3b3 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -88,6 +88,7 @@ REQUIREMENTS = [ # with the latest security patches. "cryptography>=3.4.7", "ijson>=3.1", + "matrix-common==1.0.0", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 95f23e27b6..f157132210 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -14,9 +14,8 @@ import json import logging -import re import typing -from typing import Any, Callable, Dict, Generator, Optional, Pattern +from typing import Any, Callable, Dict, Generator, Optional import attr from frozendict import frozendict @@ -35,9 +34,6 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) -_WILDCARD_RUN = re.compile(r"([\?\*]+)") - - def _reject_invalid_json(val: Any) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) @@ -185,56 +181,3 @@ def log_failure( if not consumeErrors: return failure return None - - -def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern: - """Converts a glob to a compiled regex object. - - Args: - glob: pattern to match - word_boundary: If True, the pattern will be allowed to match at word boundaries - anywhere in the string. Otherwise, the pattern is anchored at the start and - end of the string. - - Returns: - compiled regex pattern - """ - - # Patterns with wildcards must be simplified to avoid performance cliffs - # - The glob `?**?**?` is equivalent to the glob `???*` - # - The glob `???*` is equivalent to the regex `.{3,}` - chunks = [] - for chunk in _WILDCARD_RUN.split(glob): - # No wildcards? re.escape() - if not _WILDCARD_RUN.match(chunk): - chunks.append(re.escape(chunk)) - continue - - # Wildcards? Simplify. - qmarks = chunk.count("?") - if "*" in chunk: - chunks.append(".{%d,}" % qmarks) - else: - chunks.append(".{%d}" % qmarks) - - res = "".join(chunks) - - if word_boundary: - res = re_word_boundary(res) - else: - # \A anchors at start of string, \Z at end of string - res = r"\A" + res + r"\Z" - - return re.compile(res, re.IGNORECASE) - - -def re_word_boundary(r: str) -> str: - """ - Adds word boundary characters to the start and end of an - expression to require that the match occur as a whole word, - but do so respecting the fact that strings starting or ending - with non-word characters will change word boundaries. - """ - # we can't use \b as it chokes on unicode. however \W seems to be okay - # as shorthand for [^0-9A-Za-z_]. - return r"(^|\W)%s(\W|$)" % (r,) diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py deleted file mode 100644 index 220accb92b..0000000000 --- a/tests/util/test_glob_to_regex.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse.util import glob_to_regex - -from tests.unittest import TestCase - - -class GlobToRegexTestCase(TestCase): - def test_literal_match(self): - """patterns without wildcards should match""" - pat = glob_to_regex("foobaz") - self.assertTrue( - pat.match("FoobaZ"), "patterns should match and be case-insensitive" - ) - self.assertFalse( - pat.match("x foobaz"), "pattern should not match at word boundaries" - ) - - def test_wildcard_match(self): - pat = glob_to_regex("f?o*baz") - - self.assertTrue( - pat.match("FoobarbaZ"), - "* should match string and pattern should be case-insensitive", - ) - self.assertTrue(pat.match("foobaz"), "* should match 0 characters") - self.assertFalse(pat.match("fooxaz"), "the character after * must match") - self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters") - self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters") - - def test_multi_wildcard(self): - """patterns with multiple wildcards in a row should match""" - pat = glob_to_regex("**baz") - self.assertTrue(pat.match("agsgsbaz"), "** should match any string") - self.assertTrue(pat.match("baz"), "** should match the empty string") - self.assertEqual(pat.pattern, r"\A.{0,}baz\Z") - - pat = glob_to_regex("*?baz") - self.assertTrue(pat.match("agsgsbaz"), "*? should match any string") - self.assertTrue(pat.match("abaz"), "*? should match a single char") - self.assertFalse(pat.match("baz"), "*? should not match the empty string") - self.assertEqual(pat.pattern, r"\A.{1,}baz\Z") - - pat = glob_to_regex("a?*?*?baz") - self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars") - self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars") - self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars") - self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z") -- cgit 1.5.1 From eedb4527f1524fc3b83cd4838774b04d6f1e3911 Mon Sep 17 00:00:00 2001 From: Philipp Matthias Schäfer Date: Wed, 5 Jan 2022 13:16:52 +0100 Subject: Fix link from generated configuration file to documentation (#11678) Co-authored-by: reivilibre Co-authored-by: reivilibre --- changelog.d/11678.doc | 1 + docs/sample_config.yaml | 2 +- synapse/config/modules.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11678.doc (limited to 'synapse/config') diff --git a/changelog.d/11678.doc b/changelog.d/11678.doc new file mode 100644 index 0000000000..dff663e782 --- /dev/null +++ b/changelog.d/11678.doc @@ -0,0 +1 @@ +Fix the documentation link in newly-generated configuration files. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 00dfd2c013..810a14b077 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -37,7 +37,7 @@ # Server admins can expand Synapse's functionality with external modules. # -# See https://matrix-org.github.io/synapse/latest/modules.html for more +# See https://matrix-org.github.io/synapse/latest/modules/index.html for more # documentation on how to configure or create custom modules for Synapse. # modules: diff --git a/synapse/config/modules.py b/synapse/config/modules.py index ae0821e5a5..85fb05890d 100644 --- a/synapse/config/modules.py +++ b/synapse/config/modules.py @@ -37,7 +37,7 @@ class ModulesConfig(Config): # Server admins can expand Synapse's functionality with external modules. # - # See https://matrix-org.github.io/synapse/latest/modules.html for more + # See https://matrix-org.github.io/synapse/latest/modules/index.html for more # documentation on how to configure or create custom modules for Synapse. # modules: -- cgit 1.5.1