diff --git a/synapse/__init__.py b/synapse/__init__.py
index 1bd03462ac..aa9a3269c0 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.37.1"
+__version__ = "1.38.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py
new file mode 100644
index 0000000000..01dc0c4237
--- /dev/null
+++ b/synapse/_scripts/review_recent_signups.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import sys
+import time
+from datetime import datetime
+from typing import List
+
+import attr
+
+from synapse.config._base import RootConfig, find_config_files, read_config_files
+from synapse.config.database import DatabaseConfig
+from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
+from synapse.storage.engines import create_engine
+
+
+class ReviewConfig(RootConfig):
+ "A config class that just pulls out the database config"
+ config_classes = [DatabaseConfig]
+
+
+@attr.s(auto_attribs=True)
+class UserInfo:
+ user_id: str
+ creation_ts: int
+ emails: List[str] = attr.Factory(list)
+ private_rooms: List[str] = attr.Factory(list)
+ public_rooms: List[str] = attr.Factory(list)
+ ips: List[str] = attr.Factory(list)
+
+
+def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]:
+ """Fetches recently registered users and some info on them."""
+
+ sql = """
+ SELECT name, creation_ts FROM users
+ WHERE
+ ? <= creation_ts
+ AND deactivated = 0
+ """
+
+ txn.execute(sql, (since_ms / 1000,))
+
+ user_infos = [UserInfo(user_id, creation_ts) for user_id, creation_ts in txn]
+
+ for user_info in user_infos:
+ user_info.emails = DatabasePool.simple_select_onecol_txn(
+ txn,
+ table="user_threepids",
+ keyvalues={"user_id": user_info.user_id, "medium": "email"},
+ retcol="address",
+ )
+
+ sql = """
+ SELECT room_id, canonical_alias, name, join_rules
+ FROM local_current_membership
+ INNER JOIN room_stats_state USING (room_id)
+ WHERE user_id = ? AND membership = 'join'
+ """
+
+ txn.execute(sql, (user_info.user_id,))
+ for room_id, canonical_alias, name, join_rules in txn:
+ if join_rules == "public":
+ user_info.public_rooms.append(canonical_alias or name or room_id)
+ else:
+ user_info.private_rooms.append(canonical_alias or name or room_id)
+
+ user_info.ips = DatabasePool.simple_select_onecol_txn(
+ txn,
+ table="user_ips",
+ keyvalues={"user_id": user_info.user_id},
+ retcol="ip",
+ )
+
+ return user_infos
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-c",
+ "--config-path",
+ action="append",
+ metavar="CONFIG_FILE",
+ help="The config files for Synapse.",
+ required=True,
+ )
+ parser.add_argument(
+ "-s",
+ "--since",
+ metavar="duration",
+ help="Specify how far back to review user registrations for, defaults to 7d (i.e. 7 days).",
+ default="7d",
+ )
+ parser.add_argument(
+ "-e",
+ "--exclude-emails",
+ action="store_true",
+ help="Exclude users that have validated email addresses",
+ )
+ parser.add_argument(
+ "-u",
+ "--only-users",
+ action="store_true",
+ help="Only print user IDs that match.",
+ )
+
+ config = ReviewConfig()
+
+ config_args = parser.parse_args(sys.argv[1:])
+ config_files = find_config_files(search_paths=config_args.config_path)
+ config_dict = read_config_files(config_files)
+ config.parse_config_dict(
+ config_dict,
+ )
+
+ since_ms = time.time() * 1000 - config.parse_duration(config_args.since)
+ exclude_users_with_email = config_args.exclude_emails
+ include_context = not config_args.only_users
+
+ for database_config in config.database.databases:
+ if "main" in database_config.databases:
+ break
+
+ engine = create_engine(database_config.config)
+
+ with make_conn(database_config, engine, "review_recent_signups") as db_conn:
+ user_infos = get_recent_users(db_conn.cursor(), since_ms)
+
+ for user_info in user_infos:
+ if exclude_users_with_email and user_info.emails:
+ continue
+
+ if include_context:
+ print_public_rooms = ""
+ if user_info.public_rooms:
+ print_public_rooms = "(" + ", ".join(user_info.public_rooms[:3])
+
+ if len(user_info.public_rooms) > 3:
+ print_public_rooms += ", ..."
+
+ print_public_rooms += ")"
+
+ print("# Created:", datetime.fromtimestamp(user_info.creation_ts))
+ print("# Email:", ", ".join(user_info.emails) or "None")
+ print("# IPs:", ", ".join(user_info.ips))
+ print(
+ "# Number joined public rooms:",
+ len(user_info.public_rooms),
+ print_public_rooms,
+ )
+ print("# Number joined private rooms:", len(user_info.private_rooms))
+ print("#")
+
+ print(user_info.user_id)
+
+ if include_context:
+ print()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index f8b068e563..307f5f9a94 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
@@ -28,10 +28,8 @@ from synapse.api.errors import (
InvalidClientTokenError,
MissingClientTokenError,
)
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.appservice import ApplicationService
from synapse.events import EventBase
-from synapse.events.builder import EventBuilder
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
@@ -39,7 +37,6 @@ from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
-from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -47,15 +44,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-AuthEventTypes = (
- EventTypes.Create,
- EventTypes.Member,
- EventTypes.PowerLevels,
- EventTypes.JoinRules,
- EventTypes.RoomHistoryVisibility,
- EventTypes.ThirdPartyInvite,
-)
-
# guests always get this device id.
GUEST_DEVICE_ID = "guest_device"
@@ -66,9 +54,7 @@ class _InvalidMacaroonException(Exception):
class Auth:
"""
- FIXME: This class contains a mix of functions for authenticating users
- of our client-server API and authenticating events added to room graphs.
- The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
+ This class contains functions for authenticating users of our client-server API.
"""
def __init__(self, hs: "HomeServer"):
@@ -90,18 +76,6 @@ class Auth:
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
- async def check_from_context(
- self, room_version: str, event, context, do_sig_check=True
- ) -> None:
- auth_event_ids = event.auth_event_ids()
- auth_events_by_id = await self.store.get_events(auth_event_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
-
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- event_auth.check(
- room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
- )
-
async def check_user_in_room(
self,
room_id: str,
@@ -152,13 +126,6 @@ class Auth:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
- async def check_host_in_room(self, room_id: str, host: str) -> bool:
- with Measure(self.clock, "check_host_in_room"):
- return await self.store.is_host_joined(room_id, host)
-
- def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
- return event_auth.get_public_keys(invite_event)
-
async def get_user_by_req(
self,
request: SynapseRequest,
@@ -489,44 +456,6 @@ class Auth:
"""
return await self.store.is_server_admin(user)
- def compute_auth_events(
- self,
- event: Union[EventBase, EventBuilder],
- current_state_ids: StateMap[str],
- for_verification: bool = False,
- ) -> List[str]:
- """Given an event and current state return the list of event IDs used
- to auth an event.
-
- If `for_verification` is False then only return auth events that
- should be added to the event's `auth_events`.
-
- Returns:
- List of event IDs.
- """
-
- if event.type == EventTypes.Create:
- return []
-
- # Currently we ignore the `for_verification` flag even though there are
- # some situations where we can drop particular auth events when adding
- # to the event's `auth_events` (e.g. joins pointing to previous joins
- # when room is publicly joinable). Dropping event IDs has the
- # advantage that the auth chain for the room grows slower, but we use
- # the auth chain in state resolution v2 to order events, which means
- # care must be taken if dropping events to ensure that it doesn't
- # introduce undesirable "state reset" behaviour.
- #
- # All of which sounds a bit tricky so we don't bother for now.
-
- auth_ids = []
- for etype, state_key in event_auth.auth_types_for_event(event):
- auth_ev_id = current_state_ids.get((etype, state_key))
- if auth_ev_id:
- auth_ids.append(auth_ev_id)
-
- return auth_ids
-
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 8879136881..b30571fe49 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -21,7 +21,7 @@ import socket
import sys
import traceback
import warnings
-from typing import Awaitable, Callable, Iterable
+from typing import TYPE_CHECKING, Awaitable, Callable, Iterable
from cryptography.utils import CryptographyDeprecationWarning
from typing_extensions import NoReturn
@@ -41,10 +41,14 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
+from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# list of tuples of function, args list, kwargs dict
@@ -312,7 +316,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
-async def start(hs: "synapse.server.HomeServer"):
+async def start(hs: "HomeServer"):
"""
Start a Synapse server or worker.
@@ -365,6 +369,9 @@ async def start(hs: "synapse.server.HomeServer"):
load_legacy_spam_checkers(hs)
+ # If we've configured an expiry time for caches, start the background job now.
+ setup_expire_lru_cache_entries(hs)
+
# It is now safe to start your Synapse.
hs.start_listening()
hs.get_datastore().db_pool.start_profiling()
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 23ca0c83c1..06fbd1166b 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -5,6 +5,7 @@ from synapse.config import (
api,
appservice,
auth,
+ cache,
captcha,
cas,
consent,
@@ -88,6 +89,7 @@ class RootConfig:
tracer: tracer.TracerConfig
redis: redis.RedisConfig
modules: modules.ModulesConfig
+ caches: cache.CacheConfig
federation: federation.FederationConfig
config_classes: List = ...
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 91165ee1ce..7789b40323 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -116,35 +116,41 @@ class CacheConfig(Config):
#event_cache_size: 10K
caches:
- # Controls the global cache factor, which is the default cache factor
- # for all caches if a specific factor for that cache is not otherwise
- # set.
- #
- # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
- # variable. Setting by environment variable takes priority over
- # setting through the config file.
- #
- # Defaults to 0.5, which will half the size of all caches.
- #
- #global_factor: 1.0
-
- # A dictionary of cache name to cache factor for that individual
- # cache. Overrides the global cache factor for a given cache.
- #
- # These can also be set through environment variables comprised
- # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
- # letters and underscores. Setting by environment variable
- # takes priority over setting through the config file.
- # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
- #
- # Some caches have '*' and other characters that are not
- # alphanumeric or underscores. These caches can be named with or
- # without the special characters stripped. For example, to specify
- # the cache factor for `*stateGroupCache*` via an environment
- # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
- #
- per_cache_factors:
- #get_users_who_share_room_with_user: 2.0
+ # Controls the global cache factor, which is the default cache factor
+ # for all caches if a specific factor for that cache is not otherwise
+ # set.
+ #
+ # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
+ # variable. Setting by environment variable takes priority over
+ # setting through the config file.
+ #
+ # Defaults to 0.5, which will half the size of all caches.
+ #
+ #global_factor: 1.0
+
+ # A dictionary of cache name to cache factor for that individual
+ # cache. Overrides the global cache factor for a given cache.
+ #
+ # These can also be set through environment variables comprised
+ # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
+ # letters and underscores. Setting by environment variable
+ # takes priority over setting through the config file.
+ # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
+ #
+ # Some caches have '*' and other characters that are not
+ # alphanumeric or underscores. These caches can be named with or
+ # without the special characters stripped. For example, to specify
+ # the cache factor for `*stateGroupCache*` via an environment
+ # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
+ #
+ per_cache_factors:
+ #get_users_who_share_room_with_user: 2.0
+
+ # Controls how long an entry can be in a cache without having been
+ # accessed before being evicted. Defaults to None, which means
+ # entries are never evicted based on time.
+ #
+ #expiry_time: 30m
"""
def read_config(self, config, **kwargs):
@@ -200,6 +206,12 @@ class CacheConfig(Config):
e.message # noqa: B306, DependencyException.message is a property
)
+ expiry_time = cache_config.get("expiry_time")
+ if expiry_time:
+ self.expiry_time_msec = self.parse_duration(expiry_time)
+ else:
+ self.expiry_time_msec = None
+
# Resize all caches (if necessary) with the new factors we've loaded
self.resize_all_caches()
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index fb48ec8541..26e3950859 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
- from synapse.api.auth import Auth
+ from synapse.handlers.event_auth import EventAuthHandler
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -66,7 +66,7 @@ class EventBuilder:
"""
_state: StateHandler
- _auth: "Auth"
+ _event_auth_handler: "EventAuthHandler"
_store: DataStore
_clock: Clock
_hostname: str
@@ -125,7 +125,9 @@ class EventBuilder:
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
- auth_event_ids = self._auth.compute_auth_events(self, state_ids)
+ auth_event_ids = self._event_auth_handler.compute_auth_events(
+ self, state_ids
+ )
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
@@ -193,7 +195,7 @@ class EventBuilderFactory:
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- self.auth = hs.get_auth()
+ self._event_auth_handler = hs.get_event_auth_handler()
def new(self, room_version: str, key_values: dict) -> EventBuilder:
"""Generate an event builder appropriate for the given room version
@@ -229,7 +231,7 @@ class EventBuilderFactory:
return EventBuilder(
store=self.store,
state=self.state,
- auth=self.auth,
+ event_auth_handler=self._event_auth_handler,
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e93b7577fe..bf67d0f574 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -108,9 +108,9 @@ class FederationServer(FederationBase):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.auth = hs.get_auth()
self.handler = hs.get_federation_handler()
self.state = hs.get_state_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
self.device_handler = hs.get_device_handler()
@@ -148,6 +148,41 @@ class FederationServer(FederationBase):
self._room_prejoin_state_types = hs.config.api.room_prejoin_state
+ # Whether we have started handling old events in the staging area.
+ self._started_handling_of_staged_events = False
+
+ @wrap_as_background_process("_handle_old_staged_events")
+ async def _handle_old_staged_events(self) -> None:
+ """Handle old staged events by fetching all rooms that have staged
+ events and start the processing of each of those rooms.
+ """
+
+ # Get all the rooms IDs with staged events.
+ room_ids = await self.store.get_all_rooms_with_staged_incoming_events()
+
+ # We then shuffle them so that if there are multiple instances doing
+ # this work they're less likely to collide.
+ random.shuffle(room_ids)
+
+ for room_id in room_ids:
+ room_version = await self.store.get_room_version(room_id)
+
+ # Try and acquire the processing lock for the room, if we get it start a
+ # background process for handling the events in the room.
+ lock = await self.store.try_acquire_lock(
+ _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
+ )
+ if lock:
+ logger.info("Handling old staged inbound events in %s", room_id)
+ self._process_incoming_pdus_in_room_inner(
+ room_id,
+ room_version,
+ lock,
+ )
+
+ # We pause a bit so that we don't start handling all rooms at once.
+ await self._clock.sleep(random.uniform(0, 0.1))
+
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
@@ -166,6 +201,12 @@ class FederationServer(FederationBase):
async def on_incoming_transaction(
self, origin: str, transaction_data: JsonDict
) -> Tuple[int, Dict[str, Any]]:
+ # If we receive a transaction we should make sure that kick off handling
+ # any old events in the staging area.
+ if not self._started_handling_of_staged_events:
+ self._started_handling_of_staged_events = True
+ self._handle_old_staged_events()
+
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -420,7 +461,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -453,7 +494,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -882,25 +923,28 @@ class FederationServer(FederationBase):
room_id: str,
room_version: RoomVersion,
lock: Lock,
- latest_origin: str,
- latest_event: EventBase,
+ latest_origin: Optional[str] = None,
+ latest_event: Optional[EventBase] = None,
) -> None:
"""Process events in the staging area for the given room.
The latest_origin and latest_event args are the latest origin and event
- received.
+ received (or None to simply pull the next event from the database).
"""
# The common path is for the event we just received be the only event in
# the room, so instead of pulling the event out of the DB and parsing
# the event we just pull out the next event ID and check if that matches.
- next_origin, next_event_id = await self.store.get_next_staged_event_id_for_room(
- room_id
- )
- if next_origin == latest_origin and next_event_id == latest_event.event_id:
- origin = latest_origin
- event = latest_event
- else:
+ if latest_event is not None and latest_origin is not None:
+ (
+ next_origin,
+ next_event_id,
+ ) = await self.store.get_next_staged_event_id_for_room(room_id)
+ if next_origin != latest_origin or next_event_id != latest_event.event_id:
+ latest_origin = None
+ latest_event = None
+
+ if latest_origin is None or latest_event is None:
next = await self.store.get_next_staged_event_for_room(
room_id, room_version
)
@@ -908,6 +952,9 @@ class FederationServer(FederationBase):
return
origin, event = next
+ else:
+ origin = latest_origin
+ event = latest_event
# We loop round until there are no more events in the room in the
# staging area, or we fail to get the lock (which means another process
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f72ded038e..d75a8b15c3 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -62,9 +62,16 @@ class AdminHandler(BaseHandler):
if ret:
profile = await self.store.get_profileinfo(user.localpart)
threepids = await self.store.user_get_threepids(user.to_string())
+ external_ids = [
+ ({"auth_provider": auth_provider, "external_id": external_id})
+ for auth_provider, external_id in await self.store.get_external_ids_by_user(
+ user.to_string()
+ )
+ ]
ret["displayname"] = profile.display_name
ret["avatar_url"] = profile.avatar_url
ret["threepids"] = threepids
+ ret["external_ids"] = external_ids
return ret
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 989996b628..41dbdfd0a1 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Collection, Optional
+from typing import TYPE_CHECKING, Collection, List, Optional, Union
+from synapse import event_auth
from synapse.api.constants import (
EventTypes,
JoinRules,
@@ -20,9 +21,11 @@ from synapse.api.constants import (
RestrictedJoinRuleTypes,
)
from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersion
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
from synapse.types import StateMap
+from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -34,8 +37,63 @@ class EventAuthHandler:
"""
def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
self._store = hs.get_datastore()
+ async def check_from_context(
+ self, room_version: str, event, context, do_sig_check=True
+ ) -> None:
+ auth_event_ids = event.auth_event_ids()
+ auth_events_by_id = await self._store.get_events(auth_event_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
+
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ event_auth.check(
+ room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
+ )
+
+ def compute_auth_events(
+ self,
+ event: Union[EventBase, EventBuilder],
+ current_state_ids: StateMap[str],
+ for_verification: bool = False,
+ ) -> List[str]:
+ """Given an event and current state return the list of event IDs used
+ to auth an event.
+
+ If `for_verification` is False then only return auth events that
+ should be added to the event's `auth_events`.
+
+ Returns:
+ List of event IDs.
+ """
+
+ if event.type == EventTypes.Create:
+ return []
+
+ # Currently we ignore the `for_verification` flag even though there are
+ # some situations where we can drop particular auth events when adding
+ # to the event's `auth_events` (e.g. joins pointing to previous joins
+ # when room is publicly joinable). Dropping event IDs has the
+ # advantage that the auth chain for the room grows slower, but we use
+ # the auth chain in state resolution v2 to order events, which means
+ # care must be taken if dropping events to ensure that it doesn't
+ # introduce undesirable "state reset" behaviour.
+ #
+ # All of which sounds a bit tricky so we don't bother for now.
+
+ auth_ids = []
+ for etype, state_key in event_auth.auth_types_for_event(event):
+ auth_ev_id = current_state_ids.get((etype, state_key))
+ if auth_ev_id:
+ auth_ids.append(auth_ev_id)
+
+ return auth_ids
+
+ async def check_host_in_room(self, room_id: str, host: str) -> bool:
+ with Measure(self._clock, "check_host_in_room"):
+ return await self._store.is_host_joined(room_id, host)
+
async def check_restricted_join_rules(
self,
state_ids: StateMap[str],
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d929c65131..991ec9919a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -250,7 +250,9 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self._event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
if not is_in_room:
logger.info(
"Ignoring PDU from %s as we're not in the room",
@@ -1674,7 +1676,9 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version_id(room_id)
# now check that we are *still* in the room
- is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self._event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
if not is_in_room:
logger.info(
"Got /make_join request for room %s we are no longer in",
@@ -1705,7 +1709,7 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- await self.auth.check_from_context(
+ await self._event_auth_handler.check_from_context(
room_version, event, context, do_sig_check=False
)
@@ -1877,7 +1881,7 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- await self.auth.check_from_context(
+ await self._event_auth_handler.check_from_context(
room_version, event, context, do_sig_check=False
)
except AuthError as e:
@@ -1939,7 +1943,7 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_knock_request`
- await self.auth.check_from_context(
+ await self._event_auth_handler.check_from_context(
room_version, event, context, do_sig_check=False
)
except AuthError as e:
@@ -2111,7 +2115,7 @@ class FederationHandler(BaseHandler):
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2146,7 +2150,9 @@ class FederationHandler(BaseHandler):
)
if event:
- in_room = await self.auth.check_host_in_room(event.room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(
+ event.room_id, origin
+ )
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2499,7 +2505,7 @@ class FederationHandler(BaseHandler):
latest_events: List[str],
limit: int,
) -> List[EventBase]:
- in_room = await self.auth.check_host_in_room(room_id, origin)
+ in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2562,7 +2568,7 @@ class FederationHandler(BaseHandler):
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self.auth.compute_auth_events(
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
@@ -2991,7 +2997,7 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id,
}
- if await self.auth.check_host_in_room(room_id, self.hs.hostname):
+ if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname):
room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
@@ -3011,7 +3017,9 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
- await self.auth.check_from_context(room_version, event, context)
+ await self._event_auth_handler.check_from_context(
+ room_version, event, context
+ )
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
raise e
@@ -3054,7 +3062,9 @@ class FederationHandler(BaseHandler):
)
try:
- await self.auth.check_from_context(room_version, event, context)
+ await self._event_auth_handler.check_from_context(
+ room_version, event, context
+ )
except AuthError as e:
logger.warning("Denying third party invite %r because %s", event, e)
raise e
@@ -3142,7 +3152,7 @@ class FederationHandler(BaseHandler):
last_exception = None # type: Optional[Exception]
# for each public key in the 3pid invite event
- for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
+ for public_key_object in event_auth.get_public_keys(invite_event):
try:
# for each sig on the third_party_invite block of the actual invite
for server, signature_block in signed["signatures"].items():
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2a7a6e6982..3f783947bd 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -385,6 +385,7 @@ class EventCreationHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
+ self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state = hs.get_state_handler()
@@ -597,7 +598,7 @@ class EventCreationHandler:
(e.type, e.state_key): e.event_id for e in auth_events
}
# Actually strip down and use the necessary auth events
- auth_event_ids = self.auth.compute_auth_events(
+ auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
current_state_ids=auth_event_state_map,
for_verification=False,
@@ -1056,7 +1057,9 @@ class EventCreationHandler:
assert event.content["membership"] == Membership.LEAVE
else:
try:
- await self.auth.check_from_context(room_version, event, context)
+ await self._event_auth_handler.check_from_context(
+ room_version, event, context
+ )
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
@@ -1381,7 +1384,7 @@ class EventCreationHandler:
raise AuthError(403, "Redacting server ACL events is not permitted")
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self.auth.compute_auth_events(
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_map = await self.store.get_events(auth_events_ids)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 835d874cee..579b1b93c5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -83,6 +83,7 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
self.config = hs.config
# Room state based off defined presets
@@ -226,7 +227,7 @@ class RoomCreationHandler(BaseHandler):
},
)
old_room_version = await self.store.get_room_version_id(old_room_id)
- await self.auth.check_from_context(
+ await self._event_auth_handler.check_from_context(
old_room_version, tombstone_event, tombstone_context
)
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 266f369883..b585057ec3 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -472,7 +472,7 @@ class SpaceSummaryHandler:
# If this is a request over federation, check if the host is in the room or
# is in one of the spaces specified via the join rules.
elif origin:
- if await self._auth.check_host_in_room(room_id, origin):
+ if await self._event_auth_handler.check_host_in_room(room_id, origin):
return True
# Alternately, if the host has a user in any of the spaces specified
@@ -485,7 +485,9 @@ class SpaceSummaryHandler:
await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
)
for space_id in allowed_rooms:
- if await self._auth.check_host_in_room(space_id, origin):
+ if await self._event_auth_handler.check_host_in_room(
+ space_id, origin
+ ):
return True
# otherwise, check if the room is peekable
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 350646f458..669ea462e2 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -104,7 +104,7 @@ class BulkPushRuleEvaluator:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
- self.auth = hs.get_auth()
+ self._event_auth_handler = hs.get_event_auth_handler()
# Used by `RulesForRoom` to ensure only one thing mutates the cache at a
# time. Keyed off room_id.
@@ -172,7 +172,7 @@ class BulkPushRuleEvaluator:
# not having a power level event is an extreme edge case
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
- auth_events_ids = self.auth.compute_auth_events(
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events_dict = await self.store.get_events(auth_events_ids)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d470cdacde..33c42cf95a 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -111,7 +111,7 @@ def make_conn(
db_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
default_txn_name: str,
-) -> Connection:
+) -> "LoggingDatabaseConnection":
"""Make a new connection to the database and return it.
Returns:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index f2d27ee893..c4474df975 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,6 +16,8 @@ import logging
from queue import Empty, PriorityQueue
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
+from prometheus_client import Gauge
+
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion
@@ -32,6 +34,16 @@ from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
+oldest_pdu_in_federation_staging = Gauge(
+ "synapse_federation_server_oldest_inbound_pdu_in_staging",
+ "The age in seconds since we received the oldest pdu in the federation staging area",
+)
+
+number_pdus_in_federation_queue = Gauge(
+ "synapse_federation_server_number_inbound_pdu_in_staging",
+ "The total number of events in the inbound federation staging",
+)
+
logger = logging.getLogger(__name__)
@@ -54,6 +66,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
500000, "_event_auth_cache", size_callback=len
) # type: LruCache[str, List[Tuple[str, int]]]
+ self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
+
async def get_auth_chain(
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
@@ -1193,6 +1207,40 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return origin, event
+ async def get_all_rooms_with_staged_incoming_events(self) -> List[str]:
+ """Get the room IDs of all events currently staged."""
+ return await self.db_pool.simple_select_onecol(
+ table="federation_inbound_events_staging",
+ keyvalues={},
+ retcol="DISTINCT room_id",
+ desc="get_all_rooms_with_staged_incoming_events",
+ )
+
+ @wrap_as_background_process("_get_stats_for_federation_staging")
+ async def _get_stats_for_federation_staging(self):
+ """Update the prometheus metrics for the inbound federation staging area."""
+
+ def _get_stats_for_federation_staging_txn(txn):
+ txn.execute(
+ "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging"
+ )
+ (count,) = txn.fetchone()
+
+ txn.execute(
+ "SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging"
+ )
+
+ (age,) = txn.fetchone()
+
+ return count, age
+
+ count, age = await self.db_pool.runInteraction(
+ "_get_stats_for_federation_staging", _get_stats_for_federation_staging_txn
+ )
+
+ number_pdus_in_federation_queue.set(count)
+ oldest_pdu_in_federation_staging.set(age)
+
class EventFederationStore(EventFederationWorkerStore):
"""Responsible for storing and serving up the various graphs associated
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 9b4e95e134..ba7075caa5 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -73,20 +73,20 @@ class ProfileWorkerStore(SQLBaseStore):
async def set_profile_displayname(
self, user_localpart: str, new_displayname: Optional[str]
) -> None:
- await self.db_pool.simple_update_one(
+ await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname},
desc="set_profile_displayname",
)
async def set_profile_avatar_url(
self, user_localpart: str, new_avatar_url: Optional[str]
) -> None:
- await self.db_pool.simple_update_one(
+ await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url},
desc="set_profile_avatar_url",
)
diff --git a/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres
index b5fb763ddd..0edc9fe7a2 100644
--- a/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres
+++ b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres
@@ -42,4 +42,4 @@ INSERT INTO background_updates (ordering, update_name, progress_json, depends_on
-- ... and another to do the switcheroo
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
- (6003, 'replace_stream_ordering_column', '{}', 'index_stream_ordering2_ts');
+ (6001, 'replace_stream_ordering_column', '{}', 'index_stream_ordering2_ts');
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index d89e9d9b1d..4b9d0433ff 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import threading
+import weakref
from functools import wraps
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Collection,
@@ -31,10 +34,19 @@ from typing import (
from typing_extensions import Literal
+from twisted.internet import reactor
+
from synapse.config import cache as cache_config
-from synapse.util import caches
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.linked_list import ListNode
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
try:
from pympler.asizeof import Asizer
@@ -82,19 +94,126 @@ def enumerate_leaves(node, depth):
yield m
+P = TypeVar("P")
+
+
+class _TimedListNode(ListNode[P]):
+ """A `ListNode` that tracks last access time."""
+
+ __slots__ = ["last_access_ts_secs"]
+
+ def update_last_access(self, clock: Clock):
+ self.last_access_ts_secs = int(clock.time())
+
+
+# Whether to insert new cache entries to the global list. We only add to it if
+# time based eviction is enabled.
+USE_GLOBAL_LIST = False
+
+# A linked list of all cache entries, allowing efficient time based eviction.
+GLOBAL_ROOT = ListNode["_Node"].create_root_node()
+
+
+@wrap_as_background_process("LruCache._expire_old_entries")
+async def _expire_old_entries(clock: Clock, expiry_seconds: int):
+ """Walks the global cache list to find cache entries that haven't been
+ accessed in the given number of seconds.
+ """
+
+ now = int(clock.time())
+ node = GLOBAL_ROOT.prev_node
+ assert node is not None
+
+ i = 0
+
+ logger.debug("Searching for stale caches")
+
+ while node is not GLOBAL_ROOT:
+ # Only the root node isn't a `_TimedListNode`.
+ assert isinstance(node, _TimedListNode)
+
+ if node.last_access_ts_secs > now - expiry_seconds:
+ break
+
+ cache_entry = node.get_cache_entry()
+ next_node = node.prev_node
+
+ # The node should always have a reference to a cache entry and a valid
+ # `prev_node`, as we only drop them when we remove the node from the
+ # list.
+ assert next_node is not None
+ assert cache_entry is not None
+ cache_entry.drop_from_cache()
+
+ # If we do lots of work at once we yield to allow other stuff to happen.
+ if (i + 1) % 10000 == 0:
+ logger.debug("Waiting during drop")
+ await clock.sleep(0)
+ logger.debug("Waking during drop")
+
+ node = next_node
+
+ # If we've yielded then our current node may have been evicted, so we
+ # need to check that its still valid.
+ if node.prev_node is None:
+ break
+
+ i += 1
+
+ logger.info("Dropped %d items from caches", i)
+
+
+def setup_expire_lru_cache_entries(hs: "HomeServer"):
+ """Start a background job that expires all cache entries if they have not
+ been accessed for the given number of seconds.
+ """
+ if not hs.config.caches.expiry_time_msec:
+ return
+
+ logger.info(
+ "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000
+ )
+
+ global USE_GLOBAL_LIST
+ USE_GLOBAL_LIST = True
+
+ clock = hs.get_clock()
+ clock.looping_call(
+ _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000
+ )
+
+
class _Node:
- __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
+ __slots__ = [
+ "_list_node",
+ "_global_list_node",
+ "_cache",
+ "key",
+ "value",
+ "callbacks",
+ "memory",
+ ]
def __init__(
self,
- prev_node,
- next_node,
+ root: "ListNode[_Node]",
key,
value,
+ cache: "weakref.ReferenceType[LruCache]",
+ clock: Clock,
callbacks: Collection[Callable[[], None]] = (),
):
- self.prev_node = prev_node
- self.next_node = next_node
+ self._list_node = ListNode.insert_after(self, root)
+ self._global_list_node = None
+ if USE_GLOBAL_LIST:
+ self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT)
+ self._global_list_node.update_last_access(clock)
+
+ # We store a weak reference to the cache object so that this _Node can
+ # remove itself from the cache. If the cache is dropped we ensure we
+ # remove our entries in the lists.
+ self._cache = cache
+
self.key = key
self.value = value
@@ -116,11 +235,16 @@ class _Node:
self.memory = (
_get_size_of(key)
+ _get_size_of(value)
+ + _get_size_of(self._list_node, recurse=False)
+ _get_size_of(self.callbacks, recurse=False)
+ _get_size_of(self, recurse=False)
)
self.memory += _get_size_of(self.memory, recurse=False)
+ if self._global_list_node:
+ self.memory += _get_size_of(self._global_list_node, recurse=False)
+ self.memory += _get_size_of(self._global_list_node.last_access_ts_secs)
+
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
"""Add to stored list of callbacks, removing duplicates."""
@@ -147,6 +271,32 @@ class _Node:
self.callbacks = None
+ def drop_from_cache(self) -> None:
+ """Drop this node from the cache.
+
+ Ensures that the entry gets removed from the cache and that we get
+ removed from all lists.
+ """
+ cache = self._cache()
+ if not cache or not cache.pop(self.key, None):
+ # `cache.pop` should call `drop_from_lists()`, unless this Node had
+ # already been removed from the cache.
+ self.drop_from_lists()
+
+ def drop_from_lists(self) -> None:
+ """Remove this node from the cache lists."""
+ self._list_node.remove_from_list()
+
+ if self._global_list_node:
+ self._global_list_node.remove_from_list()
+
+ def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None:
+ """Moves this node to the front of all the lists its in."""
+ self._list_node.move_after(cache_list_root)
+ if self._global_list_node:
+ self._global_list_node.move_after(GLOBAL_ROOT)
+ self._global_list_node.update_last_access(clock)
+
class LruCache(Generic[KT, VT]):
"""
@@ -163,6 +313,7 @@ class LruCache(Generic[KT, VT]):
size_callback: Optional[Callable] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
+ clock: Optional[Clock] = None,
):
"""
Args:
@@ -188,6 +339,13 @@ class LruCache(Generic[KT, VT]):
apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
"""
+ # Default `clock` to something sensible. Note that we rename it to
+ # `real_clock` so that mypy doesn't think its still `Optional`.
+ if clock is None:
+ real_clock = Clock(reactor)
+ else:
+ real_clock = clock
+
cache = cache_type()
self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config
@@ -219,17 +377,31 @@ class LruCache(Generic[KT, VT]):
# this is exposed for access from outside this class
self.metrics = metrics
- list_root = _Node(None, None, None, None)
- list_root.next_node = list_root
- list_root.prev_node = list_root
+ # We create a single weakref to self here so that we don't need to keep
+ # creating more each time we create a `_Node`.
+ weak_ref_to_self = weakref.ref(self)
+
+ list_root = ListNode[_Node].create_root_node()
lock = threading.Lock()
def evict():
while cache_len() > self.max_size:
+ # Get the last node in the list (i.e. the oldest node).
todelete = list_root.prev_node
- evicted_len = delete_node(todelete)
- cache.pop(todelete.key, None)
+
+ # The list root should always have a valid `prev_node` if the
+ # cache is not empty.
+ assert todelete is not None
+
+ # The node should always have a reference to a cache entry, as
+ # we only drop the cache entry when we remove the node from the
+ # list.
+ node = todelete.get_cache_entry()
+ assert node is not None
+
+ evicted_len = delete_node(node)
+ cache.pop(node.key, None)
if metrics:
metrics.inc_evictions(evicted_len)
@@ -255,11 +427,7 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
- prev_node = list_root
- next_node = prev_node.next_node
- node = _Node(prev_node, next_node, key, value, callbacks)
- prev_node.next_node = node
- next_node.prev_node = node
+ node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks)
cache[key] = node
if size_callback:
@@ -268,23 +436,11 @@ class LruCache(Generic[KT, VT]):
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory)
- def move_node_to_front(node):
- prev_node = node.prev_node
- next_node = node.next_node
- prev_node.next_node = next_node
- next_node.prev_node = prev_node
- prev_node = list_root
- next_node = prev_node.next_node
- node.prev_node = prev_node
- node.next_node = next_node
- prev_node.next_node = node
- next_node.prev_node = node
-
- def delete_node(node):
- prev_node = node.prev_node
- next_node = node.next_node
- prev_node.next_node = next_node
- next_node.prev_node = prev_node
+ def move_node_to_front(node: _Node):
+ node.move_to_front(real_clock, list_root)
+
+ def delete_node(node: _Node) -> int:
+ node.drop_from_lists()
deleted_len = 1
if size_callback:
@@ -411,10 +567,13 @@ class LruCache(Generic[KT, VT]):
@synchronized
def cache_clear() -> None:
- list_root.next_node = list_root
- list_root.prev_node = list_root
for node in cache.values():
node.run_and_clear_callbacks()
+ node.drop_from_lists()
+
+ assert list_root.next_node == list_root
+ assert list_root.prev_node == list_root
+
cache.clear()
if size_callback:
cached_cache_len[0] = 0
@@ -484,3 +643,11 @@ class LruCache(Generic[KT, VT]):
self._on_resize()
return True
return False
+
+ def __del__(self) -> None:
+ # We're about to be deleted, so we make sure to clear up all the nodes
+ # and run callbacks, etc.
+ #
+ # This happens e.g. in the sync code where we have an expiring cache of
+ # lru caches.
+ self.clear()
diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
new file mode 100644
index 0000000000..a456b136f0
--- /dev/null
+++ b/synapse/util/linked_list.py
@@ -0,0 +1,150 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A circular doubly linked list implementation.
+"""
+
+import threading
+from typing import Generic, Optional, Type, TypeVar
+
+P = TypeVar("P")
+LN = TypeVar("LN", bound="ListNode")
+
+
+class ListNode(Generic[P]):
+ """A node in a circular doubly linked list, with an (optional) reference to
+ a cache entry.
+
+ The reference should only be `None` for the root node or if the node has
+ been removed from the list.
+ """
+
+ # A lock to protect mutating the list prev/next pointers.
+ _LOCK = threading.Lock()
+
+ # We don't use attrs here as in py3.6 you can't have `attr.s(slots=True)`
+ # and inherit from `Generic` for some reason
+ __slots__ = [
+ "cache_entry",
+ "prev_node",
+ "next_node",
+ ]
+
+ def __init__(self, cache_entry: Optional[P] = None) -> None:
+ self.cache_entry = cache_entry
+ self.prev_node: Optional[ListNode[P]] = None
+ self.next_node: Optional[ListNode[P]] = None
+
+ @classmethod
+ def create_root_node(cls: Type["ListNode[P]"]) -> "ListNode[P]":
+ """Create a new linked list by creating a "root" node, which is a node
+ that has prev_node/next_node pointing to itself and no associated cache
+ entry.
+ """
+ root = cls()
+ root.prev_node = root
+ root.next_node = root
+ return root
+
+ @classmethod
+ def insert_after(
+ cls: Type[LN],
+ cache_entry: P,
+ node: "ListNode[P]",
+ ) -> LN:
+ """Create a new list node that is placed after the given node.
+
+ Args:
+ cache_entry: The associated cache entry.
+ node: The existing node in the list to insert the new entry after.
+ """
+ new_node = cls(cache_entry)
+ with cls._LOCK:
+ new_node._refs_insert_after(node)
+ return new_node
+
+ def remove_from_list(self):
+ """Remove this node from the list."""
+ with self._LOCK:
+ self._refs_remove_node_from_list()
+
+ # We drop the reference to the cache entry to break the reference cycle
+ # between the list node and cache entry, allowing the two to be dropped
+ # immediately rather than at the next GC.
+ self.cache_entry = None
+
+ def move_after(self, node: "ListNode"):
+ """Move this node from its current location in the list to after the
+ given node.
+ """
+ with self._LOCK:
+ # We assert that both this node and the target node is still "alive".
+ assert self.prev_node
+ assert self.next_node
+ assert node.prev_node
+ assert node.next_node
+
+ assert self is not node
+
+ # Remove self from the list
+ self._refs_remove_node_from_list()
+
+ # Insert self back into the list, after target node
+ self._refs_insert_after(node)
+
+ def _refs_remove_node_from_list(self):
+ """Internal method to *just* remove the node from the list, without
+ e.g. clearing out the cache entry.
+ """
+ if self.prev_node is None or self.next_node is None:
+ # We've already been removed from the list.
+ return
+
+ prev_node = self.prev_node
+ next_node = self.next_node
+
+ prev_node.next_node = next_node
+ next_node.prev_node = prev_node
+
+ # We set these to None so that we don't get circular references,
+ # allowing us to be dropped without having to go via the GC.
+ self.prev_node = None
+ self.next_node = None
+
+ def _refs_insert_after(self, node: "ListNode"):
+ """Internal method to insert the node after the given node."""
+
+ # This method should only be called when we're not already in the list.
+ assert self.prev_node is None
+ assert self.next_node is None
+
+ # We expect the given node to be in the list and thus have valid
+ # prev/next refs.
+ assert node.next_node
+ assert node.prev_node
+
+ prev_node = node
+ next_node = node.next_node
+
+ self.prev_node = prev_node
+ self.next_node = next_node
+
+ prev_node.next_node = self
+ next_node.prev_node = self
+
+ def get_cache_entry(self) -> Optional[P]:
+ """Get the cache entry, returns None if this is the root node (i.e.
+ cache_entry is None) or if the entry has been dropped.
+ """
+ return self.cache_entry
|