summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/config/logger.py2
-rw-r--r--synapse/handlers/auth.py26
-rw-r--r--synapse/handlers/saml_handler.py23
-rw-r--r--synapse/http/site.py3
-rw-r--r--synapse/logging/context.py24
-rw-r--r--synapse/metrics/background_process_metrics.py7
-rw-r--r--synapse/push/__init__.py19
-rw-r--r--synapse/push/emailpusher.py16
-rw-r--r--synapse/push/httppusher.py17
-rw-r--r--synapse/push/pusherpool.py5
-rw-r--r--synapse/replication/tcp/protocol.py3
-rw-r--r--synapse/storage/databases/main/event_push_actions.py10
12 files changed, 63 insertions, 92 deletions
diff --git a/synapse/config/logger.py b/synapse/config/logger.py

index d4e887a3e0..4df3f93c1c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py
@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> # filter options, but care must when using e.g. MemoryHandler to buffer # writes. - log_context_filter = LoggingContextFilter(request="") + log_context_filter = LoggingContextFilter() log_metadata_filter = MetadataFilter({"server_name": config.server_name}) old_factory = logging.getLogRecordFactory() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d58dc3cc29..f385c72526 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -198,27 +198,25 @@ class AuthHandler(BaseHandler): self._password_enabled = hs.config.password_enabled self._password_localdb_enabled = hs.config.password_localdb_enabled - # we keep this as a list despite the O(N^2) implication so that we can - # keep PASSWORD first and avoid confusing clients which pick the first - # type in the list. (NB that the spec doesn't require us to do so and - # clients which favour types that they don't understand over those that - # they do are technically broken) - # start out by assuming PASSWORD is enabled; we will remove it later if not. - login_types = [] + login_types = set() if self._password_localdb_enabled: - login_types.append(LoginType.PASSWORD) + login_types.add(LoginType.PASSWORD) for provider in self.password_providers: - if hasattr(provider, "get_supported_login_types"): - for t in provider.get_supported_login_types().keys(): - if t not in login_types: - login_types.append(t) + login_types.update(provider.get_supported_login_types().keys()) if not self._password_enabled: + login_types.discard(LoginType.PASSWORD) + + # Some clients just pick the first type in the list. In this case, we want + # them to use PASSWORD (rather than token or whatever), so we want to make sure + # that comes first, where it's present. + self._supported_login_types = [] + if LoginType.PASSWORD in login_types: + self._supported_login_types.append(LoginType.PASSWORD) login_types.remove(LoginType.PASSWORD) - - self._supported_login_types = login_types + self._supported_login_types.extend(login_types) # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index f2ca1ddb53..6001fe3e27 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py
@@ -163,6 +163,29 @@ class SamlHandler(BaseHandler): return logger.debug("SAML2 response: %s", saml2_auth.origxml) + + await self._handle_authn_response(request, saml2_auth, relay_state) + + async def _handle_authn_response( + self, + request: SynapseRequest, + saml2_auth: saml2.response.AuthnResponse, + relay_state: str, + ) -> None: + """Handle an AuthnResponse, having parsed it from the request params + + Assumes that the signature on the response object has been checked. Maps + the user onto an MXID, registering them if necessary, and returns a response + to the browser. + + Args: + request: the incoming request from the browser. We'll respond to it with an + HTML page or a redirect + saml2_auth: the parsed AuthnResponse object + relay_state: the RelayState query param, which encodes the URI to rediret + back to + """ + for assertion in saml2_auth.assertions: # kibana limits the length of a log field, whereas this is all rather # useful, so split it up. diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5f0581dc3f..5a5790831b 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -128,8 +128,7 @@ class SynapseRequest(Request): # create a LogContext for this request request_id = self.get_request_id() - logcontext = self.logcontext = LoggingContext(request_id) - logcontext.request = request_id + self.logcontext = LoggingContext(request_id, request=request_id) # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..a507a83e93 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel: def copy_to(self, record): pass - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - def start(self, rusage: "Optional[resource._RUsage]"): pass @@ -372,13 +368,6 @@ class LoggingContext: # we also track the current scope: record.scope = self.scope - def copy_to_twisted_log_entry(self, record) -> None: - """ - Copy logging fields from this context to a Twisted log record. - """ - record["request"] = self.request - record["scope"] = self.scope - def start(self, rusage: "Optional[resource._RUsage]") -> None: """ Record that this logcontext is currently running. @@ -542,13 +531,10 @@ class LoggingContext: class LoggingContextFilter(logging.Filter): """Logging filter that adds values from the current logging context to each record. - Args: - **defaults: Default values to avoid formatters complaining about - missing fields """ - def __init__(self, **defaults) -> None: - self.defaults = defaults + def __init__(self, request: str = ""): + self._default_request = request def filter(self, record) -> Literal[True]: """Add each fields from the logging contexts to the record. @@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - for key, value in self.defaults.items(): - setattr(record, key, value) + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for # robustness' sake. if context is not None: - context.copy_to(record) + # Logging is interested in the request. + record.request = context.request return True diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 76b7decf26..70e0fa45d9 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py
@@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar _background_process_start_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc() - with BackgroundProcessLoggingContext(desc) as context: - context.request = "%s-%i" % (desc, count) + with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: try: ctx = noop_context_manager() if bg_start_span: @@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext): __slots__ = ["_proc"] - def __init__(self, name: str): - super().__init__(name) + def __init__(self, name: str, request: Optional[str] = None): + super().__init__(name, request=request) self._proc = _BackgroundProcess(name, self) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 3d2e874838..ad07ee86f6 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py
@@ -14,7 +14,7 @@ # limitations under the License. import abc -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict from synapse.types import RoomStreamToken @@ -36,12 +36,21 @@ class Pusher(metaclass=abc.ABCMeta): # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we # should honour this rather than just looking for anything higher - # because of potential out-of-order event serialisation. This starts - # off as None though as we don't know any better. - self.max_stream_ordering = None # type: Optional[int] + # because of potential out-of-order event serialisation. + self.max_stream_ordering = self.store.get_room_max_stream_ordering() - @abc.abstractmethod def on_new_notifications(self, max_token: RoomStreamToken) -> None: + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_ordering = max_token.stream + + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + self._start_processing() + + @abc.abstractmethod + def _start_processing(self): + """Start processing push notifications.""" raise NotImplementedError() @abc.abstractmethod diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 64a35c1994..11a97b8df4 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py
@@ -22,7 +22,6 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher from synapse.push.mailer import Mailer -from synapse.types import RoomStreamToken if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -93,20 +92,6 @@ class EmailPusher(Pusher): pass self.timed_call = None - def on_new_notifications(self, max_token: RoomStreamToken) -> None: - # We just use the minimum stream ordering and ignore the vector clock - # component. This is safe to do as long as we *always* ignore the vector - # clock components. - max_stream_ordering = max_token.stream - - if self.max_stream_ordering: - self.max_stream_ordering = max( - max_stream_ordering, self.max_stream_ordering - ) - else: - self.max_stream_ordering = max_stream_ordering - self._start_processing() - def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: # We could wake up and cancel the timer but there tend to be quite a # lot of read receipts so it's probably less work to just let the @@ -172,7 +157,6 @@ class EmailPusher(Pusher): being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering - assert self.max_stream_ordering is not None unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( self.user_id, start, self.max_stream_ordering ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 5408aa1295..e8b25bcd2a 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py
@@ -26,7 +26,6 @@ from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfigException -from synapse.types import RoomStreamToken from . import push_rule_evaluator, push_tools @@ -122,17 +121,6 @@ class HttpPusher(Pusher): if should_check_for_notifs: self._start_processing() - def on_new_notifications(self, max_token: RoomStreamToken) -> None: - # We just use the minimum stream ordering and ignore the vector clock - # component. This is safe to do as long as we *always* ignore the vector - # clock components. - max_stream_ordering = max_token.stream - - self.max_stream_ordering = max( - max_stream_ordering, self.max_stream_ordering or 0 - ) - self._start_processing() - def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: # Note that the min here shouldn't be relied upon to be accurate. @@ -192,10 +180,7 @@ class HttpPusher(Pusher): Never call this directly: use _process which will only allow this to run once per pusher. """ - - fn = self.store.get_unread_push_actions_for_user_in_range_for_http - assert self.max_stream_ordering is not None - unprocessed = await fn( + unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 748b1407c6..409e46da8f 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py
@@ -129,9 +129,8 @@ class PusherPool: ) # create the pusher setting last_stream_ordering to the current maximum - # stream ordering in event_push_actions, so it will process - # pushes from this point onwards. - last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() + # stream ordering, so it will process pushes from this point onwards. + last_stream_ordering = self.store.get_room_max_stream_ordering() await self.store.add_pusher( user_id=user_id, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..804da994ea 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. ctx_name = "replication-conn-%s" % self.conn_id - self._logging_context = BackgroundProcessLoggingContext(ctx_name) - self._logging_context.request = ctx_name + self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) def connectionMade(self): logger.info("[%s] Connection established", self.id()) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 2e56dfaf31..e5c03cc609 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -894,16 +894,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - async def get_latest_push_action_stream_ordering(self): - def f(txn): - txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") - return txn.fetchone() - - result = await self.db_pool.runInteraction( - "get_latest_push_action_stream_ordering", f - ) - return result[0] or 0 - def _remove_old_push_actions_before_txn( self, txn, room_id, user_id, stream_ordering ):