diff options
Diffstat (limited to 'synapse')
41 files changed, 747 insertions, 523 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1071a0576e..bff87fabde 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -34,7 +34,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging import opentracing as opentracing from synapse.types import StateMap, UserID -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -70,8 +69,9 @@ class Auth: self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(10000) - register_cache("cache", "token_cache", self.token_cache) + self.token_cache = LruCache( + 10000, "token_cache" + ) # type: LruCache[str, Tuple[str, bool]] self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index ad3c408519..58291afc22 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -60,6 +60,13 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +# Maximum number of events to provide in an AS transaction. +MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100 + +# Maximum number of ephemeral events to provide in an AS transaction. +MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 + + class ApplicationServiceScheduler: """ Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this @@ -136,10 +143,17 @@ class _ServiceQueuer: self.requests_in_flight.add(service.id) try: while True: - events = self.queued_events.pop(service.id, []) - ephemeral = self.queued_ephemeral.pop(service.id, []) + all_events = self.queued_events.get(service.id, []) + events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION] + del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION] + + all_events_ephemeral = self.queued_ephemeral.get(service.id, []) + ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] + del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] + if not events and not ephemeral: return + try: await self.txn_ctrl.send(service, events, ephemeral) except Exception: diff --git a/synapse/events/builder.py b/synapse/events/builder.py index df4f950fec..07df258e6e 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -98,7 +98,7 @@ class EventBuilder: return self._state_key is not None async def build( - self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]] + self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]], ) -> EventBase: """Transform into a fully signed and hashed event diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 0206320e96..bd8e71ae56 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional import synapse.state import synapse.storage @@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -30,11 +34,7 @@ class BaseHandler: Common base class for the event handlers. """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # type: synapse.storage.DataStore self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -56,7 +56,7 @@ class BaseHandler: clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, - ) + ) # type: Optional[Ratelimiter] else: self.admin_redaction_ratelimiter = None @@ -127,15 +127,15 @@ class BaseHandler: if guest_access != "can_join": if context: current_state_ids = await context.get_current_state_ids() - current_state = await self.store.get_events( + current_state_dict = await self.store.get_events( list(current_state_ids.values()) ) + current_state = list(current_state_dict.values()) else: - current_state = await self.state_handler.get_current_state( + current_state_map = await self.state_handler.get_current_state( event.room_id ) - - current_state = list(current_state.values()) + current_state = list(current_state_map.values()) logger.info("maybe_kick_guest_users %r", current_state) await self.kick_guest_users(current_state) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index f33044e97a..fd4f762f33 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -22,7 +22,7 @@ from typing import List from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils @@ -63,16 +63,10 @@ class AccountValidityHandler: self._raw_from = email.utils.parseaddr(self._from_string)[1] # Check the renewal emails to send and send them every 30min. - def send_emails(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "send_renewals", self._send_renewal_emails - ) - if hs.config.run_background_tasks: - self.clock.looping_call(send_emails, 30 * 60 * 1000) + self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) + @wrap_as_background_process("send_renewals") async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1d1ddc2245..8619fbb982 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1122,20 +1122,22 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash. """ - def _do_validate_hash(): + def _do_validate_hash(checked_hash: bytes): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), - stored_hash, + checked_hash, ) if stored_hash: if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return await defer_to_thread( + self.hs.get_reactor(), _do_validate_hash, stored_hash + ) else: return False diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 98075f48d2..cb11754bf8 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler): user_id, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: + # The member_event_id will always be available if membership is set + # to leave. + assert member_event_id + result = await self._room_initial_sync_parted( user_id, room_id, pagin_config, membership, member_event_id, is_peeking ) @@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler): user_id: str, room_id: str, pagin_config: PaginationConfig, - membership: Membership, + membership: str, member_event_id: str, is_peeking: bool, ) -> JsonDict: @@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler): user_id: str, room_id: str, pagin_config: PaginationConfig, - membership: Membership, + membership: str, is_peeking: bool, ) -> JsonDict: current_state = await self.state.get_current_state(room_id=room_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e37bca3d12..e4f2115617 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1364,7 +1364,12 @@ class EventCreationHandler: for k, v in original_event.internal_metadata.get_dict().items(): setattr(builder.internal_metadata, k, v) - event = await builder.build(prev_event_ids=original_event.prev_event_ids()) + # the event type hasn't changed, so there's no point in re-calculating the + # auth events. + event = await builder.build( + prev_event_ids=original_event.prev_event_ids(), + auth_event_ids=original_event.auth_event_ids(), + ) # we rebuild the event context, to be on the safe side. If nothing else, # delta_ids might need an update. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b784938755..92700b589c 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random +from typing import TYPE_CHECKING, Optional from synapse.api.errors import ( AuthError, @@ -24,11 +24,20 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID, create_requester, get_domain_from_id +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.types import ( + JsonDict, + Requester, + UserID, + create_requester, + get_domain_from_id, +) from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) MAX_DISPLAYNAME_LEN = 256 @@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler): PROFILE_UPDATE_MS = 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation = hs.get_federation_client() @@ -57,10 +66,10 @@ class ProfileHandler(BaseHandler): if hs.config.run_background_tasks: self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS + self._update_remote_profile_cache, self.PROFILE_UPDATE_MS ) - async def get_profile(self, user_id): + async def get_profile(self, user_id: str) -> JsonDict: target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): @@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler): except HttpResponseException as e: raise e.to_synapse_error() - async def get_profile_from_cache(self, user_id): + async def get_profile_from_cache(self, user_id: str) -> JsonDict: """Get the profile information from our local cache. If the user is ours then the profile information will always be corect. Otherwise, it may be out of date/missing. @@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler): profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - async def get_displayname(self, target_user): + async def get_displayname(self, target_user: UserID) -> str: if self.hs.is_mine(target_user): try: displayname = await self.store.get_profile_displayname( @@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler): return result["displayname"] async def set_displayname( - self, target_user, requester, new_displayname, by_admin=False - ): + self, + target_user: UserID, + requester: Requester, + new_displayname: str, + by_admin: bool = False, + ) -> None: """Set the displayname of a user Args: - target_user (UserID): the user whose displayname is to be changed. - requester (Requester): The user attempting to make this change. - new_displayname (str): The displayname to give this user. - by_admin (bool): Whether this change was made by an administrator. + target_user: the user whose displayname is to be changed. + requester: The user attempting to make this change. + new_displayname: The displayname to give this user. + by_admin: Whether this change was made by an administrator. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler): 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) + displayname_to_set = new_displayname # type: Optional[str] if new_displayname == "": - new_displayname = None + displayname_to_set = None # If the admin changes the display name of a user, the requesting user cannot send # the join event to update the displayname in the rooms. @@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - await self.store.set_profile_displayname(target_user.localpart, new_displayname) + await self.store.set_profile_displayname( + target_user.localpart, displayname_to_set + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) @@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - async def get_avatar_url(self, target_user): + async def get_avatar_url(self, target_user: UserID) -> str: if self.hs.is_mine(target_user): try: avatar_url = await self.store.get_profile_avatar_url( @@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler): return result["avatar_url"] async def set_avatar_url( - self, target_user, requester, new_avatar_url, by_admin=False + self, + target_user: UserID, + requester: Requester, + new_avatar_url: str, + by_admin: bool = False, ): """Set a new avatar URL for a user. Args: - target_user (UserID): the user whose avatar URL is to be changed. - requester (Requester): The user attempting to make this change. - new_avatar_url (str): The avatar URL to give this user. - by_admin (bool): Whether this change was made by an administrator. + target_user: the user whose avatar URL is to be changed. + requester: The user attempting to make this change. + new_avatar_url: The avatar URL to give this user. + by_admin: Whether this change was made by an administrator. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - async def on_profile_query(self, args): + async def on_profile_query(self, args: JsonDict) -> JsonDict: user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler): return response - async def _update_join_states(self, requester, target_user): + async def _update_join_states( + self, requester: Requester, target_user: UserID + ) -> None: if not self.hs.is_mine(target_user): return @@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e) ) - async def check_profile_query_allowed(self, target_user, requester=None): + async def check_profile_query_allowed( + self, target_user: UserID, requester: Optional[UserID] = None + ) -> None: """Checks whether a profile query is allowed. If the 'require_auth_for_profile_requests' config flag is set to True and a 'requester' is provided, the query is only allowed if the two users share a room. Args: - target_user (UserID): The owner of the queried profile. - requester (None|UserID): The user querying for the profile. + target_user: The owner of the queried profile. + requester: The user querying for the profile. Raises: SynapseError(403): The two users share no room, or ne user couldn't @@ -370,11 +394,7 @@ class ProfileHandler(BaseHandler): raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN) raise - def _start_update_remote_profile_cache(self): - return run_as_background_process( - "Update remote profile", self._update_remote_profile_cache - ) - + @wrap_as_background_process("Update remote profile") async def _update_remote_profile_cache(self): """Called periodically to check profiles of remote users we haven't checked in a while. diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py new file mode 100644 index 0000000000..0caf325916 --- /dev/null +++ b/synapse/logging/_remote.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import traceback +from collections import deque +from ipaddress import IPv4Address, IPv6Address, ip_address +from math import floor +from typing import Callable, Optional + +import attr +from zope.interface import implementer + +from twisted.application.internet import ClientService +from twisted.internet.defer import Deferred +from twisted.internet.endpoints import ( + HostnameEndpoint, + TCP4ClientEndpoint, + TCP6ClientEndpoint, +) +from twisted.internet.interfaces import IPushProducer, ITransport +from twisted.internet.protocol import Factory, Protocol +from twisted.logger import ILogObserver, Logger, LogLevel + + +@attr.s +@implementer(IPushProducer) +class LogProducer: + """ + An IPushProducer that writes logs from its buffer to its transport when it + is resumed. + + Args: + buffer: Log buffer to read logs from. + transport: Transport to write to. + format_event: A callable to format the log entry to a string. + """ + + transport = attr.ib(type=ITransport) + format_event = attr.ib(type=Callable[[dict], str]) + _buffer = attr.ib(type=deque) + _paused = attr.ib(default=False, type=bool, init=False) + + def pauseProducing(self): + self._paused = True + + def stopProducing(self): + self._paused = True + self._buffer = deque() + + def resumeProducing(self): + self._paused = False + + while self._paused is False and (self._buffer and self.transport.connected): + try: + # Request the next event and format it. + event = self._buffer.popleft() + msg = self.format_event(event) + + # Send it as a new line over the transport. + self.transport.write(msg.encode("utf8")) + except Exception: + # Something has gone wrong writing to the transport -- log it + # and break out of the while. + traceback.print_exc(file=sys.__stderr__) + break + + +@attr.s +@implementer(ILogObserver) +class TCPLogObserver: + """ + An IObserver that writes JSON logs to a TCP target. + + Args: + hs (HomeServer): The homeserver that is being logged for. + host: The host of the logging target. + port: The logging target's port. + format_event: A callable to format the log entry to a string. + maximum_buffer: The maximum buffer size. + """ + + hs = attr.ib() + host = attr.ib(type=str) + port = attr.ib(type=int) + format_event = attr.ib(type=Callable[[dict], str]) + maximum_buffer = attr.ib(type=int) + _buffer = attr.ib(default=attr.Factory(deque), type=deque) + _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) + _logger = attr.ib(default=attr.Factory(Logger)) + _producer = attr.ib(default=None, type=Optional[LogProducer]) + + def start(self) -> None: + + # Connect without DNS lookups if it's a direct IP. + try: + ip = ip_address(self.host) + if isinstance(ip, IPv4Address): + endpoint = TCP4ClientEndpoint( + self.hs.get_reactor(), self.host, self.port + ) + elif isinstance(ip, IPv6Address): + endpoint = TCP6ClientEndpoint( + self.hs.get_reactor(), self.host, self.port + ) + else: + raise ValueError("Unknown IP address provided: %s" % (self.host,)) + except ValueError: + endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port) + + factory = Factory.forProtocol(Protocol) + self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) + self._service.startService() + self._connect() + + def stop(self): + self._service.stopService() + + def _connect(self) -> None: + """ + Triggers an attempt to connect then write to the remote if not already writing. + """ + if self._connection_waiter: + return + + self._connection_waiter = self._service.whenConnected(failAfterFailures=1) + + @self._connection_waiter.addErrback + def fail(r): + r.printTraceback(file=sys.__stderr__) + self._connection_waiter = None + self._connect() + + @self._connection_waiter.addCallback + def writer(r): + # We have a connection. If we already have a producer, and its + # transport is the same, just trigger a resumeProducing. + if self._producer and r.transport is self._producer.transport: + self._producer.resumeProducing() + self._connection_waiter = None + return + + # If the producer is still producing, stop it. + if self._producer: + self._producer.stopProducing() + + # Make a new producer and start it. + self._producer = LogProducer( + buffer=self._buffer, + transport=r.transport, + format_event=self.format_event, + ) + r.transport.registerProducer(self._producer, True) + self._producer.resumeProducing() + self._connection_waiter = None + + def _handle_pressure(self) -> None: + """ + Handle backpressure by shedding events. + + The buffer will, in this order, until the buffer is below the maximum: + - Shed DEBUG events + - Shed INFO events + - Shed the middle 50% of the events. + """ + if len(self._buffer) <= self.maximum_buffer: + return + + # Strip out DEBUGs + self._buffer = deque( + filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer) + ) + + if len(self._buffer) <= self.maximum_buffer: + return + + # Strip out INFOs + self._buffer = deque( + filter(lambda event: event["log_level"] != LogLevel.info, self._buffer) + ) + + if len(self._buffer) <= self.maximum_buffer: + return + + # Cut the middle entries out + buffer_split = floor(self.maximum_buffer / 2) + + old_buffer = self._buffer + self._buffer = deque() + + for i in range(buffer_split): + self._buffer.append(old_buffer.popleft()) + + end_buffer = [] + for i in range(buffer_split): + end_buffer.append(old_buffer.pop()) + + self._buffer.extend(reversed(end_buffer)) + + def __call__(self, event: dict) -> None: + self._buffer.append(event) + + # Handle backpressure, if it exists. + try: + self._handle_pressure() + except Exception: + # If handling backpressure fails,clear the buffer and log the + # exception. + self._buffer.clear() + self._logger.failure("Failed clearing backpressure") + + # Try and write immediately. + self._connect() diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 1b8916cfa2..9b46956ca9 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -18,26 +18,11 @@ Log formatters that output terse JSON. """ import json -import sys -import traceback -from collections import deque -from ipaddress import IPv4Address, IPv6Address, ip_address -from math import floor -from typing import IO, Optional +from typing import IO -import attr -from zope.interface import implementer +from twisted.logger import FileLogObserver -from twisted.application.internet import ClientService -from twisted.internet.defer import Deferred -from twisted.internet.endpoints import ( - HostnameEndpoint, - TCP4ClientEndpoint, - TCP6ClientEndpoint, -) -from twisted.internet.interfaces import IPushProducer, ITransport -from twisted.internet.protocol import Factory, Protocol -from twisted.logger import FileLogObserver, ILogObserver, Logger +from synapse.logging._remote import TCPLogObserver _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) @@ -150,180 +135,22 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb return FileLogObserver(outFile, formatEvent) -@attr.s -@implementer(IPushProducer) -class LogProducer: +def TerseJSONToTCPLogObserver( + hs, host: str, port: int, metadata: dict, maximum_buffer: int +) -> FileLogObserver: """ - An IPushProducer that writes logs from its buffer to its transport when it - is resumed. - - Args: - buffer: Log buffer to read logs from. - transport: Transport to write to. - """ - - transport = attr.ib(type=ITransport) - _buffer = attr.ib(type=deque) - _paused = attr.ib(default=False, type=bool, init=False) - - def pauseProducing(self): - self._paused = True - - def stopProducing(self): - self._paused = True - self._buffer = deque() - - def resumeProducing(self): - self._paused = False - - while self._paused is False and (self._buffer and self.transport.connected): - try: - event = self._buffer.popleft() - self.transport.write(_encoder.encode(event).encode("utf8")) - self.transport.write(b"\n") - except Exception: - # Something has gone wrong writing to the transport -- log it - # and break out of the while. - traceback.print_exc(file=sys.__stderr__) - break - - -@attr.s -@implementer(ILogObserver) -class TerseJSONToTCPLogObserver: - """ - An IObserver that writes JSON logs to a TCP target. + A log observer that formats events to a flattened JSON representation. Args: hs (HomeServer): The homeserver that is being logged for. host: The host of the logging target. port: The logging target's port. - metadata: Metadata to be added to each log entry. + metadata: Metadata to be added to each log object. + maximum_buffer: The maximum buffer size. """ - hs = attr.ib() - host = attr.ib(type=str) - port = attr.ib(type=int) - metadata = attr.ib(type=dict) - maximum_buffer = attr.ib(type=int) - _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) - _logger = attr.ib(default=attr.Factory(Logger)) - _producer = attr.ib(default=None, type=Optional[LogProducer]) - - def start(self) -> None: - - # Connect without DNS lookups if it's a direct IP. - try: - ip = ip_address(self.host) - if isinstance(ip, IPv4Address): - endpoint = TCP4ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) - elif isinstance(ip, IPv6Address): - endpoint = TCP6ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) - except ValueError: - endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port) - - factory = Factory.forProtocol(Protocol) - self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) - self._service.startService() - self._connect() - - def stop(self): - self._service.stopService() - - def _connect(self) -> None: - """ - Triggers an attempt to connect then write to the remote if not already writing. - """ - if self._connection_waiter: - return - - self._connection_waiter = self._service.whenConnected(failAfterFailures=1) - - @self._connection_waiter.addErrback - def fail(r): - r.printTraceback(file=sys.__stderr__) - self._connection_waiter = None - self._connect() - - @self._connection_waiter.addCallback - def writer(r): - # We have a connection. If we already have a producer, and its - # transport is the same, just trigger a resumeProducing. - if self._producer and r.transport is self._producer.transport: - self._producer.resumeProducing() - self._connection_waiter = None - return - - # If the producer is still producing, stop it. - if self._producer: - self._producer.stopProducing() - - # Make a new producer and start it. - self._producer = LogProducer(buffer=self._buffer, transport=r.transport) - r.transport.registerProducer(self._producer, True) - self._producer.resumeProducing() - self._connection_waiter = None - - def _handle_pressure(self) -> None: - """ - Handle backpressure by shedding events. - - The buffer will, in this order, until the buffer is below the maximum: - - Shed DEBUG events - - Shed INFO events - - Shed the middle 50% of the events. - """ - if len(self._buffer) <= self.maximum_buffer: - return - - # Strip out DEBUGs - self._buffer = deque( - filter(lambda event: event["level"] != "DEBUG", self._buffer) - ) - - if len(self._buffer) <= self.maximum_buffer: - return - - # Strip out INFOs - self._buffer = deque( - filter(lambda event: event["level"] != "INFO", self._buffer) - ) - - if len(self._buffer) <= self.maximum_buffer: - return - - # Cut the middle entries out - buffer_split = floor(self.maximum_buffer / 2) - - old_buffer = self._buffer - self._buffer = deque() - - for i in range(buffer_split): - self._buffer.append(old_buffer.popleft()) - - end_buffer = [] - for i in range(buffer_split): - end_buffer.append(old_buffer.pop()) - - self._buffer.extend(reversed(end_buffer)) - - def __call__(self, event: dict) -> None: - flattened = flatten_event(event, self.metadata, include_time=True) - self._buffer.append(flattened) - - # Handle backpressure, if it exists. - try: - self._handle_pressure() - except Exception: - # If handling backpressure fails,clear the buffer and log the - # exception. - self._buffer.clear() - self._logger.failure("Failed clearing backpressure") + def formatEvent(_event: dict) -> str: + flattened = flatten_event(_event, metadata, include_time=True) + return _encoder.encode(flattened) + "\n" - # Try and write immediately. - self._connect() + return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 5b73463504..ea5f1c7b62 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -24,6 +24,7 @@ from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.opentracing import start_active_span if TYPE_CHECKING: import resource @@ -197,14 +198,14 @@ def run_as_background_process(desc: str, func, *args, **kwargs): with BackgroundProcessLoggingContext(desc) as context: context.request = "%s-%i" % (desc, count) - try: - result = func(*args, **kwargs) + with start_active_span(desc, tags={"request_id": context.request}): + result = func(*args, **kwargs) - if inspect.isawaitable(result): - result = await result + if inspect.isawaitable(result): + result = await result - return result + return result except Exception: logger.exception( "Background process '%s' threw an exception", desc, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index c440f2545c..a701defcdd 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): # dedupe when we add callbacks to lru cache nodes, otherwise the number # of callbacks would grow. def __call__(self): - rules = self.cache.get(self.room_id, None, update_metrics=False) + rules = self.cache.get_immediate(self.room_id, None, update_metrics=False) if rules: rules.invalidate_all() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 455a1acb46..155791b754 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -387,8 +387,8 @@ class Mailer: return ret async def get_message_vars(self, notif, event, room_state_ids): - if event.type != EventTypes.Message: - return + if event.type != EventTypes.Message and event.type != EventTypes.Encrypted: + return None sender_state_event_id = room_state_ids[("m.room.member", event.sender)] sender_state_event = await self.store.get_event(sender_state_event_id) @@ -399,10 +399,8 @@ class Mailer: # sender_hash % the number of default images to choose from sender_hash = string_ordinal_total(event.sender) - msgtype = event.content.get("msgtype") - ret = { - "msgtype": msgtype, + "event_type": event.type, "is_historical": event.event_id != notif["event_id"], "id": event.event_id, "ts": event.origin_server_ts, @@ -411,6 +409,14 @@ class Mailer: "sender_hash": sender_hash, } + # Encrypted messages don't have any additional useful information. + if event.type == EventTypes.Encrypted: + return ret + + msgtype = event.content.get("msgtype") + + ret["msgtype"] = msgtype + if msgtype == "m.text": self.add_text_message_vars(ret, event) elif msgtype == "m.image": diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 3a68ce636f..2ce9e444ab 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,11 +16,10 @@ import logging import re -from typing import Any, Dict, List, Optional, Pattern, Union +from typing import Any, Dict, List, Optional, Pattern, Tuple, Union from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -174,20 +173,21 @@ class PushRuleEvaluatorForEvent: # Similar to _glob_matches, but do not treat display_name as a glob. r = regex_cache.get((display_name, False, True), None) if not r: - r = re.escape(display_name) - r = _re_word_boundary(r) - r = re.compile(r, flags=re.IGNORECASE) + r1 = re.escape(display_name) + r1 = _re_word_boundary(r1) + r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r - return r.search(body) + return bool(r.search(body)) def _get_value(self, dotted_key: str) -> Optional[str]: return self._value_cache.get(dotted_key, None) # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000) -register_cache("cache", "regex_push_cache", regex_cache) +regex_cache = LruCache( + 50000, "regex_push_cache" +) # type: LruCache[Tuple[str, bool, bool], Pattern] def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: @@ -205,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: if not r: r = _glob_to_re(glob, word_boundary) regex_cache[(glob, True, word_boundary)] = r - return r.search(value) + return bool(r.search(value)) except re.error: logger.warning("Failed to parse glob to regex: %r", glob) return False diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 4b0ea0cc01..0f5b7adef7 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache from ._base import BaseSlavedStore @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = DeferredCache( - name="client_ip_last_seen", keylen=4, max_entries=50000 - ) # type: DeferredCache[tuple, int] + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 + ) # type: LruCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) @@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self.hs.get_tcp_replication().send_user_ip( user_id, access_token, ip, user_agent, device_id, now diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html index 1a6c70b562..6d76064d13 100644 --- a/synapse/res/templates/notif.html +++ b/synapse/res/templates/notif.html @@ -1,41 +1,47 @@ -{% for message in notif.messages %} +{%- for message in notif.messages %} <tr class="{{ "historical_message" if message.is_historical else "message" }}"> <td class="sender_avatar"> - {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} - {% if message.sender_avatar_url %} + {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} + {%- if message.sender_avatar_url %} <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" /> - {% else %} - {% if message.sender_hash % 3 == 0 %} + {%- else %} + {%- if message.sender_hash % 3 == 0 %} <img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" /> - {% elif message.sender_hash % 3 == 1 %} + {%- elif message.sender_hash % 3 == 1 %} <img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" /> - {% else %} + {%- else %} <img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" /> - {% endif %} - {% endif %} - {% endif %} + {%- endif %} + {%- endif %} + {%- endif %} </td> <td class="message_contents"> - {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} - <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div> - {% endif %} + {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} + <div class="sender_name">{%- if message.msgtype == "m.emote" %}*{%- endif %} {{ message.sender_name }}</div> + {%- endif %} <div class="message_body"> - {% if message.msgtype == "m.text" %} - {{ message.body_text_html }} - {% elif message.msgtype == "m.emote" %} - {{ message.body_text_html }} - {% elif message.msgtype == "m.notice" %} - {{ message.body_text_html }} - {% elif message.msgtype == "m.image" %} - <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> - {% elif message.msgtype == "m.file" %} - <span class="filename">{{ message.body_text_plain }}</span> - {% endif %} + {%- if message.event_type == "m.room.encrypted" %} + An encrypted message. + {%- elif message.event_type == "m.room.message" %} + {%- if message.msgtype == "m.text" %} + {{ message.body_text_html }} + {%- elif message.msgtype == "m.emote" %} + {{ message.body_text_html }} + {%- elif message.msgtype == "m.notice" %} + {{ message.body_text_html }} + {%- elif message.msgtype == "m.image" %} + <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> + {%- elif message.msgtype == "m.file" %} + <span class="filename">{{ message.body_text_plain }}</span> + {%- else %} + A message with unrecognised content. + {%- endif %} + {%- endif %} </div> </td> <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td> </tr> -{% endfor %} +{%- endfor %} <tr class="notif_link"> <td></td> <td> diff --git a/synapse/res/templates/notif.txt b/synapse/res/templates/notif.txt index a37bee9833..1ee7da3c50 100644 --- a/synapse/res/templates/notif.txt +++ b/synapse/res/templates/notif.txt @@ -1,16 +1,22 @@ -{% for message in notif.messages %} -{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) -{% if message.msgtype == "m.text" %} +{%- for message in notif.messages %} +{%- if message.event_type == "m.room.encrypted" %} +An encrypted message. +{%- elif message.event_type == "m.room.message" %} +{%- if message.msgtype == "m.emote" %}* {%- endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) +{%- if message.msgtype == "m.text" %} {{ message.body_text_plain }} -{% elif message.msgtype == "m.emote" %} +{%- elif message.msgtype == "m.emote" %} {{ message.body_text_plain }} -{% elif message.msgtype == "m.notice" %} +{%- elif message.msgtype == "m.notice" %} {{ message.body_text_plain }} -{% elif message.msgtype == "m.image" %} +{%- elif message.msgtype == "m.image" %} {{ message.body_text_plain }} -{% elif message.msgtype == "m.file" %} +{%- elif message.msgtype == "m.file" %} {{ message.body_text_plain }} -{% endif %} -{% endfor %} +{%- else %} +A message with unrecognised content. +{%- endif %} +{%- endif %} +{%- endfor %} View {{ room.title }} at {{ notif.link }} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html index a2dfeb9e9f..27d4182790 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html @@ -2,8 +2,8 @@ <html lang="en"> <head> <style type="text/css"> - {% include 'mail.css' without context %} - {% include "mail-%s.css" % app_name ignore missing without context %} + {%- include 'mail.css' without context %} + {%- include "mail-%s.css" % app_name ignore missing without context %} </style> </head> <body> @@ -18,21 +18,21 @@ <div class="summarytext">{{ summary_text }}</div> </td> <td class="logo"> - {% if app_name == "Riot" %} + {%- if app_name == "Riot" %} <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> - {% elif app_name == "Vector" %} + {%- elif app_name == "Vector" %} <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> - {% elif app_name == "Element" %} + {%- elif app_name == "Element" %} <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/> - {% else %} + {%- else %} <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> - {% endif %} + {%- endif %} </td> </tr> </table> - {% for room in rooms %} - {% include 'room.html' with context %} - {% endfor %} + {%- for room in rooms %} + {%- include 'room.html' with context %} + {%- endfor %} <div class="footer"> <a href="{{ unsubscribe_link }}">Unsubscribe</a> <br/> @@ -41,12 +41,12 @@ Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because an event was received at {{ reason.received_at|format_ts("%c") }} which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago, - {% if reason.last_sent_ts %} + {%- if reason.last_sent_ts %} and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. - {% else %} + {%- else %} and we don't have a last time we sent a mail for this room. - {% endif %} + {%- endif %} </div> </div> </td> diff --git a/synapse/res/templates/notif_mail.txt b/synapse/res/templates/notif_mail.txt index 24843042a5..df3c253979 100644 --- a/synapse/res/templates/notif_mail.txt +++ b/synapse/res/templates/notif_mail.txt @@ -2,9 +2,9 @@ Hi {{ user_display_name }}, {{ summary_text }} -{% for room in rooms %} -{% include 'room.txt' with context %} -{% endfor %} +{%- for room in rooms %} +{%- include 'room.txt' with context %} +{%- endfor %} You can disable these notifications at {{ unsubscribe_link }} diff --git a/synapse/res/templates/room.html b/synapse/res/templates/room.html index b8525fef88..4fc6f6ac9b 100644 --- a/synapse/res/templates/room.html +++ b/synapse/res/templates/room.html @@ -1,23 +1,23 @@ <table class="room"> <tr class="room_header"> <td class="room_avatar"> - {% if room.avatar_url %} + {%- if room.avatar_url %} <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" /> - {% else %} - {% if room.hash % 3 == 0 %} + {%- else %} + {%- if room.hash % 3 == 0 %} <img alt="" src="https://riot.im/img/external/avatar-1.png" /> - {% elif room.hash % 3 == 1 %} + {%- elif room.hash % 3 == 1 %} <img alt="" src="https://riot.im/img/external/avatar-2.png" /> - {% else %} + {%- else %} <img alt="" src="https://riot.im/img/external/avatar-3.png" /> - {% endif %} - {% endif %} + {%- endif %} + {%- endif %} </td> <td class="room_name" colspan="2"> {{ room.title }} </td> </tr> - {% if room.invite %} + {%- if room.invite %} <tr> <td></td> <td> @@ -25,9 +25,9 @@ </td> <td></td> </tr> - {% else %} - {% for notif in room.notifs %} - {% include 'notif.html' with context %} - {% endfor %} - {% endif %} + {%- else %} + {%- for notif in room.notifs %} + {%- include 'notif.html' with context %} + {%- endfor %} + {%- endif %} </table> diff --git a/synapse/res/templates/room.txt b/synapse/res/templates/room.txt index 84648c710e..df841e9e6f 100644 --- a/synapse/res/templates/room.txt +++ b/synapse/res/templates/room.txt @@ -1,9 +1,9 @@ {{ room.title }} -{% if room.invite %} +{%- if room.invite %} You've been invited, join at {{ room.link }} -{% else %} - {% for notif in room.notifs %} - {% include 'notif.txt' with context %} - {% endfor %} -{% endif %} +{%- else %} + {%- for notif in room.notifs %} + {%- include 'notif.txt' with context %} + {%- endfor %} +{%- endif %} diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d7deb9300d..b82a4e978a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -110,6 +110,8 @@ class LoginRestServlet(RestServlet): ({"type": t} for t in self.auth_handler.get_supported_login_types()) ) + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + return 200, {"flows": flows} def on_OPTIONS(self, request: SynapseRequest): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index ab49d227de..2b196ded1b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta): """ try: - if key is None: - getattr(self, cache_name).invalidate_all() - else: - getattr(self, cache_name).invalidate(tuple(key)) + cache = getattr(self, cache_name) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. - pass + return + + if key is None: + cache.invalidate_all() + else: + cache.invalidate(tuple(key)) def db_to_json(db_content): diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 763722d6bc..0217e63108 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -160,7 +160,7 @@ class LoggingDatabaseConnection: self.conn.__enter__() return self - def __exit__(self, exc_type, exc_value, traceback) -> bool: + def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: return self.conn.__exit__(exc_type, exc_value, traceback) # Proxy through any unknown lookups to the DB conn class. diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index ec8260e906..2408432738 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -410,8 +410,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - self.client_ip_last_seen = DeferredCache( - name="client_ip_last_seen", keylen=4, max_entries=50000 + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 ) super().__init__(database, db_conn, hs) @@ -442,7 +442,7 @@ class ClientIpStore(ClientIpWorkerStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self._batch_row_update[key] = (user_agent, device_id, now) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e662a20d24..dfb4f87b8f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,8 +34,8 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -1005,8 +1005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = DeferredCache( - name="device_id_exists", keylen=2, max_entries=10000 + self.device_id_exists_cache = LruCache( + cache_name="device_id_exists", keylen=2, max_size=10000 ) async def store_device( @@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - self.device_id_exists_cache.prefill(key, True) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: raise diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ba3b1769b0..87808c1483 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1051,9 +1051,7 @@ class PersistEventsStore: def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.prefill( - (cache_entry[0].event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index ff150f0be7..6e7f16f39c 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -33,7 +33,10 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream @@ -42,8 +45,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.descriptors import cached +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,20 +140,16 @@ class EventsWorkerStore(SQLBaseStore): db_conn, "events", "stream_ordering", step=-1 ) - if not hs.config.worker.worker_app: + if hs.config.run_background_tasks: # We periodically clean out old transaction ID mappings self._clock.looping_call( - run_as_background_process, - 5 * 60 * 1000, - "_cleanup_old_transaction_ids", - self._cleanup_old_transaction_ids, + self._cleanup_old_transaction_ids, 5 * 60 * 1000, ) - self._get_event_cache = DeferredCache( - "*getEvent*", + self._get_event_cache = LruCache( + cache_name="*getEvent*", keylen=3, - max_entries=hs.config.caches.event_cache_size, - apply_cache_factor_from_config=False, + max_size=hs.config.caches.event_cache_size, ) self._event_fetch_lock = threading.Condition() @@ -749,7 +748,7 @@ class EventsWorkerStore(SQLBaseStore): event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.prefill((event_id,), cache_entry) + self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry return result_map @@ -1375,6 +1374,7 @@ class EventsWorkerStore(SQLBaseStore): return mapping + @wrap_as_background_process("_cleanup_old_transaction_ids") async def _cleanup_old_transaction_ids(self): """Cleans out transaction id mappings older than 24hrs. """ diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 1681caa1f0..a6d1eb908a 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_displayname( - self, user_localpart: str, new_displayname: str + self, user_localpart: str, new_displayname: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", @@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore): async def get_remote_profile_cache_entries_that_expire( self, last_checked: int - ) -> Dict[str, str]: + ) -> List[Dict[str, str]]: """Get all users who haven't been checked since `last_checked` """ diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index df8609b97b..7997242d90 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore): lock=False, ) - user_has_pusher = self.get_if_user_has_pusher.cache.get( + user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( (user_id,), None, update_metrics=False ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 5cdf16521c..ca7917c989 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): if receipt_type != "m.read": return - # Returns either an ObservableDeferred or the raw result - res = self.get_users_with_read_receipts_in_room.cache.get( + res = self.get_users_with_read_receipts_in_room.cache.get_immediate( room_id, None, update_metrics=False ) - # first handle the ObservableDeferred case - if isinstance(res, ObservableDeferred): - if res.has_called(): - res = res.get_result() - else: - res = None - if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 20fcdaa529..01d9dbb36f 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -20,7 +20,10 @@ from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore @@ -67,16 +70,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): ): self._known_servers_count = 1 self.hs.get_clock().looping_call( - run_as_background_process, - 60 * 1000, - "_count_known_servers", - self._count_known_servers, + self._count_known_servers, 60 * 1000, ) self.hs.get_clock().call_later( - 1000, - run_as_background_process, - "_count_known_servers", - self._count_known_servers, + 1000, self._count_known_servers, ) LaterGauge( "synapse_federation_known_servers", @@ -85,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): lambda: self._known_servers_count, ) + @wrap_as_background_process("_count_known_servers") async def _count_known_servers(self): """ Count the servers that this server knows about. @@ -531,7 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( + prev_res = self._get_joined_users_from_context.cache.get_immediate( (room_id, context.prev_group), None ) if prev_res and isinstance(prev_res, dict): diff --git a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql index 20f5a95a24..7b84a207fd 100644 --- a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql +++ b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql @@ -13,6 +13,5 @@ * limitations under the License. */ -ALTER TABLE application_services_state - ADD COLUMN read_receipt_stream_id INT, - ADD COLUMN presence_stream_id INT; \ No newline at end of file +ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT; +ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT; \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql new file mode 100644 index 0000000000..01ea6eddcf --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql @@ -0,0 +1 @@ +DROP TABLE device_max_stream_id; diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 970bb1b9da..9cadcba18f 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Tuple from typing_extensions import Protocol @@ -65,5 +65,5 @@ class Connection(Protocol): def __enter__(self) -> "Connection": ... - def __exit__(self, exc_type, exc_value, traceback) -> bool: + def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: ... diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 8fc05be278..89f0b38535 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -16,7 +16,7 @@ import logging from sys import intern -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Sized import attr from prometheus_client.core import Gauge @@ -92,7 +92,7 @@ class CacheMetric: def register_cache( cache_type: str, cache_name: str, - cache, + cache: Sized, collect_callback: Optional[Callable] = None, resizable: bool = True, resize_callback: Optional[Callable] = None, @@ -100,12 +100,15 @@ def register_cache( """Register a cache object for metric collection and resizing. Args: - cache_type + cache_type: a string indicating the "type" of the cache. This is used + only for deduplication so isn't too important provided it's constant. cache_name: name of the cache - cache: cache itself + cache: cache itself, which must implement __len__(), and may optionally implement + a max_size property collect_callback: If given, a function which is called during metric collection to update additional metrics. - resizable: Whether this cache supports being resized. + resizable: Whether this cache supports being resized, in which case either + resize_callback must be provided, or the cache must support set_max_size(). resize_callback: A function which can be called to resize the cache. Returns: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index f728cd2cf2..601305487c 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -17,14 +17,23 @@ import enum import threading -from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast +from typing import ( + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + TypeVar, + Union, + cast, +) from prometheus_client import Gauge from twisted.internet import defer +from twisted.python import failure from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -34,7 +43,7 @@ cache_pending_metric = Gauge( ["name"], ) - +T = TypeVar("T") KT = TypeVar("KT") VT = TypeVar("VT") @@ -49,15 +58,12 @@ class DeferredCache(Generic[KT, VT]): """Wraps an LruCache, adding support for Deferred results. It expects that each entry added with set() will be a Deferred; likewise get() - may return an ObservableDeferred. + will return a Deferred. """ __slots__ = ( "cache", - "name", - "keylen", "thread", - "metrics", "_pending_deferred_cache", ) @@ -89,37 +95,27 @@ class DeferredCache(Generic[KT, VT]): cache_type() ) # type: MutableMapping[KT, CacheEntry] + def metrics_cb(): + cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) + # cache is used for completed results and maps to the result itself, rather than # a Deferred. self.cache = LruCache( max_size=max_entries, keylen=keylen, + cache_name=name, cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, + metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, - ) + ) # type: LruCache[KT, VT] - self.name = name - self.keylen = keylen self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) @property def max_entries(self): return self.cache.max_size - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -133,62 +129,113 @@ class DeferredCache(Generic[KT, VT]): def get( self, key: KT, - default=_Sentinel.sentinel, callback: Optional[Callable[[], None]] = None, update_metrics: bool = True, - ): + ) -> defer.Deferred: """Looks the key up in the caches. + For symmetry with set(), this method does *not* follow the synapse logcontext + rules: the logcontext will not be cleared on return, and the Deferred will run + its callbacks in the sentinel context. In other words: wrap the result with + make_deferred_yieldable() before `await`ing it. + Args: - key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead - callback(fn): Gets called when the entry in the cache is invalidated + key: + callback: Gets called when the entry in the cache is invalidated update_metrics (bool): whether to update the cache hit rate metrics Returns: - Either an ObservableDeferred or the result itself + A Deferred which completes with the result. Note that this may later fail + if there is an ongoing set() operation which later completes with a failure. + + Raises: + KeyError if the key is not found in the cache """ callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) if val is not _Sentinel.sentinel: val.callbacks.update(callbacks) if update_metrics: - self.metrics.inc_hits() - return val.deferred - - val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) - if val is not _Sentinel.sentinel: - self.metrics.inc_hits() - return val + m = self.cache.metrics + assert m # we always have a name, so should always have metrics + m.inc_hits() + return val.deferred.observe() - if update_metrics: - self.metrics.inc_misses() - - if default is _Sentinel.sentinel: + val2 = self.cache.get( + key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics + ) + if val2 is _Sentinel.sentinel: raise KeyError() else: - return default + return defer.succeed(val2) + + def get_immediate( + self, key: KT, default: T, update_metrics: bool = True + ) -> Union[VT, T]: + """If we have a *completed* cached value, return it.""" + return self.cache.get(key, default, update_metrics=update_metrics) def set( self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None, - ) -> ObservableDeferred: + ) -> defer.Deferred: + """Adds a new entry to the cache (or updates an existing one). + + The given `value` *must* be a Deferred. + + First any existing entry for the same key is invalidated. Then a new entry + is added to the cache for the given key. + + Until the `value` completes, calls to `get()` for the key will also result in an + incomplete Deferred, which will ultimately complete with the same result as + `value`. + + If `value` completes successfully, subsequent calls to `get()` will then return + a completed deferred with the same result. If it *fails*, the cache is + invalidated and subequent calls to `get()` will raise a KeyError. + + If another call to `set()` happens before `value` completes, then (a) any + invalidation callbacks registered in the interim will be called, (b) any + `get()`s in the interim will continue to complete with the result from the + *original* `value`, (c) any future calls to `get()` will complete with the + result from the *new* `value`. + + It is expected that `value` does *not* follow the synapse logcontext rules - ie, + if it is incomplete, it runs its callbacks in the sentinel context. + + Args: + key: Key to be set + value: a deferred which will complete with a result to add to the cache + callback: An optional callback to be called when the entry is invalidated + """ if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") callbacks = [callback] if callback else [] self.check_thread() - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: existing_entry.invalidate() + # XXX: why don't we invalidate the entry in `self.cache` yet? + + # we can save a whole load of effort if the deferred is ready. + if value.called: + result = value.result + if not isinstance(result, failure.Failure): + self.cache.set(key, result, callbacks) + return value + + # otherwise, we'll add an entry to the _pending_deferred_cache for now, + # and add callbacks to add it to the cache properly later. + + observable = ObservableDeferred(value, consumeErrors=True) + observer = observable.observe() + entry = CacheEntry(deferred=observable, callbacks=callbacks) + self._pending_deferred_cache[key] = entry def compare_and_pop(): @@ -232,7 +279,9 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache to the real cache. # observer.addCallbacks(cb, eb) - return observable + + # we return a new Deferred which will be called before any subsequent observers. + return observable.observe() def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): callbacks = [callback] if callback else [] @@ -257,11 +306,12 @@ class DeferredCache(Generic[KT, VT]): self.check_thread() if not isinstance(key, tuple): raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + key = cast(KT, key) self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) + entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1f43886804..5d7fffee66 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -23,7 +23,6 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) @@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase): keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) # type: DeferredCache[Tuple, Any] + ) # type: DeferredCache[CacheKey, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into @@ -202,32 +201,20 @@ class CacheDescriptor(_CacheDescriptorBase): cache_key = get_cache_key(args, kwargs) - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - if self.add_cache_context: - kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) - try: - cached_result_d = cache.get(cache_key, callback=invalidate_callback) - - if isinstance(cached_result_d, ObservableDeferred): - observer = cached_result_d.observe() - else: - observer = defer.succeed(cached_result_d) - + ret = cache.get(cache_key, callback=invalidate_callback) except KeyError: - ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext.get_instance( + cache, cache_key + ) - def onErr(f): - cache.invalidate(cache_key) - return f - - ret.addErrback(onErr) - - result_d = cache.set(cache_key, ret, callback=invalidate_callback) - observer = result_d.observe() + ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) + ret = cache.set(cache_key, ret, callback=invalidate_callback) - return make_deferred_yieldable(observer) + return make_deferred_yieldable(ret) wrapped = cast(_CachedFunction, _wrapped) @@ -286,7 +273,7 @@ class CacheListDescriptor(_CacheDescriptorBase): def __get__(self, obj, objtype=None): cached_method = getattr(obj, self.cached_method_name) - cache = cached_method.cache + cache = cached_method.cache # type: DeferredCache[CacheKey, Any] num_args = cached_method.num_args @functools.wraps(self.orig) @@ -326,14 +313,11 @@ class CacheListDescriptor(_CacheDescriptorBase): for arg in list_args: try: res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not isinstance(res, ObservableDeferred): - results[arg] = res - elif not res.has_succeeded(): - res = res.observe() + if not res.called: res.addCallback(update_results_dict, arg) cached_defers.append(res) else: - results[arg] = res.get_result() + results[arg] = res.result except KeyError: missing.add(arg) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 8592b93689..588d2d49f2 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -12,15 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import threading from collections import namedtuple +from typing import Any from synapse.util.caches.lrucache import LruCache -from . import register_cache - logger = logging.getLogger(__name__) @@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va return len(self.value) +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + class DictionaryCache: """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries, size_callback=len) + self.cache = LruCache( + max_size=max_entries, cache_name=name, size_callback=len + ) # type: LruCache[Any, DictionaryEntry] self.name = name self.sequence = 0 self.thread = None - # caches_by_name[name] = self.cache - - class Sentinel: - __slots__ = [] - - self.sentinel = Sentinel() - self.metrics = register_cache("dictionary", name, self.cache) def check_thread(self): expected_thread = self.thread @@ -80,10 +80,8 @@ class DictionaryCache: Returns: DictionaryEntry """ - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: - self.metrics.inc_hits() - + entry = self.cache.get(key, _Sentinel.sentinel) + if entry is not _Sentinel.sentinel: if dict_keys is None: return DictionaryEntry( entry.full, entry.known_absent, dict(entry.value) @@ -95,7 +93,6 @@ class DictionaryCache: {k: entry.value[k] for k in dict_keys if k in entry.value}, ) - self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) def invalidate(self, key): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 33eae2b7c4..60bb6ff642 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -15,11 +15,35 @@ import threading from functools import wraps -from typing import Callable, Optional, Type, Union +from typing import ( + Any, + Callable, + Generic, + Iterable, + Optional, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import Literal from synapse.config import cache as cache_config +from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache +# Function type: the type used for invalidation callbacks +FT = TypeVar("FT", bound=Callable[..., Any]) + +# Key and Value type for the cache +KT = TypeVar("KT") +VT = TypeVar("VT") + +# a general type var, distinct from either KT or VT +T = TypeVar("T") + def enumerate_leaves(node, depth): if depth == 0: @@ -41,29 +65,31 @@ class _Node: self.callbacks = callbacks -class LruCache: +class LruCache(Generic[KT, VT]): """ - Least-recently-used cache. + Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. + Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. - - Can also set callbacks on objects when getting/setting which are fired - when that key gets invalidated/evicted. """ def __init__( self, max_size: int, + cache_name: Optional[str] = None, keylen: int = 1, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable] = None, - evicted_callback: Optional[Callable] = None, + metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, ): """ Args: max_size: The maximum amount of entries the cache can hold + cache_name: The name of this cache, for the prometheus metrics. If unset, + no metrics will be reported on this cache. + keylen: The length of the tuple used as the cache key. Ignored unless cache_type is `TreeCache`. @@ -73,9 +99,13 @@ class LruCache: size_callback (func(V) -> int | None): - evicted_callback (func(int)|None): - if not None, called on eviction with the size of the evicted - entry + metrics_collection_callback: + metrics collection callback. This is called early in the metrics + collection process, before any of the metrics registered with the + prometheus Registry are collected, so can be used to update any dynamic + metrics. + + Ignored if cache_name is None. apply_cache_factor_from_config (bool): If true, `max_size` will be multiplied by a cache factor derived from the homeserver config @@ -94,6 +124,23 @@ class LruCache: else: self.max_size = int(max_size) + # register_cache might call our "set_cache_factor" callback; there's nothing to + # do yet when we get resized. + self._on_resize = None # type: Optional[Callable[[],None]] + + if cache_name is not None: + metrics = register_cache( + "lru_cache", + cache_name, + self, + collect_callback=metrics_collection_callback, + ) # type: Optional[CacheMetric] + else: + metrics = None + + # this is exposed for access from outside this class + self.metrics = metrics + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -105,16 +152,16 @@ class LruCache: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) - if evicted_callback: - evicted_callback(evicted_len) + if metrics: + metrics.inc_evictions(evicted_len) - def synchronized(f): + def synchronized(f: FT) -> FT: @wraps(f) def inner(*args, **kwargs): with lock: return f(*args, **kwargs) - return inner + return cast(FT, inner) cached_cache_len = [0] if size_callback is not None: @@ -168,18 +215,45 @@ class LruCache: node.callbacks.clear() return deleted_len + @overload + def cache_get( + key: KT, + default: Literal[None] = None, + callbacks: Iterable[Callable[[], None]] = ..., + update_metrics: bool = ..., + ) -> Optional[VT]: + ... + + @overload + def cache_get( + key: KT, + default: T, + callbacks: Iterable[Callable[[], None]] = ..., + update_metrics: bool = ..., + ) -> Union[T, VT]: + ... + @synchronized - def cache_get(key, default=None, callbacks=[]): + def cache_get( + key: KT, + default: Optional[T] = None, + callbacks: Iterable[Callable[[], None]] = [], + update_metrics: bool = True, + ): node = cache.get(key, None) if node is not None: move_node_to_front(node) node.callbacks.update(callbacks) + if update_metrics and metrics: + metrics.inc_hits() return node.value else: + if update_metrics and metrics: + metrics.inc_misses() return default @synchronized - def cache_set(key, value, callbacks=[]): + def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []): node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause @@ -208,7 +282,7 @@ class LruCache: evict() @synchronized - def cache_set_default(key, value): + def cache_set_default(key: KT, value: VT) -> VT: node = cache.get(key, None) if node is not None: return node.value @@ -217,8 +291,16 @@ class LruCache: evict() return value + @overload + def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: + ... + + @overload + def cache_pop(key: KT, default: T) -> Union[T, VT]: + ... + @synchronized - def cache_pop(key, default=None): + def cache_pop(key: KT, default: Optional[T] = None): node = cache.get(key, None) if node: delete_node(node) @@ -228,18 +310,18 @@ class LruCache: return default @synchronized - def cache_del_multi(key): + def cache_del_multi(key: KT) -> None: """ This will only work if constructed with cache_type=TreeCache """ popped = cache.pop(key) if popped is None: return - for leaf in enumerate_leaves(popped, keylen - len(key)): + for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))): delete_node(leaf) @synchronized - def cache_clear(): + def cache_clear() -> None: list_root.next_node = list_root list_root.prev_node = list_root for node in cache.values(): @@ -250,15 +332,21 @@ class LruCache: cached_cache_len[0] = 0 @synchronized - def cache_contains(key): + def cache_contains(key: KT) -> bool: return key in cache self.sentinel = object() + + # make sure that we clear out any excess entries after we get resized. self._on_resize = evict + self.get = cache_get self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + # `invalidate` is exposed for consistency with DeferredCache, so that it can be + # invalidated by the cache invalidation replication stream. + self.invalidate = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi self.len = synchronized(cache_len) @@ -302,6 +390,7 @@ class LruCache: new_size = int(self._original_max_size * factor) if new_size != self.max_size: self.max_size = new_size - self._on_resize() + if self._on_resize: + self._on_resize() return True return False |