diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index d4e887a3e0..4df3f93c1c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
- log_context_filter = LoggingContextFilter(request="")
+ log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d58dc3cc29..f385c72526 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -198,27 +198,25 @@ class AuthHandler(BaseHandler):
self._password_enabled = hs.config.password_enabled
self._password_localdb_enabled = hs.config.password_localdb_enabled
- # we keep this as a list despite the O(N^2) implication so that we can
- # keep PASSWORD first and avoid confusing clients which pick the first
- # type in the list. (NB that the spec doesn't require us to do so and
- # clients which favour types that they don't understand over those that
- # they do are technically broken)
-
# start out by assuming PASSWORD is enabled; we will remove it later if not.
- login_types = []
+ login_types = set()
if self._password_localdb_enabled:
- login_types.append(LoginType.PASSWORD)
+ login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
- if hasattr(provider, "get_supported_login_types"):
- for t in provider.get_supported_login_types().keys():
- if t not in login_types:
- login_types.append(t)
+ login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
+ login_types.discard(LoginType.PASSWORD)
+
+ # Some clients just pick the first type in the list. In this case, we want
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
+ # that comes first, where it's present.
+ self._supported_login_types = []
+ if LoginType.PASSWORD in login_types:
+ self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
-
- self._supported_login_types = login_types
+ self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index f2ca1ddb53..6001fe3e27 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -163,6 +163,29 @@ class SamlHandler(BaseHandler):
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+ await self._handle_authn_response(request, saml2_auth, relay_state)
+
+ async def _handle_authn_response(
+ self,
+ request: SynapseRequest,
+ saml2_auth: saml2.response.AuthnResponse,
+ relay_state: str,
+ ) -> None:
+ """Handle an AuthnResponse, having parsed it from the request params
+
+ Assumes that the signature on the response object has been checked. Maps
+ the user onto an MXID, registering them if necessary, and returns a response
+ to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+ saml2_auth: the parsed AuthnResponse object
+ relay_state: the RelayState query param, which encodes the URI to rediret
+ back to
+ """
+
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5f0581dc3f..5a5790831b 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -128,8 +128,7 @@ class SynapseRequest(Request):
# create a LogContext for this request
request_id = self.get_request_id()
- logcontext = self.logcontext = LoggingContext(request_id)
- logcontext.request = request_id
+ self.logcontext = LoggingContext(request_id, request=request_id)
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..a507a83e93 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel:
def copy_to(self, record):
pass
- def copy_to_twisted_log_entry(self, record):
- record["request"] = None
- record["scope"] = None
-
def start(self, rusage: "Optional[resource._RUsage]"):
pass
@@ -372,13 +368,6 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
- def copy_to_twisted_log_entry(self, record) -> None:
- """
- Copy logging fields from this context to a Twisted log record.
- """
- record["request"] = self.request
- record["scope"] = self.scope
-
def start(self, rusage: "Optional[resource._RUsage]") -> None:
"""
Record that this logcontext is currently running.
@@ -542,13 +531,10 @@ class LoggingContext:
class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each
record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
"""
- def __init__(self, **defaults) -> None:
- self.defaults = defaults
+ def __init__(self, request: str = ""):
+ self._default_request = request
def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
@@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter):
True to include the record in the log output.
"""
context = current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
+ record.request = self._default_request
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
- context.copy_to(record)
+ # Logging is interested in the request.
+ record.request = context.request
return True
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 76b7decf26..70e0fa45d9 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
- with BackgroundProcessLoggingContext(desc) as context:
- context.request = "%s-%i" % (desc, count)
+ with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
try:
ctx = noop_context_manager()
if bg_start_span:
@@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"]
- def __init__(self, name: str):
- super().__init__(name)
+ def __init__(self, name: str, request: Optional[str] = None):
+ super().__init__(name, request=request)
self._proc = _BackgroundProcess(name, self)
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 3d2e874838..ad07ee86f6 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict
from synapse.types import RoomStreamToken
@@ -36,12 +36,21 @@ class Pusher(metaclass=abc.ABCMeta):
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
- # because of potential out-of-order event serialisation. This starts
- # off as None though as we don't know any better.
- self.max_stream_ordering = None # type: Optional[int]
+ # because of potential out-of-order event serialisation.
+ self.max_stream_ordering = self.store.get_room_max_stream_ordering()
- @abc.abstractmethod
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ max_stream_ordering = max_token.stream
+
+ self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ self._start_processing()
+
+ @abc.abstractmethod
+ def _start_processing(self):
+ """Start processing push notifications."""
raise NotImplementedError()
@abc.abstractmethod
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 64a35c1994..11a97b8df4 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -22,7 +22,6 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher
from synapse.push.mailer import Mailer
-from synapse.types import RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -93,20 +92,6 @@ class EmailPusher(Pusher):
pass
self.timed_call = None
- def on_new_notifications(self, max_token: RoomStreamToken) -> None:
- # We just use the minimum stream ordering and ignore the vector clock
- # component. This is safe to do as long as we *always* ignore the vector
- # clock components.
- max_stream_ordering = max_token.stream
-
- if self.max_stream_ordering:
- self.max_stream_ordering = max(
- max_stream_ordering, self.max_stream_ordering
- )
- else:
- self.max_stream_ordering = max_stream_ordering
- self._start_processing()
-
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
@@ -172,7 +157,6 @@ class EmailPusher(Pusher):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
- assert self.max_stream_ordering is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 5408aa1295..e8b25bcd2a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -26,7 +26,6 @@ from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfigException
-from synapse.types import RoomStreamToken
from . import push_rule_evaluator, push_tools
@@ -122,17 +121,6 @@ class HttpPusher(Pusher):
if should_check_for_notifs:
self._start_processing()
- def on_new_notifications(self, max_token: RoomStreamToken) -> None:
- # We just use the minimum stream ordering and ignore the vector clock
- # component. This is safe to do as long as we *always* ignore the vector
- # clock components.
- max_stream_ordering = max_token.stream
-
- self.max_stream_ordering = max(
- max_stream_ordering, self.max_stream_ordering or 0
- )
- self._start_processing()
-
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
@@ -192,10 +180,7 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
-
- fn = self.store.get_unread_push_actions_for_user_in_range_for_http
- assert self.max_stream_ordering is not None
- unprocessed = await fn(
+ unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 748b1407c6..409e46da8f 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -129,9 +129,8 @@ class PusherPool:
)
# create the pusher setting last_stream_ordering to the current maximum
- # stream ordering in event_push_actions, so it will process
- # pushes from this point onwards.
- last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
+ # stream ordering, so it will process pushes from this point onwards.
+ last_stream_ordering = self.store.get_room_max_stream_ordering()
await self.store.add_pusher(
user_id=user_id,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..804da994ea 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id
- self._logging_context = BackgroundProcessLoggingContext(ctx_name)
- self._logging_context.request = ctx_name
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 2e56dfaf31..e5c03cc609 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -894,16 +894,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
- async def get_latest_push_action_stream_ordering(self):
- def f(txn):
- txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
- return txn.fetchone()
-
- result = await self.db_pool.runInteraction(
- "get_latest_push_action_stream_ordering", f
- )
- return result[0] or 0
-
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
|