summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-08-11 22:03:14 +0100
committerErik Johnston <erik@matrix.org>2020-08-11 22:03:14 +0100
commitfdb46b5442c63212a52e2296491a23f1935f9929 (patch)
treec46d0940415f96e3c0383c35f1d39b7462d5c20c /synapse
parentAdd comment explaining cast (diff)
parentAuto set logging filter (#8051) (diff)
downloadsynapse-fdb46b5442c63212a52e2296491a23f1935f9929.tar.xz
Merge remote-tracking branch 'origin/develop' into erikj/type_server
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/_util.py49
-rw-r--r--synapse/config/logger.py63
-rw-r--r--synapse/config/saml2_config.py50
-rw-r--r--synapse/handlers/events.py4
-rw-r--r--synapse/handlers/saml_handler.py42
-rw-r--r--synapse/http/client.py2
-rw-r--r--synapse/http/federation/matrix_federation_agent.py2
-rw-r--r--synapse/http/matrixfederationclient.py94
-rw-r--r--synapse/notifier.py131
-rw-r--r--synapse/res/templates/saml_error.html17
10 files changed, 354 insertions, 100 deletions
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
new file mode 100644
index 0000000000..cd31b1c3c9
--- /dev/null
+++ b/synapse/config/_util.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, List
+
+import jsonschema
+
+from synapse.config._base import ConfigError
+from synapse.types import JsonDict
+
+
+def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None:
+    """Validates a config setting against a JsonSchema definition
+
+    This can be used to validate a section of the config file against a schema
+    definition. If the validation fails, a ConfigError is raised with a textual
+    description of the problem.
+
+    Args:
+        json_schema: the schema to validate against
+        config: the configuration value to be validated
+        config_path: the path within the config file. This will be used as a basis
+           for the error message.
+    """
+    try:
+        jsonschema.validate(config, json_schema)
+    except jsonschema.ValidationError as e:
+        # copy `config_path` before modifying it.
+        path = list(config_path)
+        for p in list(e.path):
+            if isinstance(p, int):
+                path.append("<item %i>" % p)
+            else:
+                path.append(str(p))
+
+        raise ConfigError(
+            "Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
+        )
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index dd775a97e8..c96e6ef62a 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -55,24 +55,33 @@ formatters:
         format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
 %(request)s - %(message)s'
 
-filters:
-    context:
-        (): synapse.logging.context.LoggingContextFilter
-        request: ""
-
 handlers:
     file:
-        class: logging.handlers.RotatingFileHandler
+        class: logging.handlers.TimedRotatingFileHandler
         formatter: precise
         filename: ${log_file}
-        maxBytes: 104857600
-        backupCount: 10
-        filters: [context]
+        when: midnight
+        backupCount: 3  # Does not include the current log file.
         encoding: utf8
+
+    # Default to buffering writes to log file for efficiency. This means that
+    # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
+    # logs will still be flushed immediately.
+    buffer:
+        class: logging.handlers.MemoryHandler
+        target: file
+        # The capacity is the number of log lines that are buffered before
+        # being written to disk. Increasing this will lead to better
+        # performance, at the expensive of it taking longer for log lines to
+        # be written to disk.
+        capacity: 10
+        flushLevel: 30  # Flush for WARNING logs as well
+
+    # A handler that writes logs to stderr. Unused by default, but can be used
+    # instead of "buffer" and "file" in the logger handlers.
     console:
         class: logging.StreamHandler
         formatter: precise
-        filters: [context]
 
 loggers:
     synapse.storage.SQL:
@@ -80,9 +89,24 @@ loggers:
         # information such as access tokens.
         level: INFO
 
+    twisted:
+        # We send the twisted logging directly to the file handler,
+        # to work around https://github.com/matrix-org/synapse/issues/3471
+        # when using "buffer" logger. Use "console" to log to stderr instead.
+        handlers: [file]
+        propagate: false
+
 root:
     level: INFO
-    handlers: [file, console]
+
+    # Write logs to the `buffer` handler, which will buffer them together in memory,
+    # then write them to a file.
+    #
+    # Replace "buffer" with "console" to log to stderr instead. (Note that you'll
+    # also need to update the configuation for the `twisted` logger above, in
+    # this case.)
+    #
+    handlers: [buffer]
 
 disable_existing_loggers: false
 """
@@ -168,11 +192,26 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
 
         handler = logging.StreamHandler()
         handler.setFormatter(formatter)
-        handler.addFilter(LoggingContextFilter(request=""))
         logger.addHandler(handler)
     else:
         logging.config.dictConfig(log_config)
 
+    # We add a log record factory that runs all messages through the
+    # LoggingContextFilter so that we get the context *at the time we log*
+    # rather than when we write to a handler. This can be done in config using
+    # filter options, but care must when using e.g. MemoryHandler to buffer
+    # writes.
+
+    log_filter = LoggingContextFilter(request="")
+    old_factory = logging.getLogRecordFactory()
+
+    def factory(*args, **kwargs):
+        record = old_factory(*args, **kwargs)
+        log_filter.filter(record)
+        return record
+
+    logging.setLogRecordFactory(factory)
+
     # Route Twisted's native logging through to the standard library logging
     # system.
     observer = STDLibLogObserver()
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 293643b2de..9277b5f342 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 
 import logging
+from typing import Any, List
 
+import attr
 import jinja2
 import pkg_resources
 
@@ -23,6 +25,7 @@ from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
+from ._util import validate_config
 
 logger = logging.getLogger(__name__)
 
@@ -80,6 +83,11 @@ class SAML2Config(Config):
 
         self.saml2_enabled = True
 
+        attribute_requirements = saml2_config.get("attribute_requirements") or []
+        self.attribute_requirements = _parse_attribute_requirements_def(
+            attribute_requirements
+        )
+
         self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
             "grandfathered_mxid_source_attribute", "uid"
         )
@@ -341,6 +349,17 @@ class SAML2Config(Config):
           #
           #grandfathered_mxid_source_attribute: upn
 
+          # It is possible to configure Synapse to only allow logins if SAML attributes
+          # match particular values. The requirements can be listed under
+          # `attribute_requirements` as shown below. All of the listed attributes must
+          # match for the login to be permitted.
+          #
+          #attribute_requirements:
+          #  - attribute: userGroup
+          #    value: "staff"
+          #  - attribute: department
+          #    value: "sales"
+
           # Directory in which Synapse will try to find the template files below.
           # If not set, default templates from within the Synapse package will be used.
           #
@@ -368,3 +387,34 @@ class SAML2Config(Config):
         """ % {
             "config_dir_path": config_dir_path
         }
+
+
+@attr.s(frozen=True)
+class SamlAttributeRequirement:
+    """Object describing a single requirement for SAML attributes."""
+
+    attribute = attr.ib(type=str)
+    value = attr.ib(type=str)
+
+    JSON_SCHEMA = {
+        "type": "object",
+        "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
+        "required": ["attribute", "value"],
+    }
+
+
+ATTRIBUTE_REQUIREMENTS_SCHEMA = {
+    "type": "array",
+    "items": SamlAttributeRequirement.JSON_SCHEMA,
+}
+
+
+def _parse_attribute_requirements_def(
+    attribute_requirements: Any,
+) -> List[SamlAttributeRequirement]:
+    validate_config(
+        ATTRIBUTE_REQUIREMENTS_SCHEMA,
+        attribute_requirements,
+        config_path=["saml2_config", "attribute_requirements"],
+    )
+    return [SamlAttributeRequirement(**x) for x in attribute_requirements]
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 71a89f09c7..1924636c4d 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -57,13 +57,10 @@ class EventStreamHandler(BaseHandler):
         timeout=0,
         as_client_event=True,
         affect_presence=True,
-        only_keys=None,
         room_id=None,
         is_guest=False,
     ):
         """Fetches the events stream for a given user.
-
-        If `only_keys` is not None, events from keys will be sent down.
         """
 
         if room_id:
@@ -93,7 +90,6 @@ class EventStreamHandler(BaseHandler):
                 auth_user,
                 pagin_config,
                 timeout,
-                only_keys=only_keys,
                 is_guest=is_guest,
                 explicit_room_id=room_id,
             )
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 2d506dc1f2..c1fcb98454 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,15 +14,16 @@
 # limitations under the License.
 import logging
 import re
-from typing import Callable, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
 
 import attr
 import saml2
 import saml2.response
 from saml2.client import Saml2Client
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import AuthError, SynapseError
 from synapse.config import ConfigError
+from synapse.config.saml2_config import SamlAttributeRequirement
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -34,6 +35,9 @@ from synapse.types import (
 from synapse.util.async_helpers import Linearizer
 from synapse.util.iterutils import chunk_seq
 
+if TYPE_CHECKING:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 
@@ -49,7 +53,7 @@ class Saml2SessionData:
 
 
 class SamlHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
@@ -62,6 +66,7 @@ class SamlHandler:
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
+        self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
 
         # plugin to do custom mapping from saml response to mxid
         self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -73,7 +78,7 @@ class SamlHandler:
         self._auth_provider_id = "saml"
 
         # a map from saml session id to Saml2SessionData object
-        self._outstanding_requests_dict = {}
+        self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
@@ -165,11 +170,18 @@ class SamlHandler:
                 saml2.BINDING_HTTP_POST,
                 outstanding=self._outstanding_requests_dict,
             )
+        except saml2.response.UnsolicitedResponse as e:
+            # the pysaml2 library helpfully logs an ERROR here, but neglects to log
+            # the session ID. I don't really want to put the full text of the exception
+            # in the (user-visible) exception message, so let's log the exception here
+            # so we can track down the session IDs later.
+            logger.warning(str(e))
+            raise SynapseError(400, "Unexpected SAML2 login.")
         except Exception as e:
-            raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,))
+            raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
 
         if saml2_auth.not_signed:
-            raise SynapseError(400, "SAML2 response was not signed")
+            raise SynapseError(400, "SAML2 response was not signed.")
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
         for assertion in saml2_auth.assertions:
@@ -188,6 +200,9 @@ class SamlHandler:
             saml2_auth.in_response_to, None
         )
 
+        for requirement in self._saml2_attribute_requirements:
+            _check_attribute_requirement(saml2_auth.ava, requirement)
+
         remote_user_id = self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
         )
@@ -294,6 +309,21 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
+    values = ava.get(req.attribute, [])
+    for v in values:
+        if v == req.value:
+            return
+
+    logger.info(
+        "SAML2 attribute %s did not match required value '%s' (was '%s')",
+        req.attribute,
+        req.value,
+        values,
+    )
+    raise AuthError(403, "You are not authorized to log in here.")
+
+
 DOT_REPLACE_PATTERN = re.compile(
     ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
 )
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 529532a063..8aeb70cdec 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -297,7 +297,7 @@ class SimpleHttpClient(object):
         outgoing_requests_counter.labels(method).inc()
 
         # log request but strip `access_token` (AS requests for example include this)
-        logger.info("Sending request %s %s", method, redact_uri(uri))
+        logger.debug("Sending request %s %s", method, redact_uri(uri))
 
         with start_active_span(
             "outgoing-client-request",
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 0c02648015..369bf9c2fc 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -247,7 +247,7 @@ class MatrixHostnameEndpoint(object):
             port = server.port
 
             try:
-                logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+                logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
                 endpoint = HostnameEndpoint(self._reactor, host, port)
                 if self._tls_options:
                     endpoint = wrapClientTLS(self._tls_options, endpoint)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2a6373937a..738be43f46 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -29,10 +29,11 @@ from zope.interface import implementer
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver
+from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
 from twisted.internet.task import _EPSILON, Cooperator
 from twisted.web._newclient import ResponseDone
 from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
 
 import synapse.metrics
 import synapse.util.retryutils
@@ -74,7 +75,7 @@ MAXINT = sys.maxsize
 _next_id = 1
 
 
-@attr.s
+@attr.s(frozen=True)
 class MatrixFederationRequest(object):
     method = attr.ib()
     """HTTP method
@@ -110,26 +111,52 @@ class MatrixFederationRequest(object):
     :type: str|None
     """
 
+    uri = attr.ib(init=False, type=bytes)
+    """The URI of this request
+    """
+
     def __attrs_post_init__(self):
         global _next_id
-        self.txn_id = "%s-O-%s" % (self.method, _next_id)
+        txn_id = "%s-O-%s" % (self.method, _next_id)
         _next_id = (_next_id + 1) % (MAXINT - 1)
 
+        object.__setattr__(self, "txn_id", txn_id)
+
+        destination_bytes = self.destination.encode("ascii")
+        path_bytes = self.path.encode("ascii")
+        if self.query:
+            query_bytes = encode_query_args(self.query)
+        else:
+            query_bytes = b""
+
+        # The object is frozen so we can pre-compute this.
+        uri = urllib.parse.urlunparse(
+            (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
+        )
+        object.__setattr__(self, "uri", uri)
+
     def get_json(self):
         if self.json_callback:
             return self.json_callback()
         return self.json
 
 
-async def _handle_json_response(reactor, timeout_sec, request, response):
+async def _handle_json_response(
+    reactor: IReactorTime,
+    timeout_sec: float,
+    request: MatrixFederationRequest,
+    response: IResponse,
+    start_ms: int,
+):
     """
     Reads the JSON body of a response, with a timeout
 
     Args:
-        reactor (IReactor): twisted reactor, for the timeout
-        timeout_sec (float): number of seconds to wait for response to complete
-        request (MatrixFederationRequest): the request that triggered the response
-        response (IResponse): response to the request
+        reactor: twisted reactor, for the timeout
+        timeout_sec: number of seconds to wait for response to complete
+        request: the request that triggered the response
+        response: response to the request
+        start_ms: Timestamp when request was made
 
     Returns:
         dict: parsed JSON response
@@ -143,23 +170,35 @@ async def _handle_json_response(reactor, timeout_sec, request, response):
         body = await make_deferred_yieldable(d)
     except TimeoutError as e:
         logger.warning(
-            "{%s} [%s] Timed out reading response", request.txn_id, request.destination,
+            "{%s} [%s] Timed out reading response - %s %s",
+            request.txn_id,
+            request.destination,
+            request.method,
+            request.uri.decode("ascii"),
         )
         raise RequestSendFailed(e, can_retry=True) from e
     except Exception as e:
         logger.warning(
-            "{%s} [%s] Error reading response: %s",
+            "{%s} [%s] Error reading response %s %s: %s",
             request.txn_id,
             request.destination,
+            request.method,
+            request.uri.decode("ascii"),
             e,
         )
         raise
+
+    time_taken_secs = reactor.seconds() - start_ms / 1000
+
     logger.info(
-        "{%s} [%s] Completed: %d %s",
+        "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
         request.txn_id,
         request.destination,
         response.code,
         response.phrase.decode("ascii", errors="replace"),
+        time_taken_secs,
+        request.method,
+        request.uri.decode("ascii"),
     )
     return body
 
@@ -261,7 +300,9 @@ class MatrixFederationHttpClient(object):
             # 'M_UNRECOGNIZED' which some endpoints can return when omitting a
             # trailing slash on Synapse <= v0.99.3.
             logger.info("Retrying request with trailing slash")
-            request.path += "/"
+
+            # Request is frozen so we create a new instance
+            request = attr.evolve(request, path=request.path + "/")
 
             response = await self._send_request(request, **send_request_args)
 
@@ -373,9 +414,7 @@ class MatrixFederationHttpClient(object):
             else:
                 retries_left = MAX_SHORT_RETRIES
 
-            url_bytes = urllib.parse.urlunparse(
-                (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
-            )
+            url_bytes = request.uri
             url_str = url_bytes.decode("ascii")
 
             url_to_sign_bytes = urllib.parse.urlunparse(
@@ -402,7 +441,7 @@ class MatrixFederationHttpClient(object):
 
                     headers_dict[b"Authorization"] = auth_headers
 
-                    logger.info(
+                    logger.debug(
                         "{%s} [%s] Sending request: %s %s; timeout %fs",
                         request.txn_id,
                         request.destination,
@@ -436,7 +475,6 @@ class MatrixFederationHttpClient(object):
                     except DNSLookupError as e:
                         raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
                     except Exception as e:
-                        logger.info("Failed to send request: %s", e)
                         raise RequestSendFailed(e, can_retry=True) from e
 
                     incoming_responses_counter.labels(
@@ -496,7 +534,7 @@ class MatrixFederationHttpClient(object):
 
                     break
                 except RequestSendFailed as e:
-                    logger.warning(
+                    logger.info(
                         "{%s} [%s] Request failed: %s %s: %s",
                         request.txn_id,
                         request.destination,
@@ -654,6 +692,8 @@ class MatrixFederationHttpClient(object):
             json=data,
         )
 
+        start_ms = self.clock.time_msec()
+
         response = await self._send_request_with_optional_trailing_slash(
             request,
             try_trailing_slash_on_400,
@@ -664,7 +704,7 @@ class MatrixFederationHttpClient(object):
         )
 
         body = await _handle_json_response(
-            self.reactor, self.default_timeout, request, response
+            self.reactor, self.default_timeout, request, response, start_ms
         )
 
         return body
@@ -720,6 +760,8 @@ class MatrixFederationHttpClient(object):
             method="POST", destination=destination, path=path, query=args, json=data
         )
 
+        start_ms = self.clock.time_msec()
+
         response = await self._send_request(
             request,
             long_retries=long_retries,
@@ -733,7 +775,7 @@ class MatrixFederationHttpClient(object):
             _sec_timeout = self.default_timeout
 
         body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response
+            self.reactor, _sec_timeout, request, response, start_ms,
         )
         return body
 
@@ -786,6 +828,8 @@ class MatrixFederationHttpClient(object):
             method="GET", destination=destination, path=path, query=args
         )
 
+        start_ms = self.clock.time_msec()
+
         response = await self._send_request_with_optional_trailing_slash(
             request,
             try_trailing_slash_on_400,
@@ -796,7 +840,7 @@ class MatrixFederationHttpClient(object):
         )
 
         body = await _handle_json_response(
-            self.reactor, self.default_timeout, request, response
+            self.reactor, self.default_timeout, request, response, start_ms
         )
 
         return body
@@ -846,6 +890,8 @@ class MatrixFederationHttpClient(object):
             method="DELETE", destination=destination, path=path, query=args
         )
 
+        start_ms = self.clock.time_msec()
+
         response = await self._send_request(
             request,
             long_retries=long_retries,
@@ -854,7 +900,7 @@ class MatrixFederationHttpClient(object):
         )
 
         body = await _handle_json_response(
-            self.reactor, self.default_timeout, request, response
+            self.reactor, self.default_timeout, request, response, start_ms
         )
         return body
 
@@ -914,12 +960,14 @@ class MatrixFederationHttpClient(object):
             )
             raise
         logger.info(
-            "{%s} [%s] Completed: %d %s [%d bytes]",
+            "{%s} [%s] Completed: %d %s [%d bytes] %s %s",
             request.txn_id,
             request.destination,
             response.code,
             response.phrase.decode("ascii", errors="replace"),
             length,
+            request.method,
+            request.uri.decode("ascii"),
         )
         return (length, headers)
 
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 22ab4a9da5..694efe7116 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -15,7 +15,17 @@
 
 import logging
 from collections import namedtuple
-from typing import Callable, Iterable, List, TypeVar
+from typing import (
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
 
 from prometheus_client import Counter
 
@@ -24,12 +34,14 @@ from twisted.internet import defer
 import synapse.server
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError
+from synapse.events import EventBase
 from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import PreserveLoggingContext
 from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import StreamToken
+from synapse.streams.config import PaginationConfig
+from synapse.types import Collection, StreamToken, UserID
 from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
 from synapse.util.metrics import Measure
 from synapse.visibility import filter_events_for_client
@@ -77,7 +89,13 @@ class _NotifierUserStream(object):
     so that it can remove itself from the indexes in the Notifier class.
     """
 
-    def __init__(self, user_id, rooms, current_token, time_now_ms):
+    def __init__(
+        self,
+        user_id: str,
+        rooms: Collection[str],
+        current_token: StreamToken,
+        time_now_ms: int,
+    ):
         self.user_id = user_id
         self.rooms = set(rooms)
         self.current_token = current_token
@@ -93,13 +111,13 @@ class _NotifierUserStream(object):
         with PreserveLoggingContext():
             self.notify_deferred = ObservableDeferred(defer.Deferred())
 
-    def notify(self, stream_key, stream_id, time_now_ms):
+    def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
         """Notify any listeners for this user of a new event from an
         event source.
         Args:
-            stream_key(str): The stream the event came from.
-            stream_id(str): The new id for the stream the event came from.
-            time_now_ms(int): The current time in milliseconds.
+            stream_key: The stream the event came from.
+            stream_id: The new id for the stream the event came from.
+            time_now_ms: The current time in milliseconds.
         """
         self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
         self.last_notified_token = self.current_token
@@ -112,7 +130,7 @@ class _NotifierUserStream(object):
             self.notify_deferred = ObservableDeferred(defer.Deferred())
             noify_deferred.callback(self.current_token)
 
-    def remove(self, notifier):
+    def remove(self, notifier: "Notifier"):
         """ Remove this listener from all the indexes in the Notifier
         it knows about.
         """
@@ -123,10 +141,10 @@ class _NotifierUserStream(object):
 
         notifier.user_to_user_stream.pop(self.user_id)
 
-    def count_listeners(self):
+    def count_listeners(self) -> int:
         return len(self.notify_deferred.observers())
 
-    def new_listener(self, token):
+    def new_listener(self, token: StreamToken) -> _NotificationListener:
         """Returns a deferred that is resolved when there is a new token
         greater than the given token.
 
@@ -159,14 +177,16 @@ class Notifier(object):
     UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
 
     def __init__(self, hs: "synapse.server.HomeServer"):
-        self.user_to_user_stream = {}
-        self.room_to_user_streams = {}
+        self.user_to_user_stream = {}  # type: Dict[str, _NotifierUserStream]
+        self.room_to_user_streams = {}  # type: Dict[str, Set[_NotifierUserStream]]
 
         self.hs = hs
         self.storage = hs.get_storage()
         self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastore()
-        self.pending_new_room_events = []
+        self.pending_new_room_events = (
+            []
+        )  # type: List[Tuple[int, EventBase, Collection[str]]]
 
         # Called when there are new things to stream over replication
         self.replication_callbacks = []  # type: List[Callable[[], None]]
@@ -178,10 +198,9 @@ class Notifier(object):
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
 
+        self.federation_sender = None
         if hs.should_send_federation():
             self.federation_sender = hs.get_federation_sender()
-        else:
-            self.federation_sender = None
 
         self.state_handler = hs.get_state_handler()
 
@@ -193,12 +212,12 @@ class Notifier(object):
         # when rendering the metrics page, which is likely once per minute at
         # most when scraping it.
         def count_listeners():
-            all_user_streams = set()
+            all_user_streams = set()  # type: Set[_NotifierUserStream]
 
-            for x in list(self.room_to_user_streams.values()):
-                all_user_streams |= x
-            for x in list(self.user_to_user_stream.values()):
-                all_user_streams.add(x)
+            for streams in list(self.room_to_user_streams.values()):
+                all_user_streams |= streams
+            for stream in list(self.user_to_user_stream.values()):
+                all_user_streams.add(stream)
 
             return sum(stream.count_listeners() for stream in all_user_streams)
 
@@ -223,7 +242,11 @@ class Notifier(object):
         self.replication_callbacks.append(cb)
 
     def on_new_room_event(
-        self, event, room_stream_id, max_room_stream_id, extra_users=[]
+        self,
+        event: EventBase,
+        room_stream_id: int,
+        max_room_stream_id: int,
+        extra_users: Collection[str] = [],
     ):
         """ Used by handlers to inform the notifier something has happened
         in the room, room event wise.
@@ -241,11 +264,11 @@ class Notifier(object):
 
         self.notify_replication()
 
-    def _notify_pending_new_room_events(self, max_room_stream_id):
+    def _notify_pending_new_room_events(self, max_room_stream_id: int):
         """Notify for the room events that were queued waiting for a previous
         event to be persisted.
         Args:
-            max_room_stream_id(int): The highest stream_id below which all
+            max_room_stream_id: The highest stream_id below which all
                 events have been persisted.
         """
         pending = self.pending_new_room_events
@@ -258,7 +281,9 @@ class Notifier(object):
             else:
                 self._on_new_room_event(event, room_stream_id, extra_users)
 
-    def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
+    def _on_new_room_event(
+        self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = []
+    ):
         """Notify any user streams that are interested in this room event"""
         # poke any interested application service.
         run_as_background_process(
@@ -275,13 +300,19 @@ class Notifier(object):
             "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
         )
 
-    async def _notify_app_services(self, room_stream_id):
+    async def _notify_app_services(self, room_stream_id: int):
         try:
             await self.appservice_handler.notify_interested_services(room_stream_id)
         except Exception:
             logger.exception("Error notifying application services of event")
 
-    def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
+    def on_new_event(
+        self,
+        stream_key: str,
+        new_token: int,
+        users: Collection[str] = [],
+        rooms: Collection[str] = [],
+    ):
         """ Used to inform listeners that something has happened event wise.
 
         Will wake up all listeners for the given users and rooms.
@@ -307,14 +338,19 @@ class Notifier(object):
 
                 self.notify_replication()
 
-    def on_new_replication_data(self):
+    def on_new_replication_data(self) -> None:
         """Used to inform replication listeners that something has happend
         without waking up any of the normal user event streams"""
         self.notify_replication()
 
     async def wait_for_events(
-        self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
-    ):
+        self,
+        user_id: str,
+        timeout: int,
+        callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
+        room_ids=None,
+        from_token=StreamToken.START,
+    ) -> T:
         """Wait until the callback returns a non empty response or the
         timeout fires.
         """
@@ -377,19 +413,16 @@ class Notifier(object):
 
     async def get_events_for(
         self,
-        user,
-        pagination_config,
-        timeout,
-        only_keys=None,
-        is_guest=False,
-        explicit_room_id=None,
-    ):
+        user: UserID,
+        pagination_config: PaginationConfig,
+        timeout: int,
+        is_guest: bool = False,
+        explicit_room_id: str = None,
+    ) -> EventStreamResult:
         """ For the given user and rooms, return any new events for them. If
         there are no new events wait for up to `timeout` milliseconds for any
         new events to happen before returning.
 
-        If `only_keys` is not None, events from keys will be sent down.
-
         If explicit_room_id is not set, the user's joined rooms will be polled
         for events.
         If explicit_room_id is set, that room will be polled for events only if
@@ -404,11 +437,13 @@ class Notifier(object):
         room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
         is_peeking = not is_joined
 
-        async def check_for_updates(before_token, after_token):
+        async def check_for_updates(
+            before_token: StreamToken, after_token: StreamToken
+        ) -> EventStreamResult:
             if not after_token.is_after(before_token):
                 return EventStreamResult([], (from_token, from_token))
 
-            events = []
+            events = []  # type: List[EventBase]
             end_token = from_token
 
             for name, source in self.event_sources.sources.items():
@@ -417,8 +452,6 @@ class Notifier(object):
                 after_id = getattr(after_token, keyname)
                 if before_id == after_id:
                     continue
-                if only_keys and name not in only_keys:
-                    continue
 
                 new_events, new_key = await source.get_new_events(
                     user=user,
@@ -476,7 +509,9 @@ class Notifier(object):
 
         return result
 
-    async def _get_room_ids(self, user, explicit_room_id):
+    async def _get_room_ids(
+        self, user: UserID, explicit_room_id: Optional[str]
+    ) -> Tuple[Collection[str], bool]:
         joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
         if explicit_room_id:
             if explicit_room_id in joined_room_ids:
@@ -486,7 +521,7 @@ class Notifier(object):
             raise AuthError(403, "Non-joined access not allowed")
         return joined_room_ids, True
 
-    async def _is_world_readable(self, room_id):
+    async def _is_world_readable(self, room_id: str) -> bool:
         state = await self.state_handler.get_current_state(
             room_id, EventTypes.RoomHistoryVisibility, ""
         )
@@ -496,7 +531,7 @@ class Notifier(object):
             return False
 
     @log_function
-    def remove_expired_streams(self):
+    def remove_expired_streams(self) -> None:
         time_now_ms = self.clock.time_msec()
         expired_streams = []
         expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
@@ -510,21 +545,21 @@ class Notifier(object):
             expired_stream.remove(self)
 
     @log_function
-    def _register_with_keys(self, user_stream):
+    def _register_with_keys(self, user_stream: _NotifierUserStream):
         self.user_to_user_stream[user_stream.user_id] = user_stream
 
         for room in user_stream.rooms:
             s = self.room_to_user_streams.setdefault(room, set())
             s.add(user_stream)
 
-    def _user_joined_room(self, user_id, room_id):
+    def _user_joined_room(self, user_id: str, room_id: str):
         new_user_stream = self.user_to_user_stream.get(user_id)
         if new_user_stream is not None:
             room_streams = self.room_to_user_streams.setdefault(room_id, set())
             room_streams.add(new_user_stream)
             new_user_stream.rooms.add(room_id)
 
-    def notify_replication(self):
+    def notify_replication(self) -> None:
         """Notify the any replication listeners that there's a new event"""
         for cb in self.replication_callbacks:
             cb()
diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html
index bfd6449c5d..01cd9bdaf3 100644
--- a/synapse/res/templates/saml_error.html
+++ b/synapse/res/templates/saml_error.html
@@ -2,10 +2,17 @@
 <html lang="en">
 <head>
     <meta charset="UTF-8">
-    <title>SSO error</title>
+    <title>SSO login error</title>
 </head>
 <body>
-    <p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p>
+{# a 403 means we have actively rejected their login #}
+{% if code == 403 %}
+    <p>You are not allowed to log in here.</p>
+{% else %}
+    <p>
+        There was an error during authentication:
+    </p>
+    <div id="errormsg" style="margin:20px 80px">{{ msg }}</div>
     <p>
         If you are seeing this page after clicking a link sent to you via email, make
         sure you only click the confirmation link once, and that you open the
@@ -37,9 +44,9 @@
         // to print one.
         let errorDesc = new URLSearchParams(searchStr).get("error_description")
         if (errorDesc) {
-
-            document.getElementById("errormsg").innerText = ` ("${errorDesc}")`;
+            document.getElementById("errormsg").innerText = errorDesc;
         }
     </script>
+{% endif %}
 </body>
-</html>
\ No newline at end of file
+</html>