diff options
Diffstat (limited to 'synapse')
51 files changed, 809 insertions, 257 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py index af8d59cf87..691f8f9adf 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -98,11 +98,14 @@ class EventTypes: Retention = "m.room.retention" - Presence = "m.presence" - Dummy = "org.matrix.dummy_event" +class EduTypes: + Presence = "m.presence" + RoomKeyRequest = "m.room_key_request" + + class RejectedReason: AUTH_ERROR = "auth_error" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 5d9d5a228f..c3f07bc1a3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Any, Optional, Tuple +from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError from synapse.types import Requester @@ -42,7 +42,9 @@ class Ratelimiter: # * How many times an action has occurred since a point in time # * The point in time # * The rate_hz of this particular entry. This can vary per request - self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + self.actions = ( + OrderedDict() + ) # type: OrderedDict[Hashable, Tuple[float, int, float]] def can_requester_do_action( self, @@ -82,7 +84,7 @@ class Ratelimiter: def can_do_action( self, - key: Any, + key: Hashable, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -175,7 +177,7 @@ class Ratelimiter: def ratelimit( self, - key: Any, + key: Hashable, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index b4bd4d8e7a..9f99651aa2 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -210,7 +210,9 @@ def start(config_options): config.update_user_directory = False config.run_background_tasks = False config.start_pushers = False + config.pusher_shard_config.instances = [] config.send_federation = False + config.federation_shard_config.instances = [] synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 6526acb2f2..dc0d3eb725 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -645,9 +645,6 @@ class GenericWorkerServer(HomeServer): self.get_tcp_replication().start_replication(self) - async def remove_pusher(self, app_id, push_key, user_id): - self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - @cache_in_self def get_replication_data_handler(self): return GenericWorkerReplicationHandler(self) @@ -922,22 +919,6 @@ def start(config_options): # For other worker types we force this to off. config.appservice.notify_appservices = False - if config.worker_app == "synapse.app.pusher": - if config.server.start_pushers: - sys.stderr.write( - "\nThe pushers must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``start_pushers: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.server.start_pushers = True - else: - # For other worker types we force this to off. - config.server.start_pushers = False - if config.worker_app == "synapse.app.user_dir": if config.server.update_user_directory: sys.stderr.write( @@ -954,22 +935,6 @@ def start(config_options): # For other worker types we force this to off. config.server.update_user_directory = False - if config.worker_app == "synapse.app.federation_sender": - if config.worker.send_federation: - sys.stderr.write( - "\nThe send_federation must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``send_federation: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.worker.send_federation = True - else: - # For other worker types we force this to off. - config.worker.send_federation = False - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts hs = GenericWorkerServer( diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 97399eb9ba..4026966711 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -21,7 +21,7 @@ import os from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Iterable, List, MutableMapping, Optional +from typing import Any, Iterable, List, MutableMapping, Optional, Union import attr import jinja2 @@ -147,7 +147,20 @@ class Config: return int(value) * size @staticmethod - def parse_duration(value): + def parse_duration(value: Union[str, int]) -> int: + """Convert a duration as a string or integer to a number of milliseconds. + + If an integer is provided it is treated as milliseconds and is unchanged. + + String durations can have a suffix of 's', 'm', 'h', 'd', 'w', or 'y'. + No suffix is treated as milliseconds. + + Args: + value: The duration to parse. + + Returns: + The number of milliseconds in the duration. + """ if isinstance(value, int): return value second = 1000 @@ -831,22 +844,23 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key.""" - # If multiple instances are not defined we always return true - if not self.instances or len(self.instances) == 1: - return True + # If no instances are defined we assume some other worker is handling + # this. + if not self.instances: + return False - return self.get_instance(key) == instance_name + return self._get_instance(key) == instance_name - def get_instance(self, key: str) -> str: + def _get_instance(self, key: str) -> str: """Get the instance responsible for handling the given key. - Note: For things like federation sending the config for which instance - is sending is known only to the sender instance if there is only one. - Therefore `should_handle` should be used where possible. + Note: For federation sending and pushers the config for which instance + is sending is known only to the sender instance, so we don't expose this + method by default. """ if not self.instances: - return "master" + raise Exception("Unknown worker") if len(self.instances) == 1: return self.instances[0] @@ -863,4 +877,21 @@ class ShardedWorkerHandlingConfig: return self.instances[remainder] +@attr.s +class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): + """A version of `ShardedWorkerHandlingConfig` that is used for config + options where all instances know which instances are responsible for the + sharded work. + """ + + def __attrs_post_init__(self): + # We require that `self.instances` is non-empty. + if not self.instances: + raise Exception("Got empty list of instances for shard config") + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key.""" + return self._get_instance(key) + + __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 70025b5d60..db16c86f50 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -149,4 +149,6 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + +class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 9f3c57e6a1..55e4db5442 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -41,6 +41,10 @@ class FederationConfig(Config): ) self.federation_metrics_domains = set(federation_metrics_domains) + self.allow_profile_lookup_over_federation = config.get( + "allow_profile_lookup_over_federation", True + ) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ ## Federation ## @@ -66,6 +70,12 @@ class FederationConfig(Config): #federation_metrics_domains: # - matrix.org # - example.com + + # Uncomment to disable profile lookup over federation. By default, the + # Federation API allows other homeservers to obtain profile data of any user + # on this homeserver. Defaults to 'true'. + # + #allow_profile_lookup_over_federation: false """ diff --git a/synapse/config/push.py b/synapse/config/push.py index 3adbfb73e6..7831a2ef79 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ShardedWorkerHandlingConfig +from ._base import Config class PushConfig(Config): @@ -27,9 +27,6 @@ class PushConfig(Config): "group_unread_count_by_room", True ) - pusher_instances = config.get("pusher_instances") or [] - self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) - # There was a a 'redact_content' setting but mistakenly read from the # 'email'section'. Check for the flag in the 'push' section, and log, # but do not honour it to avoid nasty surprises when people upgrade. diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index def33a60ad..847d25122c 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -102,6 +102,16 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + # Ratelimit cross-user key requests: + # * For local requests this is keyed by the sending device. + # * For requests received over federation this is keyed by the origin. + # + # Note that this isn't exposed in the configuration as it is obscure. + self.rc_key_requests = RateLimitConfig( + config.get("rc_key_requests", {}), + defaults={"per_second": 20, "burst_count": 100}, + ) + self.rc_3pid_validation = RateLimitConfig( config.get("rc_3pid_validation") or {}, defaults={"per_second": 0.003, "burst_count": 5}, diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 52849c3256..69d9de5a43 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -206,7 +206,6 @@ class ContentRepositoryConfig(Config): def generate_config_section(self, data_dir_path, **kwargs): media_store = os.path.join(data_dir_path, "media_store") - uploads_path = os.path.join(data_dir_path, "uploads") formatted_thumbnail_sizes = "".join( THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES diff --git a/synapse/config/server.py b/synapse/config/server.py index 6f3325ff81..2afca36e7d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -263,6 +263,12 @@ class ServerConfig(Config): False, ) + # Whether to retrieve and display profile data for a user when they + # are invited to a room + self.include_profile_data_on_invite = config.get( + "include_profile_data_on_invite", True + ) + if "restrict_public_rooms_to_local_users" in config and ( "allow_public_rooms_without_auth" in config or "allow_public_rooms_over_federation" in config @@ -391,7 +397,6 @@ class ServerConfig(Config): if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" - self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, # for testing. The value defines the number of milliseconds to pause before @@ -848,6 +853,14 @@ class ServerConfig(Config): # #limit_profile_requests_to_users_who_share_rooms: true + # Uncomment to prevent a user's profile data from being retrieved and + # displayed in a room until they have joined it. By default, a user's + # profile data is included in an invite event, regardless of the values + # of the above two settings, and whether or not the users share a server. + # Defaults to 'true'. + # + #include_profile_data_on_invite: false + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index c8d19c5d6b..8d05ef173c 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -24,32 +24,46 @@ class UserDirectoryConfig(Config): section = "userdirectory" def read_config(self, config, **kwargs): - self.user_directory_search_enabled = True - self.user_directory_search_all_users = False - user_directory_config = config.get("user_directory", None) - if user_directory_config: - self.user_directory_search_enabled = user_directory_config.get( - "enabled", True - ) - self.user_directory_search_all_users = user_directory_config.get( - "search_all_users", False - ) + user_directory_config = config.get("user_directory") or {} + self.user_directory_search_enabled = user_directory_config.get("enabled", True) + self.user_directory_search_all_users = user_directory_config.get( + "search_all_users", False + ) + self.user_directory_search_prefer_local_users = user_directory_config.get( + "prefer_local_users", False + ) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ # User Directory configuration # - # 'enabled' defines whether users can search the user directory. If - # false then empty responses are returned to all queries. Defaults to - # true. - # - # 'search_all_users' defines whether to search all users visible to your HS - # when searching the user directory, rather than limiting to users visible - # in public rooms. Defaults to false. If you set it True, you'll have to - # rebuild the user_directory search indexes, see - # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md - # - #user_directory: - # enabled: true - # search_all_users: false + user_directory: + # Defines whether users can search the user directory. If false then + # empty responses are returned to all queries. Defaults to true. + # + # Uncomment to disable the user directory. + # + #enabled: false + + # Defines whether to search all users visible to your HS when searching + # the user directory, rather than limiting to users visible in public + # rooms. Defaults to false. + # + # If you set it true, you'll have to rebuild the user_directory search + # indexes, see: + # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md + # + # Uncomment to return search results containing all known users, even if that + # user does not share a room with the requester. + # + #search_all_users: true + + # Defines whether to prefer local users in search query results. + # If True, local users are more likely to appear above remote users + # when searching the user directory. Defaults to false. + # + # Uncomment to prefer local over remote users in user directory search + # results. + # + #prefer_local_users: true """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 7a0ca16da8..ac92375a85 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -17,9 +17,28 @@ from typing import List, Union import attr -from ._base import Config, ConfigError, ShardedWorkerHandlingConfig +from ._base import ( + Config, + ConfigError, + RoutableShardedWorkerHandlingConfig, + ShardedWorkerHandlingConfig, +) from .server import ListenerConfig, parse_listener_def +_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """ +The send_federation config option must be disabled in the main +synapse process before they can be run in a separate worker. + +Please add ``send_federation: false`` to the main config +""" + +_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """ +The start_pushers config option must be disabled in the main +synapse process before they can be run in a separate worker. + +Please add ``start_pushers: false`` to the main config +""" + def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: """Helper for allowing parsing a string or list of strings to a config @@ -103,6 +122,7 @@ class WorkerConfig(Config): self.worker_replication_secret = config.get("worker_replication_secret", None) self.worker_name = config.get("worker_name", self.worker_app) + self.instance_name = self.worker_name or "master" self.worker_main_http_uri = config.get("worker_main_http_uri", None) @@ -118,12 +138,41 @@ class WorkerConfig(Config): ) ) - # Whether to send federation traffic out in this process. This only - # applies to some federation traffic, and so shouldn't be used to - # "disable" federation - self.send_federation = config.get("send_federation", True) + # Handle federation sender configuration. + # + # There are two ways of configuring which instances handle federation + # sending: + # 1. The old way where "send_federation" is set to false and running a + # `synapse.app.federation_sender` worker app. + # 2. Specifying the workers sending federation in + # `federation_sender_instances`. + # + + send_federation = config.get("send_federation", True) + + federation_sender_instances = config.get("federation_sender_instances") + if federation_sender_instances is None: + # Default to an empty list, which means "another, unknown, worker is + # responsible for it". + federation_sender_instances = [] - federation_sender_instances = config.get("federation_sender_instances") or [] + # If no federation sender instances are set we check if + # `send_federation` is set, which means use master + if send_federation: + federation_sender_instances = ["master"] + + if self.worker_app == "synapse.app.federation_sender": + if send_federation: + # If we're running federation senders, and not using + # `federation_sender_instances`, then we should have + # explicitly set `send_federation` to false. + raise ConfigError( + _FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR + ) + + federation_sender_instances = [self.worker_name] + + self.send_federation = self.instance_name in federation_sender_instances self.federation_shard_config = ShardedWorkerHandlingConfig( federation_sender_instances ) @@ -164,7 +213,37 @@ class WorkerConfig(Config): "Must only specify one instance to handle `receipts` messages." ) - self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) + if len(self.writers.events) == 0: + raise ConfigError("Must specify at least one instance to handle `events`.") + + self.events_shard_config = RoutableShardedWorkerHandlingConfig( + self.writers.events + ) + + # Handle sharded push + start_pushers = config.get("start_pushers", True) + pusher_instances = config.get("pusher_instances") + if pusher_instances is None: + # Default to an empty list, which means "another, unknown, worker is + # responsible for it". + pusher_instances = [] + + # If no pushers instances are set we check if `start_pushers` is + # set, which means use master + if start_pushers: + pusher_instances = ["master"] + + if self.worker_app == "synapse.app.pusher": + if start_pushers: + # If we're running pushers, and not using + # `pusher_instances`, then we should have explicitly set + # `start_pushers` to false. + raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR) + + pusher_instances = [self.instance_name] + + self.start_pushers = self.instance_name in pusher_instances + self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) # Whether this worker should run background tasks or not. # diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8d4bb621e7..2f832b47f6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -44,6 +44,7 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json @@ -869,6 +870,13 @@ class FederationHandlerRegistry: # EDU received. self._edu_type_to_instance = {} # type: Dict[str, List[str]] + # A rate limiter for incoming room key requests per origin. + self._room_key_request_rate_limiter = Ratelimiter( + clock=self.clock, + rate_hz=self.config.rc_key_requests.per_second, + burst_count=self.config.rc_key_requests.burst_count, + ) + def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] ): @@ -917,7 +925,15 @@ class FederationHandlerRegistry: self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict): - if not self.config.use_presence and edu_type == "m.presence": + if not self.config.use_presence and edu_type == EduTypes.Presence: + return + + # If the incoming room key requests from a particular origin are over + # the limit, drop them. + if ( + edu_type == EduTypes.RoomKeyRequest + and not self._room_key_request_rate_limiter.can_do_action(origin) + ): return # Check if we have a handler on this instance diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 97fc4d0a82..24ebc4b803 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -474,7 +474,7 @@ class FederationSender: self._processing_pending_presence = False def send_presence_to_destinations( - self, states: List[UserPresenceState], destinations: List[str] + self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. destinations (list[str]) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cce83704d4..2cf935f38d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -484,10 +484,9 @@ class FederationQueryServlet(BaseFederationServlet): # This is when we receive a server-server Query async def on_GET(self, origin, content, query, query_type): - return await self.handler.on_query_request( - query_type, - {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, - ) + args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} + args["origin"] = origin + return await self.handler.on_query_request(query_type, args) class FederationMakeJoinServlet(BaseFederationServlet): diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 1aa7d803b5..7db4f48965 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -16,7 +16,9 @@ import logging from typing import TYPE_CHECKING, Any, Dict +from synapse.api.constants import EduTypes from synapse.api.errors import SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( get_active_span_text_map, @@ -25,7 +27,7 @@ from synapse.logging.opentracing import ( start_active_span, ) from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string @@ -78,6 +80,12 @@ class DeviceMessageHandler: ReplicationUserDevicesResyncRestServlet.make_client(hs) ) + self._ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.rc_key_requests.per_second, + burst_count=hs.config.rc_key_requests.burst_count, + ) + async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: local_messages = {} sender_user_id = content["sender"] @@ -168,15 +176,27 @@ class DeviceMessageHandler: async def send_device_message( self, - sender_user_id: str, + requester: Requester, message_type: str, messages: Dict[str, Dict[str, JsonDict]], ) -> None: + sender_user_id = requester.user.to_string() + set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) local_messages = {} remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] for user_id, by_device in messages.items(): + # Ratelimit local cross-user key requests by the sending device. + if ( + message_type == EduTypes.RoomKeyRequest + and user_id != sender_user_id + and self._ratelimiter.can_do_action( + (sender_user_id, requester.device_id) + ) + ): + continue + # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): messages_by_device = { diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 3e23f82cf7..f46cab7325 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -17,7 +17,7 @@ import logging import random from typing import TYPE_CHECKING, Iterable, List, Optional -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state @@ -113,7 +113,7 @@ class EventStreamHandler(BaseHandler): states = await presence_handler.get_states(users) to_add.extend( { - "type": EventTypes.Presence, + "type": EduTypes.Presence, "content": format_user_presence_state(state, time_now), } for state in states diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 78c3e5a10b..71a5076672 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state @@ -412,7 +412,7 @@ class InitialSyncHandler(BaseHandler): return [ { - "type": EventTypes.Presence, + "type": EduTypes.Presence, "content": format_user_presence_state(s, time_now), } for s in states diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c03f6c997b..1b7c065b34 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -387,6 +387,12 @@ class EventCreationHandler: self.room_invite_state_types = self.hs.config.room_invite_state_types + self.membership_types_to_include_profile_data_in = ( + {Membership.JOIN, Membership.INVITE} + if self.hs.config.include_profile_data_on_invite + else {Membership.JOIN} + ) + self.send_event = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users @@ -500,7 +506,7 @@ class EventCreationHandler: membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership in {Membership.JOIN, Membership.INVITE}: + if membership in self.membership_types_to_include_profile_data_in: # If event doesn't include a display name, add one. profile = self.profile_handler content = builder.content diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fb85b19770..b6a9ce4f38 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -849,6 +849,9 @@ class PresenceHandler(BasePresenceHandler): """Process current state deltas to find new joins that need to be handled. """ + # A map of destination to a set of user state that they should receive + presence_destinations = {} # type: Dict[str, Set[UserPresenceState]] + for delta in deltas: typ = delta["type"] state_key = delta["state_key"] @@ -858,6 +861,7 @@ class PresenceHandler(BasePresenceHandler): logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + # Drop any event that isn't a membership join if typ != EventTypes.Member: continue @@ -880,13 +884,38 @@ class PresenceHandler(BasePresenceHandler): # Ignore changes to join events. continue - await self._on_user_joined_room(room_id, state_key) + # Retrieve any user presence state updates that need to be sent as a result, + # and the destinations that need to receive it + destinations, user_presence_states = await self._on_user_joined_room( + room_id, state_key + ) + + # Insert the destinations and respective updates into our destinations dict + for destination in destinations: + presence_destinations.setdefault(destination, set()).update( + user_presence_states + ) + + # Send out user presence updates for each destination + for destination, user_state_set in presence_destinations.items(): + self.federation.send_presence_to_destinations( + destinations=[destination], states=user_state_set + ) - async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: + async def _on_user_joined_room( + self, room_id: str, user_id: str + ) -> Tuple[List[str], List[UserPresenceState]]: """Called when we detect a user joining the room via the current state - delta stream. - """ + delta stream. Returns the destinations that need to be updated and the + presence updates to send to them. + + Args: + room_id: The ID of the room that the user has joined. + user_id: The ID of the user that has joined the room. + Returns: + A tuple of destinations and presence updates to send to them. + """ if self.is_mine_id(user_id): # If this is a local user then we need to send their presence # out to hosts in the room (who don't already have it) @@ -894,15 +923,15 @@ class PresenceHandler(BasePresenceHandler): # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = await self.current_state_for_user(user_id) - hosts = await self.state.get_current_hosts_in_room(room_id) + remote_hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. - hosts = {host for host in hosts if host != self.server_name} + filtered_remote_hosts = [ + host for host in remote_hosts if host != self.server_name + ] - self.federation.send_presence_to_destinations( - states=[state], destinations=hosts - ) + state = await self.current_state_for_user(user_id) + return filtered_remote_hosts, [state] else: # A remote user has joined the room, so we need to: # 1. Check if this is a new server in the room @@ -915,6 +944,8 @@ class PresenceHandler(BasePresenceHandler): # TODO: Check that this is actually a new server joining the # room. + remote_host = get_domain_from_id(user_id) + users = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, users)) @@ -934,10 +965,7 @@ class PresenceHandler(BasePresenceHandler): or state.status_msg is not None ] - if states: - self.federation.send_presence_to_destinations( - states=states, destinations=[get_domain_from_id(user_id)] - ) + return [remote_host], states def should_notify(old_state, new_state): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 2f62d84fb5..dd59392bda 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -310,6 +310,15 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) async def on_profile_query(self, args: JsonDict) -> JsonDict: + """Handles federation profile query requests.""" + + if not self.hs.config.allow_profile_lookup_over_federation: + raise SynapseError( + 403, + "Profile lookup over federation is disabled on this homeserver", + Codes.FORBIDDEN, + ) + user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4e8ed7b33f..ce644e01ad 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -277,8 +277,9 @@ class SyncHandler: user_id = sync_config.user.to_string() await self.auth.check_auth_blocking(requester=requester) - res = await self.response_cache.wrap( + res = await self.response_cache.wrap_conditional( sync_config.request_key, + lambda result: since_token != result.next_batch, self._wait_for_sync_for_user, sync_config, since_token, diff --git a/synapse/http/site.py b/synapse/http/site.py index 4a4fb5ef26..30153237e3 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -16,6 +16,10 @@ import logging import time from typing import Optional, Union +import attr +from zope.interface import implementer + +from twisted.internet.interfaces import IAddress from twisted.python.failure import Failure from twisted.web.server import Request, Site @@ -333,26 +337,77 @@ class SynapseRequest(Request): class XForwardedForRequest(SynapseRequest): - def __init__(self, *args, **kw): - SynapseRequest.__init__(self, *args, **kw) + """Request object which honours proxy headers + Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with + information from request headers. """ - Add a layer on top of another request that only uses the value of an - X-Forwarded-For header as the result of C{getClientIP}. - """ - def getClientIP(self): + # the client IP and ssl flag, as extracted from the headers. + _forwarded_for = None # type: Optional[_XForwardedForAddress] + _forwarded_https = False # type: bool + + def requestReceived(self, command, path, version): + # this method is called by the Channel once the full request has been + # received, to dispatch the request to a resource. + # We can use it to set the IP address and protocol according to the + # headers. + self._process_forwarded_headers() + return super().requestReceived(command, path, version) + + def _process_forwarded_headers(self): + headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for") + if not headers: + return + + # for now, we just use the first x-forwarded-for header. Really, we ought + # to start from the client IP address, and check whether it is trusted; if it + # is, work backwards through the headers until we find an untrusted address. + # see https://github.com/matrix-org/synapse/issues/9471 + self._forwarded_for = _XForwardedForAddress( + headers[0].split(b",")[0].strip().decode("ascii") + ) + + # if we got an x-forwarded-for header, also look for an x-forwarded-proto header + header = self.getHeader(b"x-forwarded-proto") + if header is not None: + self._forwarded_https = header.lower() == b"https" + else: + # this is done largely for backwards-compatibility so that people that + # haven't set an x-forwarded-proto header don't get a redirect loop. + logger.warning( + "forwarded request lacks an x-forwarded-proto header: assuming https" + ) + self._forwarded_https = True + + def isSecure(self): + if self._forwarded_https: + return True + return super().isSecure() + + def getClientIP(self) -> str: """ - @return: The client address (the first address) in the value of the - I{X-Forwarded-For header}. If the header is not present, return - C{b"-"}. + Return the IP address of the client who submitted this request. + + This method is deprecated. Use getClientAddress() instead. """ - return ( - self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] - .split(b",")[0] - .strip() - .decode("ascii") - ) + if self._forwarded_for is not None: + return self._forwarded_for.host + return super().getClientIP() + + def getClientAddress(self) -> IAddress: + """ + Return the address of the client who submitted this request. + """ + if self._forwarded_for is not None: + return self._forwarded_for + return super().getClientAddress() + + +@implementer(IAddress) +@attr.s(frozen=True, slots=True) +class _XForwardedForAddress: + host = attr.ib(type=str) class SynapseSite(Site): diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index b9d3da2e0a..f4d7e199e9 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -74,6 +74,7 @@ class HttpPusher(Pusher): self.timed_call = None self._is_processing = False self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room + self._pusherpool = hs.get_pusherpool() self.data = pusher_config.data if self.data is None: @@ -299,7 +300,7 @@ class HttpPusher(Pusher): ) else: logger.info("Pushkey %s was rejected: removing", pk) - await self.hs.remove_pusher(self.app_id, pk, self.user_id) + await self._pusherpool.remove_pusher(self.app_id, pk, self.user_id) return True async def _build_notification_dict( diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index ae1145be0e..21f14f05f0 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -25,6 +25,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.push.pusher import PusherFactory +from synapse.replication.http.push import ReplicationRemovePusherRestServlet from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute @@ -58,7 +59,6 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs self.pusher_factory = PusherFactory(hs) - self._should_start_pushers = hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() @@ -67,6 +67,16 @@ class PusherPool: # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() + self._should_start_pushers = ( + self._instance_name in self._pusher_shard_config.instances + ) + + # We can only delete pushers on master. + self._remove_pusher_client = None + if hs.config.worker.worker_app: + self._remove_pusher_client = ReplicationRemovePusherRestServlet.make_client( + hs + ) # Record the last stream ID that we were poked about so we can get # changes since then. We set this to the current max stream ID on @@ -175,9 +185,6 @@ class PusherPool: user_id: user to remove pushers for access_tokens: access token *ids* to remove pushers for """ - if not self._pusher_shard_config.should_handle(self._instance_name, user_id): - return - tokens = set(access_tokens) for p in await self.store.get_pushers_by_user_id(user_id): if p.access_token in tokens: @@ -380,6 +387,12 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() - await self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id, pushkey, user_id - ) + # We can only delete pushers on master. + if self._remove_pusher_client: + await self._remove_pusher_client( + app_id=app_id, pushkey=pushkey, user_id=user_id + ) + else: + await self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id, pushkey, user_id + ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8a2b73b75e..321a333820 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -106,6 +106,9 @@ CONDITIONAL_REQUIREMENTS = { "pysaml2>=4.5.0;python_version>='3.6'", ], "oidc": ["authlib>=0.14.0"], + # systemd-python is necessary for logging to the systemd journal via + # `systemd.journal.JournalHandler`, as is documented in + # `contrib/systemd/log_config.yaml`. "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], "sentry": ["sentry-sdk>=0.7.2"], diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index dd527e807f..cb4a52dbe9 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -21,6 +21,7 @@ from synapse.replication.http import ( login, membership, presence, + push, register, send_event, streams, @@ -42,6 +43,7 @@ class ReplicationRestResource(JsonResource): membership.register_servlets(hs, self) streams.register_servlets(hs, self) account_data.register_servlets(hs, self) + push.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7a0dbb5b1a..8af53b4f28 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -213,8 +213,9 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): content = parse_json_object_from_request(request) args = content["args"] + args["origin"] = content["origin"] - logger.info("Got %r query", query_type) + logger.info("Got %r query from %s", query_type, args["origin"]) result = await self.registry.on_query(query_type, args) diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py new file mode 100644 index 0000000000..054ed64d34 --- /dev/null +++ b/synapse/replication/http/push.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationRemovePusherRestServlet(ReplicationEndpoint): + """Deletes the given pusher. + + Request format: + + POST /_synapse/replication/remove_pusher/:user_id + + { + "app_id": "<some_id>", + "pushkey": "<some_key>" + } + + """ + + NAME = "add_user_account_data" + PATH_ARGS = ("user_id",) + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.pusher_pool = hs.get_pusherpool() + + @staticmethod + async def _serialize_payload(app_id, pushkey, user_id): + payload = { + "app_id": app_id, + "pushkey": pushkey, + } + + return payload + + async def _handle_request(self, request, user_id): + content = parse_json_object_from_request(request) + + app_id = content["app_id"] + pushkey = content["pushkey"] + + await self.pusher_pool.remove_pusher(app_id, pushkey, user_id) + + return 200, {} + + +def register_servlets(hs, http_server): + ReplicationRemovePusherRestServlet(hs).register(http_server) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 0a9da79c32..bb447f75b4 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -325,31 +325,6 @@ class FederationAckCommand(Command): return "%s %s" % (self.instance_name, self.token) -class RemovePusherCommand(Command): - """Sent by the client to request the master remove the given pusher. - - Format:: - - REMOVE_PUSHER <app_id> <push_key> <user_id> - """ - - NAME = "REMOVE_PUSHER" - - def __init__(self, app_id, push_key, user_id): - self.user_id = user_id - self.app_id = app_id - self.push_key = push_key - - @classmethod - def from_line(cls, line): - app_id, push_key, user_id = line.split(" ", 2) - - return cls(app_id, push_key, user_id) - - def to_line(self): - return " ".join((self.app_id, self.push_key, self.user_id)) - - class UserIpCommand(Command): """Sent periodically when a worker sees activity from a client. @@ -416,7 +391,6 @@ _COMMANDS = ( ReplicateCommand, UserSyncCommand, FederationAckCommand, - RemovePusherCommand, UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, @@ -443,7 +417,6 @@ VALID_CLIENT_COMMANDS = ( UserSyncCommand.NAME, ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, - RemovePusherCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d1d00c3717..a7245da152 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -44,7 +44,6 @@ from synapse.replication.tcp.commands import ( PositionCommand, RdataCommand, RemoteServerUpCommand, - RemovePusherCommand, ReplicateCommand, UserIpCommand, UserSyncCommand, @@ -373,23 +372,6 @@ class ReplicationCommandHandler: if self._federation_sender: self._federation_sender.federation_ack(cmd.instance_name, cmd.token) - def on_REMOVE_PUSHER( - self, conn: AbstractConnection, cmd: RemovePusherCommand - ) -> Optional[Awaitable[None]]: - remove_pusher_counter.inc() - - if self._is_master: - return self._handle_remove_pusher(cmd) - else: - return None - - async def _handle_remove_pusher(self, cmd: RemovePusherCommand): - await self._store.delete_pusher_by_app_id_pushkey_user_id( - app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id - ) - - self._notifier.on_new_replication_data() - def on_USER_IP( self, conn: AbstractConnection, cmd: UserIpCommand ) -> Optional[Awaitable[None]]: @@ -684,11 +666,6 @@ class ReplicationCommandHandler: UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) ) - def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): - """Poke the master to remove a pusher for a user""" - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - def send_user_ip( self, user_id: str, diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index f4fdc40b22..00e1dcdbb8 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -145,7 +145,7 @@ <input type="submit" value="Continue" class="primary-button"> {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %} <section class="idp-pick-details"> - <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2> + <h2>{% if idp.idp_icon %}<img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>{% endif %}Information from {{ idp.idp_name }}</h2> {% if user_attributes.avatar_url %} <label class="idp-detail idp-avatar" for="idp-avatar"> <div class="check-row"> diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 998a0ef671..9c701c7348 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -35,6 +35,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.storage.databases.main.media_repository import MediaSortOrder from synapse.types import JsonDict, UserID if TYPE_CHECKING: @@ -832,8 +833,33 @@ class UserMediaRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + media, total = await self.store.get_local_media_by_user_paginate( - start, limit, user_id + start, limit, user_id, order_by, direction ) ret = {"media": media, "total": total} diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index a3dee14ed4..79c1b526ee 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -56,10 +56,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ("messages",)) - sender_user_id = requester.user.to_string() - await self.device_message_handler.send_device_message( - sender_user_id, message_type, content["messages"] + requester, message_type, content["messages"] ) response = (200, {}) # type: Tuple[int, dict] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a0162d4255..3375455c43 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -509,7 +509,7 @@ class MediaRepository: t_height: int, t_method: str, t_type: str, - url_cache: str, + url_cache: Optional[str], ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 1057e638be..b1b1c9e6ec 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -244,7 +244,7 @@ class MediaStorage: await consumer.wait() return local_path - raise Exception("file could not be found") + raise NotFoundError() def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d653a58be9..3ab90e9f9b 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -114,6 +114,7 @@ class ThumbnailResource(DirectServeJsonResource): m_type, thumbnail_infos, media_id, + media_id, url_cache=media_info["url_cache"], server_name=None, ) @@ -269,6 +270,7 @@ class ThumbnailResource(DirectServeJsonResource): method, m_type, thumbnail_infos, + media_id, media_info["filesystem_id"], url_cache=None, server_name=server_name, @@ -282,6 +284,7 @@ class ThumbnailResource(DirectServeJsonResource): desired_method: str, desired_type: str, thumbnail_infos: List[Dict[str, Any]], + media_id: str, file_id: str, url_cache: Optional[str] = None, server_name: Optional[str] = None, @@ -317,8 +320,59 @@ class ThumbnailResource(DirectServeJsonResource): return responder = await self.media_storage.fetch_media(file_info) + if responder: + await respond_with_responder( + request, + responder, + file_info.thumbnail_type, + file_info.thumbnail_length, + ) + return + + # If we can't find the thumbnail we regenerate it. This can happen + # if e.g. we've deleted the thumbnails but still have the original + # image somewhere. + # + # Since we have an entry for the thumbnail in the DB we a) know we + # have have successfully generated the thumbnail in the past (so we + # don't need to worry about repeatedly failing to generate + # thumbnails), and b) have already calculated that appropriate + # width/height/method so we can just call the "generate exact" + # methods. + + # First let's check that we do actually have the original image + # still. This will throw a 404 if we don't. + # TODO: We should refetch the thumbnails for remote media. + await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) + + if server_name: + await self.media_repo.generate_remote_exact_thumbnail( + server_name, + file_id=file_id, + media_id=media_id, + t_width=file_info.thumbnail_width, + t_height=file_info.thumbnail_height, + t_method=file_info.thumbnail_method, + t_type=file_info.thumbnail_type, + ) + else: + await self.media_repo.generate_local_exact_thumbnail( + media_id=media_id, + t_width=file_info.thumbnail_width, + t_height=file_info.thumbnail_height, + t_method=file_info.thumbnail_method, + t_type=file_info.thumbnail_type, + url_cache=url_cache, + ) + + responder = await self.media_storage.fetch_media(file_info) await respond_with_responder( - request, responder, file_info.thumbnail_type, file_info.thumbnail_length + request, + responder, + file_info.thumbnail_type, + file_info.thumbnail_length, ) else: logger.info("Failed to find any generated thumbnails") diff --git a/synapse/server.py b/synapse/server.py index 6b3892e3cd..4b9ec7f0ae 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -248,7 +248,7 @@ class HomeServer(metaclass=abc.ABCMeta): self.start_time = None # type: Optional[int] self._instance_id = random_string(5) - self._instance_name = config.worker_name or "master" + self._instance_name = config.worker.instance_name self.version_string = version_string @@ -758,12 +758,6 @@ class HomeServer(metaclass=abc.ABCMeta): reconnect=True, ) - async def remove_pusher(self, app_id: str, push_key: str, user_id: str): - return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) - def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" - return self.config.send_federation and ( - not self.config.worker_app - or self.config.worker_app == "synapse.app.federation_sender" - ) + return self.config.send_federation diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4646926449..f1ba529a2d 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor -from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection # python 3 does not have a maximum int value @@ -381,7 +380,10 @@ class DatabasePool: _TXN_ID = 0 def __init__( - self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + self, + hs, + database_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, ): self.hs = hs self._clock = hs.get_clock() @@ -420,16 +422,6 @@ class DatabasePool: self._check_safe_to_upsert, ) - # We define this sequence here so that it can be referenced from both - # the DataStore and PersistEventStore. - def get_chain_id_txn(txn): - txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] - - self.event_chain_id_gen = build_sequence_generator( - engine, get_chain_id_txn, "event_auth_chain_id" - ) - def is_running(self) -> bool: """Is the database pool currently running""" return self._db_pool.running diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index e84f8b42f7..379c78bb83 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -79,7 +79,7 @@ class Databases: # If we're on a process that can persist events also # instantiate a `PersistEventsStore` if hs.get_instance_name() in hs.config.worker.writers.events: - persist_events = PersistEventsStore(hs, database, main) + persist_events = PersistEventsStore(hs, database, main, db_conn) if "state" in database_config.databases: logger.info( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 287606cb4f..cd1ceac50e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -42,7 +42,9 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry +from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically @@ -90,7 +92,11 @@ class PersistEventsStore: """ def __init__( - self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" + self, + hs: "HomeServer", + db: DatabasePool, + main_data_store: "DataStore", + db_conn: Connection, ): self.hs = hs self.db_pool = db @@ -474,6 +480,7 @@ class PersistEventsStore: self._add_chain_cover_index( txn, self.db_pool, + self.store.event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, @@ -484,6 +491,7 @@ class PersistEventsStore: cls, txn, db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, List[str]], @@ -630,6 +638,7 @@ class PersistEventsStore: new_chain_tuples = cls._allocate_chain_ids( txn, db_pool, + event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, @@ -768,6 +777,7 @@ class PersistEventsStore: def _allocate_chain_ids( txn, db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, List[str]], @@ -880,7 +890,7 @@ class PersistEventsStore: chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] # Generate new chain IDs for all unallocated chain IDs. - newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn( txn, len(unallocated_chain_ids) ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 89274e75f7..c1626ccf28 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, + self.event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c8850a4707..edbe42f2bf 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_list = [] self._event_fetch_ongoing = 0 + # We define this sequence here so that it can be referenced from both + # the DataStore and PersistEventStore. + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self.event_chain_id_gen = build_sequence_generator( + db_conn, + database.engine, + get_chain_id_txn, + "event_auth_chain_id", + table="event_auth_chains", + id_column="chain_id", + ) + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index a0313c3ccf..274f8de595 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from typing import Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore @@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( ) +class MediaSortOrder(Enum): + """ + Enum to define the sorting method used when returning media with + get_local_media_by_user_paginate + """ + + MEDIA_ID = "media_id" + UPLOAD_NAME = "upload_name" + CREATED_TS = "created_ts" + LAST_ACCESS_TS = "last_access_ts" + MEDIA_LENGTH = "media_length" + MEDIA_TYPE = "media_type" + QUARANTINED_BY = "quarantined_by" + SAFE_FROM_QUARANTINE = "safe_from_quarantine" + + class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_local_media_by_user_paginate( - self, start: int, limit: int, user_id: str + self, + start: int, + limit: int, + user_id: str, + order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value, + direction: str = "f", ) -> Tuple[List[Dict[str, Any]], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): start: offset in the list limit: maximum amount of media_ids to retrieve user_id: fully-qualified user id + order_by: the sort order of the returned list + direction: sort ascending or descending Returns: A paginated list of all metadata of user's media, plus the total count of all the user's media @@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn(txn): + # Set ordering + order_by_column = MediaSortOrder(order_by).value + + if direction == "b": + order = "DESC" + else: + order = "ASC" + args = [user_id] sql = """ SELECT COUNT(*) as total_media @@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "safe_from_quarantine" FROM local_media_repository WHERE user_id = ? - ORDER BY created_ts DESC, media_id DESC + ORDER BY {order_by_column} {order}, media_id ASC LIMIT ? OFFSET ? - """ + """.format( + order_by_column=order_by_column, + order=order, + ) args += [limit, start] txn.execute(sql, args) @@ -344,16 +379,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - await self.db_pool.simple_insert( - "local_media_repository_thumbnails", - { + await self.db_pool.simple_upsert( + table="local_media_repository_thumbnails", + keyvalues={ "media_id": media_id, "thumbnail_width": thumbnail_width, "thumbnail_height": thumbnail_height, "thumbnail_method": thumbnail_method, "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, }, + values={"thumbnail_length": thumbnail_length}, desc="store_local_thumbnail", ) @@ -498,18 +533,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - await self.db_pool.simple_insert( - "remote_media_cache_thumbnails", - { + await self.db_pool.simple_upsert( + table="remote_media_cache_thumbnails", + keyvalues={ "media_origin": origin, "media_id": media_id, "thumbnail_width": thumbnail_width, "thumbnail_height": thumbnail_height, "thumbnail_method": thumbnail_method, "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, - "filesystem_id": filesystem_id, }, + values={"thumbnail_length": thumbnail_length}, + insertion_values={"filesystem_id": filesystem_id}, desc="store_remote_media_thumbnail", ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d5b5507815..61a7556e56 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -23,7 +23,7 @@ import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor @@ -70,7 +70,12 @@ class TokenLookupResult: class RegistrationWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # call `find_max_generated_user_id_localpart` each time, which is # expensive if there are many entries. self._user_id_seq = build_sequence_generator( + db_conn, database.engine, find_max_generated_user_id_localpart, "user_id_seq", + table=None, + id_column=None, ) self._account_validity = hs.config.account_validity @@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql new file mode 100644 index 0000000000..2442eea6bc --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql @@ -0,0 +1,19 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Delete all pushers associated with deleted devices. This is to clear up after +-- a bug where they weren't correctly deleted when using workers. +DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens); diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 63f88eac51..1026f321e5 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -497,8 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def add_users_in_public_rooms( self, room_id: str, user_ids: Iterable[str] ) -> None: - """Insert entries into the users_who_share_private_rooms table. The first - user should be a local user. + """Insert entries into the users_in_public_rooms table. Args: room_id @@ -556,6 +555,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self._prefer_local_users_in_search = ( + hs.config.user_directory_search_prefer_local_users + ) + self._server_name = hs.config.server_name + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): self.db_pool.simple_delete_txn( @@ -665,7 +669,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - @cached() async def get_shared_rooms_for_users( self, user_id: str, other_user_id: str ) -> Set[str]: @@ -754,9 +757,24 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) """ + # We allow manipulating the ranking algorithm by injecting statements + # based on config options. + additional_ordering_statements = [] + ordering_arguments = () + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)" + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + # We order by rank and then if they have profile info # The ranking algorithm is hand tweaked for "best" results. Broadly # the idea is we give a higher weight to exact matches. @@ -767,7 +785,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND vector @@ to_tsquery('simple', ?) ORDER BY (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -787,33 +805,54 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): 8 ) ) + %(order_case_statements)s DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, + """ % { + "where_clause": where_clause, + "order_case_statements": " ".join(additional_ordering_statements), + } + args = ( + join_args + + (full_query, exact_query, prefix_query) + + ordering_arguments + + (limit + 1,) ) - args = join_args + (full_query, exact_query, prefix_query, limit + 1) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + # + # Note that we need to include a comma at the end for valid SQL + statement = "user_id LIKE ? DESC," + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + sql = """ SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, + %(order_statements)s display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (search_query, limit + 1) + """ % { + "where_clause": where_clause, + "order_statements": " ".join(additional_ordering_statements), + } + args = join_args + (search_query,) + ordering_arguments + (limit + 1,) else: # This should be unreachable. raise Exception("Unrecognized database engine") diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index b16b9905d8..e2240703a7 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return txn.fetchone()[0] self._state_group_seq_gen = build_sequence_generator( - self.database_engine, get_max_state_group_txn, "state_group_id_seq" - ) - self._state_group_seq_gen.check_consistency( - db_conn, table="state_groups", id_column="id" + db_conn, + self.database_engine, + get_max_state_group_txn, + "state_group_id_seq", + table="state_groups", + id_column="id", ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 3ea637b281..36a67e7019 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator): def build_sequence_generator( + db_conn: "LoggingDatabaseConnection", database_engine: BaseDatabaseEngine, get_first_callback: GetFirstCallbackType, sequence_name: str, + table: Optional[str], + id_column: Optional[str], + stream_name: Optional[str] = None, + positive: bool = True, ) -> SequenceGenerator: """Get the best impl of SequenceGenerator available @@ -265,8 +270,23 @@ def build_sequence_generator( get_first_callback: a callback which gets the next sequence ID. Used if we're on sqlite. sequence_name: the name of a postgres sequence to use. + table, id_column, stream_name, positive: If set then `check_consistency` + is called on the created sequence. See docstring for + `check_consistency` details. """ if isinstance(database_engine, PostgresEngine): - return PostgresSequenceGenerator(sequence_name) + seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator else: - return LocalSequenceGenerator(get_first_callback) + seq = LocalSequenceGenerator(get_first_callback) + + if table: + assert id_column + seq.check_consistency( + db_conn=db_conn, + table=table, + id_column=id_column, + stream_name=stream_name, + positive=positive, + ) + + return seq diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 32228f42ee..53f85195a7 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar from twisted.internet import defer @@ -40,6 +40,7 @@ class ResponseCache(Generic[T]): def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): # Requests that haven't finished yet. self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] + self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]] self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000.0 @@ -101,7 +102,11 @@ class ResponseCache(Generic[T]): self.pending_result_cache[key] = result def remove(r): - if self.timeout_sec: + should_cache = all( + func(r) for func in self.pending_conditionals.pop(key, []) + ) + + if self.timeout_sec and should_cache: self.clock.call_later( self.timeout_sec, self.pending_result_cache.pop, key, None ) @@ -112,6 +117,31 @@ class ResponseCache(Generic[T]): result.addBoth(remove) return result.observe() + def add_conditional(self, key: T, conditional: Callable[[Any], bool]): + self.pending_conditionals.setdefault(key, set()).add(conditional) + + def wrap_conditional( + self, + key: T, + should_cache: Callable[[Any], bool], + callback: "Callable[..., Any]", + *args: Any, + **kwargs: Any + ) -> defer.Deferred: + """The same as wrap(), but adds a conditional to the final execution. + + When the final execution completes, *all* conditionals need to return True for it to properly cache, + else it'll not be cached in a timed fashion. + """ + + # See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of + # python, adding a key immediately in the same execution thread will not cause a race condition. + result = self.get(key) + if not result or isinstance(result, defer.Deferred) and not result.called: + self.add_conditional(key, should_cache) + + return self.wrap(key, callback, *args, **kwargs) + def wrap( self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any ) -> defer.Deferred: |