diff options
author | Erik Johnston <erik@matrix.org> | 2020-08-11 22:03:14 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2020-08-11 22:03:14 +0100 |
commit | fdb46b5442c63212a52e2296491a23f1935f9929 (patch) | |
tree | c46d0940415f96e3c0383c35f1d39b7462d5c20c /synapse | |
parent | Add comment explaining cast (diff) | |
parent | Auto set logging filter (#8051) (diff) | |
download | synapse-fdb46b5442c63212a52e2296491a23f1935f9929.tar.xz |
Merge remote-tracking branch 'origin/develop' into erikj/type_server
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/config/_util.py | 49 | ||||
-rw-r--r-- | synapse/config/logger.py | 63 | ||||
-rw-r--r-- | synapse/config/saml2_config.py | 50 | ||||
-rw-r--r-- | synapse/handlers/events.py | 4 | ||||
-rw-r--r-- | synapse/handlers/saml_handler.py | 42 | ||||
-rw-r--r-- | synapse/http/client.py | 2 | ||||
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 2 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 94 | ||||
-rw-r--r-- | synapse/notifier.py | 131 | ||||
-rw-r--r-- | synapse/res/templates/saml_error.html | 17 |
10 files changed, 354 insertions, 100 deletions
diff --git a/synapse/config/_util.py b/synapse/config/_util.py new file mode 100644 index 0000000000..cd31b1c3c9 --- /dev/null +++ b/synapse/config/_util.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List + +import jsonschema + +from synapse.config._base import ConfigError +from synapse.types import JsonDict + + +def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None: + """Validates a config setting against a JsonSchema definition + + This can be used to validate a section of the config file against a schema + definition. If the validation fails, a ConfigError is raised with a textual + description of the problem. + + Args: + json_schema: the schema to validate against + config: the configuration value to be validated + config_path: the path within the config file. This will be used as a basis + for the error message. + """ + try: + jsonschema.validate(config, json_schema) + except jsonschema.ValidationError as e: + # copy `config_path` before modifying it. + path = list(config_path) + for p in list(e.path): + if isinstance(p, int): + path.append("<item %i>" % p) + else: + path.append(str(p)) + + raise ConfigError( + "Unable to parse configuration: %s at %s" % (e.message, ".".join(path)) + ) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index dd775a97e8..c96e6ef62a 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -55,24 +55,33 @@ formatters: format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \ %(request)s - %(message)s' -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - handlers: file: - class: logging.handlers.RotatingFileHandler + class: logging.handlers.TimedRotatingFileHandler formatter: precise filename: ${log_file} - maxBytes: 104857600 - backupCount: 10 - filters: [context] + when: midnight + backupCount: 3 # Does not include the current log file. encoding: utf8 + + # Default to buffering writes to log file for efficiency. This means that + # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR + # logs will still be flushed immediately. + buffer: + class: logging.handlers.MemoryHandler + target: file + # The capacity is the number of log lines that are buffered before + # being written to disk. Increasing this will lead to better + # performance, at the expensive of it taking longer for log lines to + # be written to disk. + capacity: 10 + flushLevel: 30 # Flush for WARNING logs as well + + # A handler that writes logs to stderr. Unused by default, but can be used + # instead of "buffer" and "file" in the logger handlers. console: class: logging.StreamHandler formatter: precise - filters: [context] loggers: synapse.storage.SQL: @@ -80,9 +89,24 @@ loggers: # information such as access tokens. level: INFO + twisted: + # We send the twisted logging directly to the file handler, + # to work around https://github.com/matrix-org/synapse/issues/3471 + # when using "buffer" logger. Use "console" to log to stderr instead. + handlers: [file] + propagate: false + root: level: INFO - handlers: [file, console] + + # Write logs to the `buffer` handler, which will buffer them together in memory, + # then write them to a file. + # + # Replace "buffer" with "console" to log to stderr instead. (Note that you'll + # also need to update the configuation for the `twisted` logger above, in + # this case.) + # + handlers: [buffer] disable_existing_loggers: false """ @@ -168,11 +192,26 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): handler = logging.StreamHandler() handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) logger.addHandler(handler) else: logging.config.dictConfig(log_config) + # We add a log record factory that runs all messages through the + # LoggingContextFilter so that we get the context *at the time we log* + # rather than when we write to a handler. This can be done in config using + # filter options, but care must when using e.g. MemoryHandler to buffer + # writes. + + log_filter = LoggingContextFilter(request="") + old_factory = logging.getLogRecordFactory() + + def factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + log_filter.filter(record) + return record + + logging.setLogRecordFactory(factory) + # Route Twisted's native logging through to the standard library logging # system. observer = STDLibLogObserver() diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 293643b2de..9277b5f342 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -15,7 +15,9 @@ # limitations under the License. import logging +from typing import Any, List +import attr import jinja2 import pkg_resources @@ -23,6 +25,7 @@ from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +from ._util import validate_config logger = logging.getLogger(__name__) @@ -80,6 +83,11 @@ class SAML2Config(Config): self.saml2_enabled = True + attribute_requirements = saml2_config.get("attribute_requirements") or [] + self.attribute_requirements = _parse_attribute_requirements_def( + attribute_requirements + ) + self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) @@ -341,6 +349,17 @@ class SAML2Config(Config): # #grandfathered_mxid_source_attribute: upn + # It is possible to configure Synapse to only allow logins if SAML attributes + # match particular values. The requirements can be listed under + # `attribute_requirements` as shown below. All of the listed attributes must + # match for the login to be permitted. + # + #attribute_requirements: + # - attribute: userGroup + # value: "staff" + # - attribute: department + # value: "sales" + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # @@ -368,3 +387,34 @@ class SAML2Config(Config): """ % { "config_dir_path": config_dir_path } + + +@attr.s(frozen=True) +class SamlAttributeRequirement: + """Object describing a single requirement for SAML attributes.""" + + attribute = attr.ib(type=str) + value = attr.ib(type=str) + + JSON_SCHEMA = { + "type": "object", + "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, + "required": ["attribute", "value"], + } + + +ATTRIBUTE_REQUIREMENTS_SCHEMA = { + "type": "array", + "items": SamlAttributeRequirement.JSON_SCHEMA, +} + + +def _parse_attribute_requirements_def( + attribute_requirements: Any, +) -> List[SamlAttributeRequirement]: + validate_config( + ATTRIBUTE_REQUIREMENTS_SCHEMA, + attribute_requirements, + config_path=["saml2_config", "attribute_requirements"], + ) + return [SamlAttributeRequirement(**x) for x in attribute_requirements] diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 71a89f09c7..1924636c4d 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -57,13 +57,10 @@ class EventStreamHandler(BaseHandler): timeout=0, as_client_event=True, affect_presence=True, - only_keys=None, room_id=None, is_guest=False, ): """Fetches the events stream for a given user. - - If `only_keys` is not None, events from keys will be sent down. """ if room_id: @@ -93,7 +90,6 @@ class EventStreamHandler(BaseHandler): auth_user, pagin_config, timeout, - only_keys=only_keys, is_guest=is_guest, explicit_room_id=room_id, ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 2d506dc1f2..c1fcb98454 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -14,15 +14,16 @@ # limitations under the License. import logging import re -from typing import Callable, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple import attr import saml2 import saml2.response from saml2.client import Saml2Client -from synapse.api.errors import SynapseError +from synapse.api.errors import AuthError, SynapseError from synapse.config import ConfigError +from synapse.config.saml2_config import SamlAttributeRequirement from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -34,6 +35,9 @@ from synapse.types import ( from synapse.util.async_helpers import Linearizer from synapse.util.iterutils import chunk_seq +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) @@ -49,7 +53,7 @@ class Saml2SessionData: class SamlHandler: - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -62,6 +66,7 @@ class SamlHandler: self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) + self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements # plugin to do custom mapping from saml response to mxid self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( @@ -73,7 +78,7 @@ class SamlHandler: self._auth_provider_id = "saml" # a map from saml session id to Saml2SessionData object - self._outstanding_requests_dict = {} + self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] # a lock on the mappings self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) @@ -165,11 +170,18 @@ class SamlHandler: saml2.BINDING_HTTP_POST, outstanding=self._outstanding_requests_dict, ) + except saml2.response.UnsolicitedResponse as e: + # the pysaml2 library helpfully logs an ERROR here, but neglects to log + # the session ID. I don't really want to put the full text of the exception + # in the (user-visible) exception message, so let's log the exception here + # so we can track down the session IDs later. + logger.warning(str(e)) + raise SynapseError(400, "Unexpected SAML2 login.") except Exception as e: - raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,)) + raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,)) if saml2_auth.not_signed: - raise SynapseError(400, "SAML2 response was not signed") + raise SynapseError(400, "SAML2 response was not signed.") logger.debug("SAML2 response: %s", saml2_auth.origxml) for assertion in saml2_auth.assertions: @@ -188,6 +200,9 @@ class SamlHandler: saml2_auth.in_response_to, None ) + for requirement in self._saml2_attribute_requirements: + _check_attribute_requirement(saml2_auth.ava, requirement) + remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url ) @@ -294,6 +309,21 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement): + values = ava.get(req.attribute, []) + for v in values: + if v == req.value: + return + + logger.info( + "SAML2 attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + raise AuthError(403, "You are not authorized to log in here.") + + DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) diff --git a/synapse/http/client.py b/synapse/http/client.py index 529532a063..8aeb70cdec 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -297,7 +297,7 @@ class SimpleHttpClient(object): outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) - logger.info("Sending request %s %s", method, redact_uri(uri)) + logger.debug("Sending request %s %s", method, redact_uri(uri)) with start_active_span( "outgoing-client-request", diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 0c02648015..369bf9c2fc 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -247,7 +247,7 @@ class MatrixHostnameEndpoint(object): port = server.port try: - logger.info("Connecting to %s:%i", host.decode("ascii"), port) + logger.debug("Connecting to %s:%i", host.decode("ascii"), port) endpoint = HostnameEndpoint(self._reactor, host, port) if self._tls_options: endpoint = wrapClientTLS(self._tls_options, endpoint) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2a6373937a..738be43f46 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -29,10 +29,11 @@ from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError -from twisted.internet.interfaces import IReactorPluggableNameResolver +from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse import synapse.metrics import synapse.util.retryutils @@ -74,7 +75,7 @@ MAXINT = sys.maxsize _next_id = 1 -@attr.s +@attr.s(frozen=True) class MatrixFederationRequest(object): method = attr.ib() """HTTP method @@ -110,26 +111,52 @@ class MatrixFederationRequest(object): :type: str|None """ + uri = attr.ib(init=False, type=bytes) + """The URI of this request + """ + def __attrs_post_init__(self): global _next_id - self.txn_id = "%s-O-%s" % (self.method, _next_id) + txn_id = "%s-O-%s" % (self.method, _next_id) _next_id = (_next_id + 1) % (MAXINT - 1) + object.__setattr__(self, "txn_id", txn_id) + + destination_bytes = self.destination.encode("ascii") + path_bytes = self.path.encode("ascii") + if self.query: + query_bytes = encode_query_args(self.query) + else: + query_bytes = b"" + + # The object is frozen so we can pre-compute this. + uri = urllib.parse.urlunparse( + (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") + ) + object.__setattr__(self, "uri", uri) + def get_json(self): if self.json_callback: return self.json_callback() return self.json -async def _handle_json_response(reactor, timeout_sec, request, response): +async def _handle_json_response( + reactor: IReactorTime, + timeout_sec: float, + request: MatrixFederationRequest, + response: IResponse, + start_ms: int, +): """ Reads the JSON body of a response, with a timeout Args: - reactor (IReactor): twisted reactor, for the timeout - timeout_sec (float): number of seconds to wait for response to complete - request (MatrixFederationRequest): the request that triggered the response - response (IResponse): response to the request + reactor: twisted reactor, for the timeout + timeout_sec: number of seconds to wait for response to complete + request: the request that triggered the response + response: response to the request + start_ms: Timestamp when request was made Returns: dict: parsed JSON response @@ -143,23 +170,35 @@ async def _handle_json_response(reactor, timeout_sec, request, response): body = await make_deferred_yieldable(d) except TimeoutError as e: logger.warning( - "{%s} [%s] Timed out reading response", request.txn_id, request.destination, + "{%s} [%s] Timed out reading response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), ) raise RequestSendFailed(e, can_retry=True) from e except Exception as e: logger.warning( - "{%s} [%s] Error reading response: %s", + "{%s} [%s] Error reading response %s %s: %s", request.txn_id, request.destination, + request.method, + request.uri.decode("ascii"), e, ) raise + + time_taken_secs = reactor.seconds() - start_ms / 1000 + logger.info( - "{%s} [%s] Completed: %d %s", + "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), + time_taken_secs, + request.method, + request.uri.decode("ascii"), ) return body @@ -261,7 +300,9 @@ class MatrixFederationHttpClient(object): # 'M_UNRECOGNIZED' which some endpoints can return when omitting a # trailing slash on Synapse <= v0.99.3. logger.info("Retrying request with trailing slash") - request.path += "/" + + # Request is frozen so we create a new instance + request = attr.evolve(request, path=request.path + "/") response = await self._send_request(request, **send_request_args) @@ -373,9 +414,7 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - url_bytes = urllib.parse.urlunparse( - (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") - ) + url_bytes = request.uri url_str = url_bytes.decode("ascii") url_to_sign_bytes = urllib.parse.urlunparse( @@ -402,7 +441,7 @@ class MatrixFederationHttpClient(object): headers_dict[b"Authorization"] = auth_headers - logger.info( + logger.debug( "{%s} [%s] Sending request: %s %s; timeout %fs", request.txn_id, request.destination, @@ -436,7 +475,6 @@ class MatrixFederationHttpClient(object): except DNSLookupError as e: raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e except Exception as e: - logger.info("Failed to send request: %s", e) raise RequestSendFailed(e, can_retry=True) from e incoming_responses_counter.labels( @@ -496,7 +534,7 @@ class MatrixFederationHttpClient(object): break except RequestSendFailed as e: - logger.warning( + logger.info( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -654,6 +692,8 @@ class MatrixFederationHttpClient(object): json=data, ) + start_ms = self.clock.time_msec() + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, @@ -664,7 +704,7 @@ class MatrixFederationHttpClient(object): ) body = await _handle_json_response( - self.reactor, self.default_timeout, request, response + self.reactor, self.default_timeout, request, response, start_ms ) return body @@ -720,6 +760,8 @@ class MatrixFederationHttpClient(object): method="POST", destination=destination, path=path, query=args, json=data ) + start_ms = self.clock.time_msec() + response = await self._send_request( request, long_retries=long_retries, @@ -733,7 +775,7 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout body = await _handle_json_response( - self.reactor, _sec_timeout, request, response + self.reactor, _sec_timeout, request, response, start_ms, ) return body @@ -786,6 +828,8 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) + start_ms = self.clock.time_msec() + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, @@ -796,7 +840,7 @@ class MatrixFederationHttpClient(object): ) body = await _handle_json_response( - self.reactor, self.default_timeout, request, response + self.reactor, self.default_timeout, request, response, start_ms ) return body @@ -846,6 +890,8 @@ class MatrixFederationHttpClient(object): method="DELETE", destination=destination, path=path, query=args ) + start_ms = self.clock.time_msec() + response = await self._send_request( request, long_retries=long_retries, @@ -854,7 +900,7 @@ class MatrixFederationHttpClient(object): ) body = await _handle_json_response( - self.reactor, self.default_timeout, request, response + self.reactor, self.default_timeout, request, response, start_ms ) return body @@ -914,12 +960,14 @@ class MatrixFederationHttpClient(object): ) raise logger.info( - "{%s} [%s] Completed: %d %s [%d bytes]", + "{%s} [%s] Completed: %d %s [%d bytes] %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), length, + request.method, + request.uri.decode("ascii"), ) return (length, headers) diff --git a/synapse/notifier.py b/synapse/notifier.py index 22ab4a9da5..694efe7116 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,7 +15,17 @@ import logging from collections import namedtuple -from typing import Callable, Iterable, List, TypeVar +from typing import ( + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, +) from prometheus_client import Counter @@ -24,12 +34,14 @@ from twisted.internet import defer import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError +from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import PreserveLoggingContext from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import StreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import Collection, StreamToken, UserID from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -77,7 +89,13 @@ class _NotifierUserStream(object): so that it can remove itself from the indexes in the Notifier class. """ - def __init__(self, user_id, rooms, current_token, time_now_ms): + def __init__( + self, + user_id: str, + rooms: Collection[str], + current_token: StreamToken, + time_now_ms: int, + ): self.user_id = user_id self.rooms = set(rooms) self.current_token = current_token @@ -93,13 +111,13 @@ class _NotifierUserStream(object): with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - def notify(self, stream_key, stream_id, time_now_ms): + def notify(self, stream_key: str, stream_id: int, time_now_ms: int): """Notify any listeners for this user of a new event from an event source. Args: - stream_key(str): The stream the event came from. - stream_id(str): The new id for the stream the event came from. - time_now_ms(int): The current time in milliseconds. + stream_key: The stream the event came from. + stream_id: The new id for the stream the event came from. + time_now_ms: The current time in milliseconds. """ self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token @@ -112,7 +130,7 @@ class _NotifierUserStream(object): self.notify_deferred = ObservableDeferred(defer.Deferred()) noify_deferred.callback(self.current_token) - def remove(self, notifier): + def remove(self, notifier: "Notifier"): """ Remove this listener from all the indexes in the Notifier it knows about. """ @@ -123,10 +141,10 @@ class _NotifierUserStream(object): notifier.user_to_user_stream.pop(self.user_id) - def count_listeners(self): + def count_listeners(self) -> int: return len(self.notify_deferred.observers()) - def new_listener(self, token): + def new_listener(self, token: StreamToken) -> _NotificationListener: """Returns a deferred that is resolved when there is a new token greater than the given token. @@ -159,14 +177,16 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs: "synapse.server.HomeServer"): - self.user_to_user_stream = {} - self.room_to_user_streams = {} + self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream] + self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]] self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = [] + self.pending_new_room_events = ( + [] + ) # type: List[Tuple[int, EventBase, Collection[str]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -178,10 +198,9 @@ class Notifier(object): self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() + self.federation_sender = None if hs.should_send_federation(): self.federation_sender = hs.get_federation_sender() - else: - self.federation_sender = None self.state_handler = hs.get_state_handler() @@ -193,12 +212,12 @@ class Notifier(object): # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): - all_user_streams = set() + all_user_streams = set() # type: Set[_NotifierUserStream] - for x in list(self.room_to_user_streams.values()): - all_user_streams |= x - for x in list(self.user_to_user_stream.values()): - all_user_streams.add(x) + for streams in list(self.room_to_user_streams.values()): + all_user_streams |= streams + for stream in list(self.user_to_user_stream.values()): + all_user_streams.add(stream) return sum(stream.count_listeners() for stream in all_user_streams) @@ -223,7 +242,11 @@ class Notifier(object): self.replication_callbacks.append(cb) def on_new_room_event( - self, event, room_stream_id, max_room_stream_id, extra_users=[] + self, + event: EventBase, + room_stream_id: int, + max_room_stream_id: int, + extra_users: Collection[str] = [], ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -241,11 +264,11 @@ class Notifier(object): self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_id): + def _notify_pending_new_room_events(self, max_room_stream_id: int): """Notify for the room events that were queued waiting for a previous event to be persisted. Args: - max_room_stream_id(int): The highest stream_id below which all + max_room_stream_id: The highest stream_id below which all events have been persisted. """ pending = self.pending_new_room_events @@ -258,7 +281,9 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) - def _on_new_room_event(self, event, room_stream_id, extra_users=[]): + def _on_new_room_event( + self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = [] + ): """Notify any user streams that are interested in this room event""" # poke any interested application service. run_as_background_process( @@ -275,13 +300,19 @@ class Notifier(object): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) - async def _notify_app_services(self, room_stream_id): + async def _notify_app_services(self, room_stream_id: int): try: await self.appservice_handler.notify_interested_services(room_stream_id) except Exception: logger.exception("Error notifying application services of event") - def on_new_event(self, stream_key, new_token, users=[], rooms=[]): + def on_new_event( + self, + stream_key: str, + new_token: int, + users: Collection[str] = [], + rooms: Collection[str] = [], + ): """ Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. @@ -307,14 +338,19 @@ class Notifier(object): self.notify_replication() - def on_new_replication_data(self): + def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" self.notify_replication() async def wait_for_events( - self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START - ): + self, + user_id: str, + timeout: int, + callback: Callable[[StreamToken, StreamToken], Awaitable[T]], + room_ids=None, + from_token=StreamToken.START, + ) -> T: """Wait until the callback returns a non empty response or the timeout fires. """ @@ -377,19 +413,16 @@ class Notifier(object): async def get_events_for( self, - user, - pagination_config, - timeout, - only_keys=None, - is_guest=False, - explicit_room_id=None, - ): + user: UserID, + pagination_config: PaginationConfig, + timeout: int, + is_guest: bool = False, + explicit_room_id: str = None, + ) -> EventStreamResult: """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. - If `only_keys` is not None, events from keys will be sent down. - If explicit_room_id is not set, the user's joined rooms will be polled for events. If explicit_room_id is set, that room will be polled for events only if @@ -404,11 +437,13 @@ class Notifier(object): room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - async def check_for_updates(before_token, after_token): + async def check_for_updates( + before_token: StreamToken, after_token: StreamToken + ) -> EventStreamResult: if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) - events = [] + events = [] # type: List[EventBase] end_token = from_token for name, source in self.event_sources.sources.items(): @@ -417,8 +452,6 @@ class Notifier(object): after_id = getattr(after_token, keyname) if before_id == after_id: continue - if only_keys and name not in only_keys: - continue new_events, new_key = await source.get_new_events( user=user, @@ -476,7 +509,9 @@ class Notifier(object): return result - async def _get_room_ids(self, user, explicit_room_id): + async def _get_room_ids( + self, user: UserID, explicit_room_id: Optional[str] + ) -> Tuple[Collection[str], bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: @@ -486,7 +521,7 @@ class Notifier(object): raise AuthError(403, "Non-joined access not allowed") return joined_room_ids, True - async def _is_world_readable(self, room_id): + async def _is_world_readable(self, room_id: str) -> bool: state = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) @@ -496,7 +531,7 @@ class Notifier(object): return False @log_function - def remove_expired_streams(self): + def remove_expired_streams(self) -> None: time_now_ms = self.clock.time_msec() expired_streams = [] expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS @@ -510,21 +545,21 @@ class Notifier(object): expired_stream.remove(self) @log_function - def _register_with_keys(self, user_stream): + def _register_with_keys(self, user_stream: _NotifierUserStream): self.user_to_user_stream[user_stream.user_id] = user_stream for room in user_stream.rooms: s = self.room_to_user_streams.setdefault(room, set()) s.add(user_stream) - def _user_joined_room(self, user_id, room_id): + def _user_joined_room(self, user_id: str, room_id: str): new_user_stream = self.user_to_user_stream.get(user_id) if new_user_stream is not None: room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) new_user_stream.rooms.add(room_id) - def notify_replication(self): + def notify_replication(self) -> None: """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html index bfd6449c5d..01cd9bdaf3 100644 --- a/synapse/res/templates/saml_error.html +++ b/synapse/res/templates/saml_error.html @@ -2,10 +2,17 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <title>SSO error</title> + <title>SSO login error</title> </head> <body> - <p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p> +{# a 403 means we have actively rejected their login #} +{% if code == 403 %} + <p>You are not allowed to log in here.</p> +{% else %} + <p> + There was an error during authentication: + </p> + <div id="errormsg" style="margin:20px 80px">{{ msg }}</div> <p> If you are seeing this page after clicking a link sent to you via email, make sure you only click the confirmation link once, and that you open the @@ -37,9 +44,9 @@ // to print one. let errorDesc = new URLSearchParams(searchStr).get("error_description") if (errorDesc) { - - document.getElementById("errormsg").innerText = ` ("${errorDesc}")`; + document.getElementById("errormsg").innerText = errorDesc; } </script> +{% endif %} </body> -</html> \ No newline at end of file +</html> |