summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorMathieu Velten <mathieuv@matrix.org>2023-03-21 12:52:12 +0100
committerMathieu Velten <mathieuv@matrix.org>2023-03-21 12:52:12 +0100
commit6400f0302956986cafac6a99d07d5c3b684989d7 (patch)
treeda342d2bdbfe2faede0cc2481c7a7f49abaaae06 /synapse
parentMerge remote-tracking branch 'origin/release-v1.79' into matrix-org-hotfixes (diff)
parentDocument that our Docker images are mirrored to GHCR. (#15282) (diff)
downloadsynapse-6400f0302956986cafac6a99d07d5c3b684989d7.tar.xz
Merge remote-tracking branch 'origin/release-v1.80' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/errors.py5
-rw-r--r--synapse/appservice/api.py13
-rw-r--r--synapse/config/experimental.py23
-rw-r--r--synapse/event_auth.py23
-rw-r--r--synapse/events/snapshot.py1
-rw-r--r--synapse/handlers/account_validity.py99
-rw-r--r--synapse/handlers/event_auth.py13
-rw-r--r--synapse/handlers/message.py17
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/register.py10
-rw-r--r--synapse/handlers/room.py112
-rw-r--r--synapse/handlers/sync.py9
-rw-r--r--synapse/handlers/user_directory.py323
-rw-r--r--synapse/http/client.py8
-rw-r--r--synapse/http/federation/matrix_federation_agent.py2
-rw-r--r--synapse/http/server.py4
-rw-r--r--synapse/media/url_previewer.py833
-rw-r--r--synapse/module_api/__init__.py18
-rw-r--r--synapse/module_api/callbacks/__init__.py22
-rw-r--r--synapse/module_api/callbacks/account_validity_callbacks.py93
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py34
-rw-r--r--synapse/replication/tcp/client.py51
-rw-r--r--synapse/replication/tcp/handler.py26
-rw-r--r--synapse/rest/__init__.py2
-rw-r--r--synapse/rest/admin/server_notice_servlet.py34
-rw-r--r--synapse/rest/admin/users.py17
-rw-r--r--synapse/rest/client/appservice_ping.py115
-rw-r--r--synapse/rest/client/register.py2
-rw-r--r--synapse/rest/client/room.py174
-rw-r--r--synapse/rest/client/sendtodevice.py25
-rw-r--r--synapse/rest/client/transactions.py55
-rw-r--r--synapse/rest/client/versions.py4
-rw-r--r--synapse/rest/media/preview_url_resource.py796
-rw-r--r--synapse/server.py34
-rw-r--r--synapse/storage/controllers/purge_events.py22
-rw-r--r--synapse/storage/database.py21
-rw-r--r--synapse/storage/databases/main/purge_events.py11
-rw-r--r--synapse/storage/databases/main/transactions.py56
-rw-r--r--synapse/storage/databases/main/user_directory.py114
-rw-r--r--synapse/storage/databases/state/store.py2
-rw-r--r--synapse/storage/schema/main/delta/74/01_user_directory_stale_remote_users.sql39
-rw-r--r--synapse/storage/schema/main/delta/74/90COMMENTS_destinations.sql.postgres52
43 files changed, 2000 insertions, 1320 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py

index e1737de59b..8c6822f3c6 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -108,6 +108,11 @@ class Codes(str, Enum): USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL" + AS_PING_URL_NOT_SET = "FI.MAU.MSC2659_URL_NOT_SET" + AS_PING_BAD_STATUS = "FI.MAU.MSC2659_BAD_STATUS" + AS_PING_CONNECTION_TIMEOUT = "FI.MAU.MSC2659_CONNECTION_TIMEOUT" + AS_PING_CONNECTION_FAILED = "FI.MAU.MSC2659_CONNECTION_FAILED" + # Attempt to send a second annotation with the same event type & annotation key # MSC2677 DUPLICATE_ANNOTATION = "M_DUPLICATE_ANNOTATION" diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 1a6f69e7d3..4812fb4496 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -266,6 +266,19 @@ class ApplicationServiceApi(SimpleHttpClient): key = (service.id, protocol) return await self.protocol_meta_cache.wrap(key, _get) + async def ping(self, service: "ApplicationService", txn_id: Optional[str]) -> None: + # The caller should check that url is set + assert service.url is not None, "ping called without URL being set" + + # This is required by the configuration. + assert service.hs_token is not None + + await self.post_json_get_json( + uri=service.url + "/_matrix/app/unstable/fi.mau.msc2659/ping", + post_json={"transaction_id": txn_id}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, + ) + async def push_bulk( self, service: "ApplicationService", diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 489f2601ac..99dcd27c74 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -166,20 +166,9 @@ class ExperimentalConfig(Config): # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) - # MSC3873: Disambiguate event_match keys. - self.msc3873_escape_event_match_key = experimental.get( - "msc3873_escape_event_match_key", False - ) - - # MSC3966: exact_event_property_contains push rule condition. - self.msc3966_exact_event_property_contains = experimental.get( - "msc3966_exact_event_property_contains", False - ) - # MSC3952: Intentional mentions, this depends on MSC3966. - self.msc3952_intentional_mentions = ( - experimental.get("msc3952_intentional_mentions", False) - and self.msc3966_exact_event_property_contains + self.msc3952_intentional_mentions = experimental.get( + "msc3952_intentional_mentions", False ) # MSC3959: Do not generate notifications for edits. @@ -187,10 +176,8 @@ class ExperimentalConfig(Config): "msc3958_supress_edit_notifs", False ) - # MSC3966: exact_event_property_contains push rule condition. - self.msc3966_exact_event_property_contains = experimental.get( - "msc3966_exact_event_property_contains", False - ) - # MSC3967: Do not require UIA when first uploading cross signing keys self.msc3967_enabled = experimental.get("msc3967_enabled", False) + + # MSC2659: Application service ping endpoint + self.msc2659_enabled = experimental.get("msc2659_enabled", False) diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4d6d1b8ebd..af55874b5c 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -168,13 +168,24 @@ async def check_state_independent_auth_rules( return # 2. Reject if event has auth_events that: ... - auth_events = await store.get_events( - event.auth_event_ids(), - redact_behaviour=EventRedactBehaviour.as_is, - allow_rejected=True, - ) if batched_auth_events: - auth_events.update(batched_auth_events) + # Copy the batched auth events to avoid mutating them. + auth_events = dict(batched_auth_events) + needed_auth_event_ids = set(event.auth_event_ids()) - batched_auth_events.keys() + if needed_auth_event_ids: + auth_events.update( + await store.get_events( + needed_auth_event_ids, + redact_behaviour=EventRedactBehaviour.as_is, + allow_rejected=True, + ) + ) + else: + auth_events = await store.get_events( + event.auth_event_ids(), + redact_behaviour=EventRedactBehaviour.as_is, + allow_rejected=True, + ) room_id = event.room_id auth_dict: MutableStateMap[str] = {} diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a91a5d1e3c..c04ad08cbb 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py
@@ -293,6 +293,7 @@ class EventContext(UnpersistedEventContextBase): Maps a (type, state_key) to the event ID of the state event matching this tuple. """ + assert self.state_group_before_event is not None return await self._storage.state.get_state_ids_for_group( self.state_group_before_event, state_filter diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 33e45e3a11..4aa4ebf7e4 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -15,9 +15,7 @@ import email.mime.multipart import email.utils import logging -from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple - -from twisted.web.http import Request +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -30,25 +28,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Types for callbacks to be registered via the module api -IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] -ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] -# Temporary hooks to allow for a transition from `/_matrix/client` endpoints -# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`. -ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] -ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]] -ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable] - class AccountValidityHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config - self.store = self.hs.get_datastores().main - self.send_email_handler = self.hs.get_send_email_handler() - self.clock = self.hs.get_clock() + self.store = hs.get_datastores().main + self.send_email_handler = hs.get_send_email_handler() + self.clock = hs.get_clock() - self._app_name = self.hs.config.email.email_app_name + self._app_name = hs.config.email.email_app_name + self._module_api_callbacks = hs.get_module_api_callbacks().account_validity self._account_validity_enabled = ( hs.config.account_validity.account_validity_enabled @@ -78,69 +68,6 @@ class AccountValidityHandler: if hs.config.worker.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) - self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] - self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] - self._on_legacy_send_mail_callback: Optional[ - ON_LEGACY_SEND_MAIL_CALLBACK - ] = None - self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None - - # The legacy admin requests callback isn't a protected attribute because we need - # to access it from the admin servlet, which is outside of this handler. - self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None - - def register_account_validity_callbacks( - self, - is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, - on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, - on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, - on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, - on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, - ) -> None: - """Register callbacks from module for each hook.""" - if is_user_expired is not None: - self._is_user_expired_callbacks.append(is_user_expired) - - if on_user_registration is not None: - self._on_user_registration_callbacks.append(on_user_registration) - - # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and - # an admin one). As part of moving the feature into a module, we need to change - # the path from /_matrix/client/unstable/account_validity/... to - # /_synapse/client/account_validity, because: - # - # * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix - # * the way we register servlets means that modules can't register resources - # under /_matrix/client - # - # We need to allow for a transition period between the old and new endpoints - # in order to allow for clients to update (and for emails to be processed). - # - # Once the email-account-validity module is loaded, it will take control of account - # validity by moving the rows from our `account_validity` table into its own table. - # - # Therefore, we need to allow modules (in practice just the one implementing the - # email-based account validity) to temporarily hook into the legacy endpoints so we - # can route the traffic coming into the old endpoints into the module, which is - # why we have the following three temporary hooks. - if on_legacy_send_mail is not None: - if self._on_legacy_send_mail_callback is not None: - raise RuntimeError("Tried to register on_legacy_send_mail twice") - - self._on_legacy_send_mail_callback = on_legacy_send_mail - - if on_legacy_renew is not None: - if self._on_legacy_renew_callback is not None: - raise RuntimeError("Tried to register on_legacy_renew twice") - - self._on_legacy_renew_callback = on_legacy_renew - - if on_legacy_admin_request is not None: - if self.on_legacy_admin_request_callback is not None: - raise RuntimeError("Tried to register on_legacy_admin_request twice") - - self.on_legacy_admin_request_callback = on_legacy_admin_request - async def is_user_expired(self, user_id: str) -> bool: """Checks if a user has expired against third-party modules. @@ -150,7 +77,7 @@ class AccountValidityHandler: Returns: Whether the user has expired. """ - for callback in self._is_user_expired_callbacks: + for callback in self._module_api_callbacks.is_user_expired_callbacks: expired = await delay_cancellation(callback(user_id)) if expired is not None: return expired @@ -168,7 +95,7 @@ class AccountValidityHandler: Args: user_id: The ID of the newly registered user. """ - for callback in self._on_user_registration_callbacks: + for callback in self._module_api_callbacks.on_user_registration_callbacks: await callback(user_id) @wrap_as_background_process("send_renewals") @@ -198,8 +125,8 @@ class AccountValidityHandler: """ # If a module supports sending a renewal email from here, do that, otherwise do # the legacy dance. - if self._on_legacy_send_mail_callback is not None: - await self._on_legacy_send_mail_callback(user_id) + if self._module_api_callbacks.on_legacy_send_mail_callback is not None: + await self._module_api_callbacks.on_legacy_send_mail_callback(user_id) return if not self._account_validity_renew_by_email_enabled: @@ -336,8 +263,10 @@ class AccountValidityHandler: """ # If a module supports triggering a renew from here, do that, otherwise do the # legacy dance. - if self._on_legacy_renew_callback is not None: - return await self._on_legacy_renew_callback(renewal_token) + if self._module_api_callbacks.on_legacy_renew_callback is not None: + return await self._module_api_callbacks.on_legacy_renew_callback( + renewal_token + ) try: ( diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index c508861b6a..0db0bd7304 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py
@@ -63,9 +63,18 @@ class EventAuthHandler: self._store, event, batched_auth_events ) auth_event_ids = event.auth_event_ids() - auth_events_by_id = await self._store.get_events(auth_event_ids) + if batched_auth_events: - auth_events_by_id.update(batched_auth_events) + # Copy the batched auth events to avoid mutating them. + auth_events_by_id = dict(batched_auth_events) + needed_auth_event_ids = set(auth_event_ids) - set(batched_auth_events) + if needed_auth_event_ids: + auth_events_by_id.update( + await self._store.get_events(needed_auth_event_ids) + ) + else: + auth_events_by_id = await self._store.get_events(auth_event_ids) + check_state_dependent_auth_rules(event, auth_events_by_id.values()) def compute_auth_events( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da129ec16a..4c75433a63 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -987,10 +987,11 @@ class EventCreationHandler: # a situation where event persistence can't keep up, causing # extremities to pile up, which in turn leads to state resolution # taking longer. - async with self.limiter.queue(event_dict["room_id"]): + room_id = event_dict["room_id"] + async with self.limiter.queue(room_id): if txn_id: event = await self.get_event_from_transaction( - requester, txn_id, event_dict["room_id"] + requester, txn_id, room_id ) if event: # we know it was persisted, so must have a stream ordering @@ -1000,6 +1001,18 @@ class EventCreationHandler: event.internal_metadata.stream_ordering, ) + # If we don't have any prev event IDs specified then we need to + # check that the host is in the room (as otherwise populating the + # prev events will fail), at which point we may as well check the + # local user is in the room. + if not prev_event_ids: + user_id = requester.user.to_string() + is_user_in_room = await self.store.check_local_user_in_room( + user_id, room_id + ) + if not is_user_in_room: + raise AuthError(403, f"User {user_id} not in room {room_id}") + # Try several times, it could fail with PartialStateConflictError # in handle_new_client_event, cf comment in except block. max_retries = 5 diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8c79c055ba..63b35c8d62 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -683,7 +683,7 @@ class PaginationHandler: await self._storage_controllers.purge_events.purge_room(room_id) - logger.info("complete") + logger.info("purge complete for room_id %s", room_id) self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE except Exception: f = Failure() diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4bf9a047a3..9a81a77cbd 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -63,7 +63,7 @@ class ProfileHandler: self._third_party_rules = hs.get_third_party_event_rules() - async def get_profile(self, user_id: str) -> JsonDict: + async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict: target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): @@ -81,7 +81,7 @@ class ProfileHandler: destination=target_user.domain, query_type="profile", args={"user_id": user_id}, - ignore_backoff=True, + ignore_backoff=ignore_backoff, ) return result except RequestSendFailed as e: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e4e506e62c..6b110dcb6e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -596,14 +596,20 @@ class RegistrationHandler: Args: user_id: The user to join """ + # If there are no rooms to auto-join, just bail. + if not self.hs.config.registration.auto_join_rooms: + return + # auto-join the user to any rooms we're supposed to dump them into # try to create the room if we're the first real user on the server. Note # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_real_user = await self.store.is_real_user(user_id) - if self.hs.config.registration.autocreate_auto_join_rooms and is_real_user: + if ( + self.hs.config.registration.autocreate_auto_join_rooms + and await self.store.is_real_user(user_id) + ): count = await self.store.count_real_users() should_auto_create_rooms = count == 1 diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b1784638f4..be120cb12f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -569,7 +569,7 @@ class RoomCreationHandler: new_room_id, # we expect to override all the presets with initial_state, so this is # somewhat arbitrary. - preset_config=RoomCreationPreset.PRIVATE_CHAT, + room_config={"preset": RoomCreationPreset.PRIVATE_CHAT}, invite_list=[], initial_state=initial_state, creation_content=creation_content, @@ -904,13 +904,6 @@ class RoomCreationHandler: check_membership=False, ) - preset_config = config.get( - "preset", - RoomCreationPreset.PRIVATE_CHAT - if visibility == "private" - else RoomCreationPreset.PUBLIC_CHAT, - ) - raw_initial_state = config.get("initial_state", []) initial_state = OrderedDict() @@ -929,7 +922,7 @@ class RoomCreationHandler: ) = await self._send_events_for_new_room( requester, room_id, - preset_config=preset_config, + room_config=config, invite_list=invite_list, initial_state=initial_state, creation_content=creation_content, @@ -938,48 +931,6 @@ class RoomCreationHandler: creator_join_profile=creator_join_profile, ) - if "name" in config: - name = config["name"] - ( - name_event, - last_stream_id, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Name, - "room_id": room_id, - "sender": user_id, - "state_key": "", - "content": {"name": name}, - }, - ratelimit=False, - prev_event_ids=[last_sent_event_id], - depth=depth, - ) - last_sent_event_id = name_event.event_id - depth += 1 - - if "topic" in config: - topic = config["topic"] - ( - topic_event, - last_stream_id, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Topic, - "room_id": room_id, - "sender": user_id, - "state_key": "", - "content": {"topic": topic}, - }, - ratelimit=False, - prev_event_ids=[last_sent_event_id], - depth=depth, - ) - last_sent_event_id = topic_event.event_id - depth += 1 - # we avoid dropping the lock between invites, as otherwise joins can # start coming in and making the createRoom slow. # @@ -1047,7 +998,7 @@ class RoomCreationHandler: self, creator: Requester, room_id: str, - preset_config: str, + room_config: JsonDict, invite_list: List[str], initial_state: MutableStateMap, creation_content: JsonDict, @@ -1064,11 +1015,33 @@ class RoomCreationHandler: Rate limiting should already have been applied by this point. + Args: + creator: + the user requesting the room creation + room_id: + room id for the room being created + room_config: + A dict of configuration options. This will be the body of + a /createRoom request; see + https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom + invite_list: + a list of user ids to invite to the room + initial_state: + A list of state events to set in the new room. + creation_content: + Extra keys, such as m.federate, to be added to the content of the m.room.create event. + room_alias: + alias for the room + power_level_content_override: + The power level content to override in the default power level event. + creator_join_profile: + Set to override the displayname and avatar for the creating + user in this room. + Returns: A tuple containing the stream ID, event ID and depth of the last event sent to the room. """ - creator_id = creator.user.to_string() event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 @@ -1079,9 +1052,6 @@ class RoomCreationHandler: # created (but not persisted to the db) to determine state for future created events # (as this info can't be pulled from the db) state_map: MutableStateMap[str] = {} - # current_state_group of last event created. Used for computing event context of - # events to be batched - current_state_group: Optional[int] = None def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: e = {"type": etype, "content": content} @@ -1123,7 +1093,9 @@ class RoomCreationHandler: event_dict, prev_event_ids=prev_event, depth=depth, - state_map=state_map, + # Take a copy to ensure each event gets a unique copy of + # state_map since it is modified below. + state_map=dict(state_map), for_batch=for_batch, ) @@ -1133,6 +1105,14 @@ class RoomCreationHandler: return new_event, new_unpersisted_context + visibility = room_config.get("visibility", "private") + preset_config = room_config.get( + "preset", + RoomCreationPreset.PRIVATE_CHAT + if visibility == "private" + else RoomCreationPreset.PUBLIC_CHAT, + ) + try: config = self._presets_dict[preset_config] except KeyError: @@ -1284,6 +1264,24 @@ class RoomCreationHandler: ) events_to_send.append((encryption_event, encryption_context)) + if "name" in room_config: + name = room_config["name"] + name_event, name_context = await create_event( + EventTypes.Name, + {"name": name}, + True, + ) + events_to_send.append((name_event, name_context)) + + if "topic" in room_config: + topic = room_config["topic"] + topic_event, topic_context = await create_event( + EventTypes.Topic, + {"topic": topic}, + True, + ) + events_to_send.append((topic_event, topic_context)) + datastore = self.hs.get_datastores().state events_and_context = ( await UnpersistedEventContext.batch_persist_unpersisted_contexts( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index fd6d946c37..9f5b83ed54 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -1226,6 +1226,10 @@ class SyncHandler: continue event_with_membership_auth = events_with_membership_auth[member] + is_create = ( + event_with_membership_auth.is_state() + and event_with_membership_auth.type == EventTypes.Create + ) is_join = ( event_with_membership_auth.is_state() and event_with_membership_auth.type == EventTypes.Member @@ -1233,9 +1237,10 @@ class SyncHandler: and event_with_membership_auth.content.get("membership") == Membership.JOIN ) - if not is_join: + if not is_create and not is_join: # The event must include the desired membership as an auth event, unless - # it's the first join event for a given user. + # it's the `m.room.create` event for a room or the first join event for + # a given user. missing_members.add(member) auth_event_ids.update(event_with_membership_auth.auth_event_ids()) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 3610b6bf78..28a92d41d6 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -13,21 +13,52 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from twisted.internet.interfaces import IDelayedCall + import synapse.metrics from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership +from synapse.api.errors import Codes, SynapseError from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.user_directory import SearchResult from synapse.storage.roommember import ProfileInfo +from synapse.types import UserID from synapse.util.metrics import Measure +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import non_null_str_or_none if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) +# Don't refresh a stale user directory entry, using a Federation /profile request, +# for 60 seconds. This gives time for other state events to arrive (which will +# then be coalesced such that only one /profile request is made). +USER_DIRECTORY_STALE_REFRESH_TIME_MS = 60 * 1000 + +# Maximum number of remote servers that we will attempt to refresh profiles for +# in one go. +MAX_SERVERS_TO_REFRESH_PROFILES_FOR_IN_ONE_GO = 5 + +# As long as we have servers to refresh (without backoff), keep adding more +# every 15 seconds. +INTERVAL_TO_ADD_MORE_SERVERS_TO_REFRESH_PROFILES = 15 + + +def calculate_time_of_next_retry(now_ts: int, retry_count: int) -> int: + """ + Calculates the time of a next retry given `now_ts` in ms and the number + of failures encountered thus far. + + Currently the sequence goes: + 1 min, 5 min, 25 min, 2 hour, 10 hour, 52 hour, 10 day, 7.75 week + """ + return now_ts + 60_000 * (5 ** min(retry_count, 7)) + class UserDirectoryHandler(StateDeltasHandler): """Handles queries and updates for the user_directory. @@ -64,12 +95,24 @@ class UserDirectoryHandler(StateDeltasHandler): self.update_user_directory = hs.config.worker.should_update_user_directory self.search_all_users = hs.config.userdirectory.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() + self._hs = hs + # The current position in the current_state_delta stream self.pos: Optional[int] = None # Guard to ensure we only process deltas one at a time self._is_processing = False + # Guard to ensure we only have one process for refreshing remote profiles + self._is_refreshing_remote_profiles = False + # Handle to cancel the `call_later` of `kick_off_remote_profile_refresh_process` + self._refresh_remote_profiles_call_later: Optional[IDelayedCall] = None + + # Guard to ensure we only have one process for refreshing remote profiles + # for the given servers. + # Set of server names. + self._is_refreshing_remote_profiles_for_servers: Set[str] = set() + if self.update_user_directory: self.notifier.add_replication_callback(self.notify_new_event) @@ -77,6 +120,11 @@ class UserDirectoryHandler(StateDeltasHandler): # we start populating the user directory self.clock.call_later(0, self.notify_new_event) + # Kick off the profile refresh process on startup + self._refresh_remote_profiles_call_later = self.clock.call_later( + 10, self.kick_off_remote_profile_refresh_process + ) + async def search_users( self, user_id: str, search_term: str, limit: int ) -> SearchResult: @@ -200,8 +248,8 @@ class UserDirectoryHandler(StateDeltasHandler): typ = delta["type"] state_key = delta["state_key"] room_id = delta["room_id"] - event_id = delta["event_id"] - prev_event_id = delta["prev_event_id"] + event_id: Optional[str] = delta["event_id"] + prev_event_id: Optional[str] = delta["prev_event_id"] logger.debug("Handling: %r %r, %s", typ, state_key, event_id) @@ -297,8 +345,8 @@ class UserDirectoryHandler(StateDeltasHandler): async def _handle_room_membership_event( self, room_id: str, - prev_event_id: str, - event_id: str, + prev_event_id: Optional[str], + event_id: Optional[str], state_key: str, ) -> None: """Process a single room membershp event. @@ -348,7 +396,8 @@ class UserDirectoryHandler(StateDeltasHandler): # Handle any profile changes for remote users. # (For local users the rest of the application calls # `handle_local_profile_change`.) - if is_remote: + # Only process if there is an event_id. + if is_remote and event_id is not None: await self._handle_possible_remote_profile_change( state_key, room_id, prev_event_id, event_id ) @@ -356,29 +405,13 @@ class UserDirectoryHandler(StateDeltasHandler): # This may be the first time we've seen a remote user. If # so, ensure we have a directory entry for them. (For local users, # the rest of the application calls `handle_local_profile_change`.) - if is_remote: - await self._upsert_directory_entry_for_remote_user(state_key, event_id) + # Only process if there is an event_id. + if is_remote and event_id is not None: + await self._handle_possible_remote_profile_change( + state_key, room_id, None, event_id + ) await self._track_user_joined_room(room_id, state_key) - async def _upsert_directory_entry_for_remote_user( - self, user_id: str, event_id: str - ) -> None: - """A remote user has just joined a room. Ensure they have an entry in - the user directory. The caller is responsible for making sure they're - remote. - """ - event = await self.store.get_event(event_id, allow_none=True) - # It isn't expected for this event to not exist, but we - # don't want the entire background process to break. - if event is None: - return - - logger.debug("Adding new user to dir, %r", user_id) - - await self.store.update_profile_in_user_dir( - user_id, event.content.get("displayname"), event.content.get("avatar_url") - ) - async def _track_user_joined_room(self, room_id: str, joining_user_id: str) -> None: """Someone's just joined a room. Update `users_in_public_rooms` or `users_who_share_private_rooms` as appropriate. @@ -460,14 +493,17 @@ class UserDirectoryHandler(StateDeltasHandler): user_id: str, room_id: str, prev_event_id: Optional[str], - event_id: Optional[str], + event_id: str, ) -> None: """Check member event changes for any profile changes and update the database if there are. This is intended for remote users only. The caller is responsible for checking that the given user is remote. """ - if not prev_event_id or not event_id: - return + + if not prev_event_id: + # If we don't have an older event to fall back on, just fetch the same + # event itself. + prev_event_id = event_id prev_event = await self.store.get_event(prev_event_id, allow_none=True) event = await self.store.get_event(event_id, allow_none=True) @@ -478,17 +514,236 @@ class UserDirectoryHandler(StateDeltasHandler): if event.membership != Membership.JOIN: return + is_public = await self.store.is_room_world_readable_or_publicly_joinable( + room_id + ) + if not is_public: + # Don't collect user profiles from private rooms as they are not guaranteed + # to be the same as the user's global profile. + now_ts = self.clock.time_msec() + await self.store.set_remote_user_profile_in_user_dir_stale( + user_id, + next_try_at_ms=now_ts + USER_DIRECTORY_STALE_REFRESH_TIME_MS, + retry_counter=0, + ) + # Schedule a wake-up to refresh the user directory for this server. + # We intentionally wake up this server directly because we don't want + # other servers ahead of it in the queue to get in the way of updating + # the profile if the server only just sent us an event. + self.clock.call_later( + USER_DIRECTORY_STALE_REFRESH_TIME_MS // 1000 + 1, + self.kick_off_remote_profile_refresh_process_for_remote_server, + UserID.from_string(user_id).domain, + ) + # Schedule a wake-up to handle any backoffs that may occur in the future. + self.clock.call_later( + 2 * USER_DIRECTORY_STALE_REFRESH_TIME_MS // 1000 + 1, + self.kick_off_remote_profile_refresh_process, + ) + return + prev_name = prev_event.content.get("displayname") new_name = event.content.get("displayname") - # If the new name is an unexpected form, do not update the directory. + # If the new name is an unexpected form, replace with None. if not isinstance(new_name, str): - new_name = prev_name + new_name = None prev_avatar = prev_event.content.get("avatar_url") new_avatar = event.content.get("avatar_url") - # If the new avatar is an unexpected form, do not update the directory. + # If the new avatar is an unexpected form, replace with None. if not isinstance(new_avatar, str): - new_avatar = prev_avatar + new_avatar = None - if prev_name != new_name or prev_avatar != new_avatar: + if ( + prev_name != new_name + or prev_avatar != new_avatar + or prev_event_id == event_id + ): + # Only update if something has changed, or we didn't have a previous event + # in the first place. await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) + + def kick_off_remote_profile_refresh_process(self) -> None: + """Called when there may be remote users with stale profiles to be refreshed""" + if not self.update_user_directory: + return + + if self._is_refreshing_remote_profiles: + return + + if self._refresh_remote_profiles_call_later: + if self._refresh_remote_profiles_call_later.active(): + self._refresh_remote_profiles_call_later.cancel() + self._refresh_remote_profiles_call_later = None + + async def process() -> None: + try: + await self._unsafe_refresh_remote_profiles() + finally: + self._is_refreshing_remote_profiles = False + + self._is_refreshing_remote_profiles = True + run_as_background_process("user_directory.refresh_remote_profiles", process) + + async def _unsafe_refresh_remote_profiles(self) -> None: + limit = MAX_SERVERS_TO_REFRESH_PROFILES_FOR_IN_ONE_GO - len( + self._is_refreshing_remote_profiles_for_servers + ) + if limit <= 0: + # nothing to do: already refreshing the maximum number of servers + # at once. + # Come back later. + self._refresh_remote_profiles_call_later = self.clock.call_later( + INTERVAL_TO_ADD_MORE_SERVERS_TO_REFRESH_PROFILES, + self.kick_off_remote_profile_refresh_process, + ) + return + + servers_to_refresh = ( + await self.store.get_remote_servers_with_profiles_to_refresh( + now_ts=self.clock.time_msec(), limit=limit + ) + ) + + if not servers_to_refresh: + # Do we have any backing-off servers that we should try again + # for eventually? + # By setting `now` is a point in the far future, we can ask for + # which server/user is next to be refreshed, even though it is + # not actually refreshable *now*. + end_of_time = 1 << 62 + backing_off_servers = ( + await self.store.get_remote_servers_with_profiles_to_refresh( + now_ts=end_of_time, limit=1 + ) + ) + if backing_off_servers: + # Find out when the next user is refreshable and schedule a + # refresh then. + backing_off_server_name = backing_off_servers[0] + users = await self.store.get_remote_users_to_refresh_on_server( + backing_off_server_name, now_ts=end_of_time, limit=1 + ) + if not users: + return + _, _, next_try_at_ts = users[0] + self._refresh_remote_profiles_call_later = self.clock.call_later( + ((next_try_at_ts - self.clock.time_msec()) // 1000) + 2, + self.kick_off_remote_profile_refresh_process, + ) + + return + + for server_to_refresh in servers_to_refresh: + self.kick_off_remote_profile_refresh_process_for_remote_server( + server_to_refresh + ) + + self._refresh_remote_profiles_call_later = self.clock.call_later( + INTERVAL_TO_ADD_MORE_SERVERS_TO_REFRESH_PROFILES, + self.kick_off_remote_profile_refresh_process, + ) + + def kick_off_remote_profile_refresh_process_for_remote_server( + self, server_name: str + ) -> None: + """Called when there may be remote users with stale profiles to be refreshed + on the given server.""" + if not self.update_user_directory: + return + + if server_name in self._is_refreshing_remote_profiles_for_servers: + return + + async def process() -> None: + try: + await self._unsafe_refresh_remote_profiles_for_remote_server( + server_name + ) + finally: + self._is_refreshing_remote_profiles_for_servers.remove(server_name) + + self._is_refreshing_remote_profiles_for_servers.add(server_name) + run_as_background_process( + "user_directory.refresh_remote_profiles_for_remote_server", process + ) + + async def _unsafe_refresh_remote_profiles_for_remote_server( + self, server_name: str + ) -> None: + logger.info("Refreshing profiles in user directory for %s", server_name) + + while True: + # Get a handful of users to process. + next_batch = await self.store.get_remote_users_to_refresh_on_server( + server_name, now_ts=self.clock.time_msec(), limit=10 + ) + if not next_batch: + # Finished for now + return + + for user_id, retry_counter, _ in next_batch: + # Request the profile of the user. + try: + profile = await self._hs.get_profile_handler().get_profile( + user_id, ignore_backoff=False + ) + except NotRetryingDestination as e: + logger.info( + "Failed to refresh profile for %r because the destination is undergoing backoff", + user_id, + ) + # As a special-case, we back off until the destination is no longer + # backed off from. + await self.store.set_remote_user_profile_in_user_dir_stale( + user_id, + e.retry_last_ts + e.retry_interval, + retry_counter=retry_counter + 1, + ) + continue + except SynapseError as e: + if e.code == HTTPStatus.NOT_FOUND and e.errcode == Codes.NOT_FOUND: + # The profile doesn't exist. + # TODO Does this mean we should clear it from our user + # directory? + await self.store.clear_remote_user_profile_in_user_dir_stale( + user_id + ) + logger.warning( + "Refresh of remote profile %r: not found (%r)", + user_id, + e.msg, + ) + continue + + logger.warning( + "Failed to refresh profile for %r because %r", user_id, e + ) + await self.store.set_remote_user_profile_in_user_dir_stale( + user_id, + calculate_time_of_next_retry( + self.clock.time_msec(), retry_counter + 1 + ), + retry_counter=retry_counter + 1, + ) + continue + except Exception: + logger.error( + "Failed to refresh profile for %r due to unhandled exception", + user_id, + exc_info=True, + ) + await self.store.set_remote_user_profile_in_user_dir_stale( + user_id, + calculate_time_of_next_retry( + self.clock.time_msec(), retry_counter + 1 + ), + retry_counter=retry_counter + 1, + ) + continue + + await self.store.update_profile_in_user_dir( + user_id, + display_name=non_null_str_or_none(profile.get("displayname")), + avatar_url=non_null_str_or_none(profile.get("avatar_url")), + ) diff --git a/synapse/http/client.py b/synapse/http/client.py
index ae48e7c3f0..d777d59ccf 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -268,8 +268,8 @@ class BlacklistingAgentWrapper(Agent): def __init__( self, agent: IAgent, + ip_blacklist: IPSet, ip_whitelist: Optional[IPSet] = None, - ip_blacklist: Optional[IPSet] = None, ): """ Args: @@ -291,7 +291,9 @@ class BlacklistingAgentWrapper(Agent): h = urllib.parse.urlparse(uri.decode("ascii")) try: - ip_address = IPAddress(h.hostname) + # h.hostname is Optional[str], None raises an AddrFormatError, so + # this is safe even though IPAddress requires a str. + ip_address = IPAddress(h.hostname) # type: ignore[arg-type] except AddrFormatError: # Not an IP pass @@ -388,8 +390,8 @@ class SimpleHttpClient: # by the DNS resolution. self.agent = BlacklistingAgentWrapper( self.agent, - ip_whitelist=self._ip_whitelist, ip_blacklist=self._ip_blacklist, + ip_whitelist=self._ip_whitelist, ) async def request( diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 0359231e7d..8d7d0a3875 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -87,7 +87,7 @@ class MatrixFederationAgent: reactor: ISynapseReactor, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, - ip_whitelist: IPSet, + ip_whitelist: Optional[IPSet], ip_blacklist: IPSet, _srv_resolver: Optional[SrvResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None, diff --git a/synapse/http/server.py b/synapse/http/server.py
index 9314454af1..7b760505b2 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -892,6 +892,10 @@ def set_cors_headers(request: SynapseRequest) -> None: b"Access-Control-Allow-Headers", b"X-Requested-With, Content-Type, Authorization, Date", ) + request.setHeader( + b"Access-Control-Expose-Headers", + b"Synapse-Trace-Id", + ) def set_corp_headers(request: Request) -> None: diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py new file mode 100644
index 0000000000..c8a4a809f1 --- /dev/null +++ b/synapse/media/url_previewer.py
@@ -0,0 +1,833 @@ +# Copyright 2016 OpenMarket Ltd +# Copyright 2020-2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import errno +import fnmatch +import logging +import os +import re +import shutil +import sys +import traceback +from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple +from urllib.parse import urljoin, urlparse, urlsplit +from urllib.request import urlopen + +import attr + +from twisted.internet.defer import Deferred +from twisted.internet.error import DNSLookupError + +from synapse.api.errors import Codes, SynapseError +from synapse.http.client import SimpleHttpClient +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.media._base import FileInfo, get_filename_from_headers +from synapse.media.media_storage import MediaStorage +from synapse.media.oembed import OEmbedProvider +from synapse.media.preview_html import decode_body, parse_html_to_open_graph +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict, UserID +from synapse.util import json_encoder +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.media.media_repository import MediaRepository + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +OG_TAG_NAME_MAXLEN = 50 +OG_TAG_VALUE_MAXLEN = 1000 + +ONE_HOUR = 60 * 60 * 1000 +ONE_DAY = 24 * ONE_HOUR +IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DownloadResult: + length: int + uri: str + response_code: int + media_type: str + download_name: Optional[str] + expires: int + etag: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class MediaInfo: + """ + Information parsed from downloading media being previewed. + """ + + # The Content-Type header of the response. + media_type: str + # The length (in bytes) of the downloaded media. + media_length: int + # The media filename, according to the server. This is parsed from the + # returned headers, if possible. + download_name: Optional[str] + # The time of the preview. + created_ts_ms: int + # Information from the media storage provider about where the file is stored + # on disk. + filesystem_id: str + filename: str + # The URI being previewed. + uri: str + # The HTTP response code. + response_code: int + # The timestamp (in milliseconds) of when this preview expires. + expires: int + # The ETag header of the response. + etag: Optional[str] + + +class UrlPreviewer: + """ + Generates an Open Graph (https://ogp.me/) responses (with some Matrix + specific additions) for a given URL. + + When Synapse is asked to preview a URL it does the following: + + 1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the + config). + 2. Checks the URL against an in-memory cache and returns the result if it exists. (This + is also used to de-duplicate processing of multiple in-flight requests at once.) + 3. Kicks off a background process to generate a preview: + 1. Checks URL and timestamp against the database cache and returns the result if it + has not expired and was successful (a 2xx return code). + 2. Checks if the URL matches an oEmbed (https://oembed.com/) pattern. If it + does, update the URL to download. + 3. Downloads the URL and stores it into a file via the media storage provider + and saves the local media metadata. + 4. If the media is an image: + 1. Generates thumbnails. + 2. Generates an Open Graph response based on image properties. + 5. If the media is HTML: + 1. Decodes the HTML via the stored file. + 2. Generates an Open Graph response from the HTML. + 3. If a JSON oEmbed URL was found in the HTML via autodiscovery: + 1. Downloads the URL and stores it into a file via the media storage provider + and saves the local media metadata. + 2. Convert the oEmbed response to an Open Graph response. + 3. Override any Open Graph data from the HTML with data from oEmbed. + 4. If an image exists in the Open Graph response: + 1. Downloads the URL and stores it into a file via the media storage + provider and saves the local media metadata. + 2. Generates thumbnails. + 3. Updates the Open Graph response based on image properties. + 6. If the media is JSON and an oEmbed URL was found: + 1. Convert the oEmbed response to an Open Graph response. + 2. If a thumbnail or image is in the oEmbed response: + 1. Downloads the URL and stores it into a file via the media storage + provider and saves the local media metadata. + 2. Generates thumbnails. + 3. Updates the Open Graph response based on image properties. + 7. Stores the result in the database cache. + 4. Returns the result. + + If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or + image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole + does not fail. As much information as possible is returned. + + The in-memory cache expires after 1 hour. + + Expired entries in the database cache (and their associated media files) are + deleted every 10 seconds. The default expiration time is 1 hour from download. + """ + + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): + self.clock = hs.get_clock() + self.filepaths = media_repo.filepaths + self.max_spider_size = hs.config.media.max_spider_size + self.server_name = hs.hostname + self.store = hs.get_datastores().main + self.client = SimpleHttpClient( + hs, + treq_args={"browser_like_redirects": True}, + ip_whitelist=hs.config.media.url_preview_ip_range_whitelist, + ip_blacklist=hs.config.media.url_preview_ip_range_blacklist, + use_proxy=True, + ) + self.media_repo = media_repo + self.primary_base_path = media_repo.primary_base_path + self.media_storage = media_storage + + self._oembed = OEmbedProvider(hs) + + # We run the background jobs if we're the instance specified (or no + # instance is specified, where we assume there is only one instance + # serving media). + instance_running_jobs = hs.config.media.media_instance_running_background_jobs + self._worker_run_media_background_jobs = ( + instance_running_jobs is None + or instance_running_jobs == hs.get_instance_name() + ) + + self.url_preview_url_blacklist = hs.config.media.url_preview_url_blacklist + self.url_preview_accept_language = hs.config.media.url_preview_accept_language + + # memory cache mapping urls to an ObservableDeferred returning + # JSON-encoded OG metadata + self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache( + cache_name="url_previews", + clock=self.clock, + # don't spider URLs more often than once an hour + expiry_ms=ONE_HOUR, + ) + + if self._worker_run_media_background_jobs: + self._cleaner_loop = self.clock.looping_call( + self._start_expire_url_cache_data, 10 * 1000 + ) + + async def preview(self, url: str, user: UserID, ts: int) -> bytes: + # XXX: we could move this into _do_preview if we wanted. + url_tuple = urlsplit(url) + for entry in self.url_preview_url_blacklist: + match = True + for attrib in entry: + pattern = entry[attrib] + value = getattr(url_tuple, attrib) + logger.debug( + "Matching attrib '%s' with value '%s' against pattern '%s'", + attrib, + value, + pattern, + ) + + if value is None: + match = False + continue + + # Some attributes might not be parsed as strings by urlsplit (such as the + # port, which is parsed as an int). Because we use match functions that + # expect strings, we want to make sure that's what we give them. + value_str = str(value) + + if pattern.startswith("^"): + if not re.match(pattern, value_str): + match = False + continue + else: + if not fnmatch.fnmatch(value_str, pattern): + match = False + continue + if match: + logger.warning("URL %s blocked by url_blacklist entry %s", url, entry) + raise SynapseError( + 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN + ) + + # the in-memory cache: + # * ensures that only one request is active at a time + # * takes load off the DB for the thundering herds + # * also caches any failures (unlike the DB) so we don't keep + # requesting the same endpoint + + observable = self._cache.get(url) + + if not observable: + download = run_in_background(self._do_preview, url, user, ts) + observable = ObservableDeferred(download, consumeErrors=True) + self._cache[url] = observable + else: + logger.info("Returning cached response") + + return await make_deferred_yieldable(observable.observe()) + + async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes: + """Check the db, and download the URL and build a preview + + Args: + url: The URL to preview. + user: The user requesting the preview. + ts: The timestamp requested for the preview. + + Returns: + json-encoded og data + """ + # check the URL cache in the DB (which will also provide us with + # historical previews, if we have any) + cache_result = await self.store.get_url_cache(url, ts) + if ( + cache_result + and cache_result["expires_ts"] > ts + and cache_result["response_code"] / 100 == 2 + ): + # It may be stored as text in the database, not as bytes (such as + # PostgreSQL). If so, encode it back before handing it on. + og = cache_result["og"] + if isinstance(og, str): + og = og.encode("utf8") + return og + + # If this URL can be accessed via oEmbed, use that instead. + url_to_download = url + oembed_url = self._oembed.get_oembed_url(url) + if oembed_url: + url_to_download = oembed_url + + media_info = await self._handle_url(url_to_download, user) + + logger.debug("got media_info of '%s'", media_info) + + # The number of milliseconds that the response should be considered valid. + expiration_ms = media_info.expires + author_name: Optional[str] = None + + if _is_media(media_info.media_type): + file_id = media_info.filesystem_id + dims = await self.media_repo._generate_thumbnails( + None, file_id, file_id, media_info.media_type, url_cache=True + ) + + og = { + "og:description": media_info.download_name, + "og:image": f"mxc://{self.server_name}/{media_info.filesystem_id}", + "og:image:type": media_info.media_type, + "matrix:image:size": media_info.media_length, + } + + if dims: + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] + else: + logger.warning("Couldn't get dims for %s" % url) + + # define our OG response for this media + elif _is_html(media_info.media_type): + # TODO: somehow stop a big HTML tree from exploding synapse's RAM + + with open(media_info.filename, "rb") as file: + body = file.read() + + tree = decode_body(body, media_info.uri, media_info.media_type) + if tree is not None: + # Check if this HTML document points to oEmbed information and + # defer to that. + oembed_url = self._oembed.autodiscover_from_html(tree) + og_from_oembed: JsonDict = {} + if oembed_url: + try: + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) + except Exception as e: + # Fetching the oEmbed info failed, don't block the entire URL preview. + logger.warning( + "oEmbed fetch failed during URL preview: %s errored with %s", + oembed_url, + e, + ) + else: + ( + og_from_oembed, + author_name, + expiration_ms, + ) = await self._handle_oembed_response( + url, oembed_info, expiration_ms + ) + + # Parse Open Graph information from the HTML in case the oEmbed + # response failed or is incomplete. + og_from_html = parse_html_to_open_graph(tree) + + # Compile the Open Graph response by using the scraped + # information from the HTML and overlaying any information + # from the oEmbed response. + og = {**og_from_html, **og_from_oembed} + + await self._precache_image_url(user, media_info, og) + else: + og = {} + + elif oembed_url: + # Handle the oEmbed information. + og, author_name, expiration_ms = await self._handle_oembed_response( + url, media_info, expiration_ms + ) + await self._precache_image_url(user, media_info, og) + + else: + logger.warning("Failed to find any OG data in %s", url) + og = {} + + # If we don't have a title but we have author_name, copy it as + # title + if not og.get("og:title") and author_name: + og["og:title"] = author_name + + # filter out any stupidly long values + keys_to_remove = [] + for k, v in og.items(): + # values can be numeric as well as strings, hence the cast to str + if len(k) > OG_TAG_NAME_MAXLEN or len(str(v)) > OG_TAG_VALUE_MAXLEN: + logger.warning( + "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN] + ) + keys_to_remove.append(k) + for k in keys_to_remove: + del og[k] + + logger.debug("Calculated OG for %s as %s", url, og) + + jsonog = json_encoder.encode(og) + + # Cap the amount of time to consider a response valid. + expiration_ms = min(expiration_ms, ONE_DAY) + + # store OG in history-aware DB cache + await self.store.store_url_cache( + url, + media_info.response_code, + media_info.etag, + media_info.created_ts_ms + expiration_ms, + jsonog, + media_info.filesystem_id, + media_info.created_ts_ms, + ) + + return jsonog.encode("utf8") + + async def _download_url(self, url: str, output_stream: BinaryIO) -> DownloadResult: + """ + Fetches a remote URL and parses the headers. + + Args: + url: The URL to fetch. + output_stream: The stream to write the content to. + + Returns: + A tuple of: + Media length, URL downloaded, the HTTP response code, + the media type, the downloaded file name, the number of + milliseconds the result is valid for, the etag header. + """ + + try: + logger.debug("Trying to get preview for url '%s'", url) + length, headers, uri, code = await self.client.get_file( + url, + output_stream=output_stream, + max_size=self.max_spider_size, + headers={ + b"Accept-Language": self.url_preview_accept_language, + # Use a custom user agent for the preview because some sites will only return + # Open Graph metadata to crawler user agents. Omit the Synapse version + # string to avoid leaking information. + b"User-Agent": [ + "Synapse (bot; +https://github.com/matrix-org/synapse)" + ], + }, + is_allowed_content_type=_is_previewable, + ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise + except DNSLookupError: + # DNS lookup returned no results + # Note: This will also be the case if one of the resolved IP + # addresses is blacklisted + raise SynapseError( + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, + ) + except Exception as e: + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading %s: %r", url, e) + + raise SynapseError( + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + + download_name = get_filename_from_headers(headers) + + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + expires = ONE_HOUR + etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + + return DownloadResult( + length, uri, code, media_type, download_name, expires, etag + ) + + async def _parse_data_url( + self, url: str, output_stream: BinaryIO + ) -> DownloadResult: + """ + Parses a data: URL. + + Args: + url: The URL to parse. + output_stream: The stream to write the content to. + + Returns: + A tuple of: + Media length, URL downloaded, the HTTP response code, + the media type, the downloaded file name, the number of + milliseconds the result is valid for, the etag header. + """ + + try: + logger.debug("Trying to parse data url '%s'", url) + with urlopen(url) as url_info: + # TODO Can this be more efficient. + output_stream.write(url_info.read()) + except Exception as e: + logger.warning("Error parsing data: URL %s: %r", url, e) + + raise SynapseError( + 500, + "Failed to parse data URL: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + + return DownloadResult( + # Read back the length that has been written. + length=output_stream.tell(), + uri=url, + # If it was parsed, consider this a 200 OK. + response_code=200, + # urlopen shoves the media-type from the data URL into the content type + # header object. + media_type=url_info.headers.get_content_type(), + # Some features are not supported by data: URLs. + download_name=None, + expires=ONE_HOUR, + etag=None, + ) + + async def _handle_url( + self, url: str, user: UserID, allow_data_urls: bool = False + ) -> MediaInfo: + """ + Fetches content from a URL and parses the result to generate a MediaInfo. + + It uses the media storage provider to persist the fetched content and + stores the mapping into the database. + + Args: + url: The URL to fetch. + user: The user who ahs requested this URL. + allow_data_urls: True if data URLs should be allowed. + + Returns: + A MediaInfo object describing the fetched content. + """ + + # TODO: we should probably honour robots.txt... except in practice + # we're most likely being explicitly triggered by a human rather than a + # bot, so are we really a robot? + + file_id = datetime.date.today().isoformat() + "_" + random_string(16) + + file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) + + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + if url.startswith("data:"): + if not allow_data_urls: + raise SynapseError( + 500, "Previewing of data: URLs is forbidden", Codes.UNKNOWN + ) + + download_result = await self._parse_data_url(url, f) + else: + download_result = await self._download_url(url, f) + + await finish() + + try: + time_now_ms = self.clock.time_msec() + + await self.store.store_local_media( + media_id=file_id, + media_type=download_result.media_type, + time_now_ms=time_now_ms, + upload_name=download_result.download_name, + media_length=download_result.length, + user_id=user, + url_cache=url, + ) + + except Exception as e: + logger.error("Error handling downloaded %s: %r", url, e) + # TODO: we really ought to delete the downloaded file in this + # case, since we won't have recorded it in the db, and will + # therefore not expire it. + raise + + return MediaInfo( + media_type=download_result.media_type, + media_length=download_result.length, + download_name=download_result.download_name, + created_ts_ms=time_now_ms, + filesystem_id=file_id, + filename=fname, + uri=download_result.uri, + response_code=download_result.response_code, + expires=download_result.expires, + etag=download_result.etag, + ) + + async def _precache_image_url( + self, user: UserID, media_info: MediaInfo, og: JsonDict + ) -> None: + """ + Pre-cache the image (if one exists) for posterity + + Args: + user: The user requesting the preview. + media_info: The media being previewed. + og: The Open Graph dictionary. This is modified with image information. + """ + # If there's no image or it is blank, there's nothing to do. + if "og:image" not in og: + return + + # Remove the raw image URL, this will be replaced with an MXC URL, if successful. + image_url = og.pop("og:image") + if not image_url: + return + + # The image URL from the HTML might be relative to the previewed page, + # convert it to an URL which can be requested directly. + url_parts = urlparse(image_url) + if url_parts.scheme != "data": + image_url = urljoin(media_info.uri, image_url) + + # FIXME: it might be cleaner to use the same flow as the main /preview_url + # request itself and benefit from the same caching etc. But for now we + # just rely on the caching on the master request to speed things up. + try: + image_info = await self._handle_url(image_url, user, allow_data_urls=True) + except Exception as e: + # Pre-caching the image failed, don't block the entire URL preview. + logger.warning( + "Pre-caching image failed during URL preview: %s errored with %s", + image_url, + e, + ) + return + + if _is_media(image_info.media_type): + # TODO: make sure we don't choke on white-on-transparent images + file_id = image_info.filesystem_id + dims = await self.media_repo._generate_thumbnails( + None, file_id, file_id, image_info.media_type, url_cache=True + ) + if dims: + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] + else: + logger.warning("Couldn't get dims for %s", image_url) + + og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}" + og["og:image:type"] = image_info.media_type + og["matrix:image:size"] = image_info.media_length + + async def _handle_oembed_response( + self, url: str, media_info: MediaInfo, expiration_ms: int + ) -> Tuple[JsonDict, Optional[str], int]: + """ + Parse the downloaded oEmbed info. + + Args: + url: The URL which is being previewed (not the one which was + requested). + media_info: The media being previewed. + expiration_ms: The length of time, in milliseconds, the media is valid for. + + Returns: + A tuple of: + The Open Graph dictionary, if the oEmbed info can be parsed. + The author name if it could be retrieved from oEmbed. + The (possibly updated) length of time, in milliseconds, the media is valid for. + """ + # If JSON was not returned, there's nothing to do. + if not _is_json(media_info.media_type): + return {}, None, expiration_ms + + with open(media_info.filename, "rb") as file: + body = file.read() + + oembed_response = self._oembed.parse_oembed_response(url, body) + open_graph_result = oembed_response.open_graph_result + + # Use the cache age from the oEmbed result, if one was given. + if open_graph_result and oembed_response.cache_age is not None: + expiration_ms = oembed_response.cache_age + + return open_graph_result, oembed_response.author_name, expiration_ms + + def _start_expire_url_cache_data(self) -> Deferred: + return run_as_background_process( + "expire_url_cache_data", self._expire_url_cache_data + ) + + async def _expire_url_cache_data(self) -> None: + """Clean up expired url cache content, media and thumbnails.""" + + assert self._worker_run_media_background_jobs + + now = self.clock.time_msec() + + logger.debug("Running url preview cache expiry") + + def try_remove_parent_dirs(dirs: Iterable[str]) -> None: + """Attempt to remove the given chain of parent directories + + Args: + dirs: The list of directory paths to delete, with children appearing + before their parents. + """ + for dir in dirs: + try: + os.rmdir(dir) + except FileNotFoundError: + # Already deleted, continue with deleting the rest + pass + except OSError as e: + # Failed, skip deleting the rest of the parent dirs + if e.errno != errno.ENOTEMPTY: + logger.warning( + "Failed to remove media directory while clearing url preview cache: %r: %s", + dir, + e, + ) + break + + # First we delete expired url cache entries + media_ids = await self.store.get_expired_url_cache(now) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except FileNotFoundError: + pass # If the path doesn't exist, meh + except OSError as e: + logger.warning( + "Failed to remove media while clearing url preview cache: %r: %s", + media_id, + e, + ) + continue + + removed_media.append(media_id) + + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + try_remove_parent_dirs(dirs) + + await self.store.delete_url_cache(removed_media) + + if removed_media: + logger.debug( + "Deleted %d entries from url preview cache", len(removed_media) + ) + else: + logger.debug("No entries removed from url preview cache") + + # Now we delete old images associated with the url cache. + # These may be cached for a bit on the client (i.e., they + # may have a room open with a preview url thing open). + # So we wait a couple of days before deleting, just in case. + expire_before = now - IMAGE_CACHE_EXPIRY_MS + media_ids = await self.store.get_url_cache_media_before(expire_before) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except FileNotFoundError: + pass # If the path doesn't exist, meh + except OSError as e: + logger.warning( + "Failed to remove media from url preview cache: %r: %s", media_id, e + ) + continue + + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + try_remove_parent_dirs(dirs) + + thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id) + try: + shutil.rmtree(thumbnail_dir) + except FileNotFoundError: + pass # If the path doesn't exist, meh + except OSError as e: + logger.warning( + "Failed to remove media from url preview cache: %r: %s", media_id, e + ) + continue + + removed_media.append(media_id) + + dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id) + # Note that one of the directories to be deleted has already been + # removed by the `rmtree` above. + try_remove_parent_dirs(dirs) + + await self.store.delete_url_cache_media(removed_media) + + if removed_media: + logger.debug("Deleted %d media from url preview cache", len(removed_media)) + else: + logger.debug("No media removed from url preview cache") + + +def _is_media(content_type: str) -> bool: + return content_type.lower().startswith("image/") + + +def _is_html(content_type: str) -> bool: + content_type = content_type.lower() + return content_type.startswith("text/html") or content_type.startswith( + "application/xhtml" + ) + + +def _is_json(content_type: str) -> bool: + return content_type.lower().startswith("application/json") + + +def _is_previewable(content_type: str) -> bool: + """Returns True for content types for which we will perform URL preview and False + otherwise.""" + + return _is_html(content_type) or _is_media(content_type) or _is_json(content_type) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 424239e3df..595c23e78d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -73,13 +73,6 @@ from synapse.events.third_party_rules import ( ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK -from synapse.handlers.account_validity import ( - IS_USER_EXPIRED_CALLBACK, - ON_LEGACY_ADMIN_REQUEST, - ON_LEGACY_RENEW_CALLBACK, - ON_LEGACY_SEND_MAIL_CALLBACK, - ON_USER_REGISTRATION_CALLBACK, -) from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, @@ -105,6 +98,13 @@ from synapse.logging.context import ( run_in_background, ) from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.module_api.callbacks.account_validity_callbacks import ( + IS_USER_EXPIRED_CALLBACK, + ON_LEGACY_ADMIN_REQUEST, + ON_LEGACY_RENEW_CALLBACK, + ON_LEGACY_SEND_MAIL_CALLBACK, + ON_USER_REGISTRATION_CALLBACK, +) from synapse.rest.client.login import LoginResponse from synapse.storage import DataStore from synapse.storage.background_updates import ( @@ -250,6 +250,7 @@ class ModuleApi: self._push_rules_handler = hs.get_push_rules_handler() self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory + self._callbacks = hs.get_module_api_callbacks() try: app_name = self._hs.config.email.email_app_name @@ -271,7 +272,6 @@ class ModuleApi: self._account_data_manager = AccountDataManager(hs) self._spam_checker = hs.get_spam_checker() - self._account_validity_handler = hs.get_account_validity_handler() self._third_party_event_rules = hs.get_third_party_event_rules() self._password_auth_provider = hs.get_password_auth_provider() self._presence_router = hs.get_presence_router() @@ -332,7 +332,7 @@ class ModuleApi: Added in Synapse v1.39.0. """ - return self._account_validity_handler.register_account_validity_callbacks( + return self._callbacks.account_validity.register_callbacks( is_user_expired=is_user_expired, on_user_registration=on_user_registration, on_legacy_send_mail=on_legacy_send_mail, diff --git a/synapse/module_api/callbacks/__init__.py b/synapse/module_api/callbacks/__init__.py new file mode 100644
index 0000000000..3d977bf655 --- /dev/null +++ b/synapse/module_api/callbacks/__init__.py
@@ -0,0 +1,22 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.module_api.callbacks.account_validity_callbacks import ( + AccountValidityModuleApiCallbacks, +) + + +class ModuleApiCallbacks: + def __init__(self) -> None: + self.account_validity = AccountValidityModuleApiCallbacks() diff --git a/synapse/module_api/callbacks/account_validity_callbacks.py b/synapse/module_api/callbacks/account_validity_callbacks.py new file mode 100644
index 0000000000..531d0c9ddc --- /dev/null +++ b/synapse/module_api/callbacks/account_validity_callbacks.py
@@ -0,0 +1,93 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Awaitable, Callable, List, Optional, Tuple + +from twisted.web.http import Request + +logger = logging.getLogger(__name__) + +# Types for callbacks to be registered via the module api +IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] +ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] +# Temporary hooks to allow for a transition from `/_matrix/client` endpoints +# to `/_synapse/client/account_validity`. See `register_callbacks` below. +ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] +ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]] +ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable] + + +class AccountValidityModuleApiCallbacks: + def __init__(self) -> None: + self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] + self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] + self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None + self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None + + # The legacy admin requests callback isn't a protected attribute because we need + # to access it from the admin servlet, which is outside of this handler. + self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None + + def register_callbacks( + self, + is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, + on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, + on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, + on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + ) -> None: + """Register callbacks from module for each hook.""" + if is_user_expired is not None: + self.is_user_expired_callbacks.append(is_user_expired) + + if on_user_registration is not None: + self.on_user_registration_callbacks.append(on_user_registration) + + # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and + # an admin one). As part of moving the feature into a module, we need to change + # the path from /_matrix/client/unstable/account_validity/... to + # /_synapse/client/account_validity, because: + # + # * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix + # * the way we register servlets means that modules can't register resources + # under /_matrix/client + # + # We need to allow for a transition period between the old and new endpoints + # in order to allow for clients to update (and for emails to be processed). + # + # Once the email-account-validity module is loaded, it will take control of account + # validity by moving the rows from our `account_validity` table into its own table. + # + # Therefore, we need to allow modules (in practice just the one implementing the + # email-based account validity) to temporarily hook into the legacy endpoints so we + # can route the traffic coming into the old endpoints into the module, which is + # why we have the following three temporary hooks. + if on_legacy_send_mail is not None: + if self.on_legacy_send_mail_callback is not None: + raise RuntimeError("Tried to register on_legacy_send_mail twice") + + self.on_legacy_send_mail_callback = on_legacy_send_mail + + if on_legacy_renew is not None: + if self.on_legacy_renew_callback is not None: + raise RuntimeError("Tried to register on_legacy_renew twice") + + self.on_legacy_renew_callback = on_legacy_renew + + if on_legacy_admin_request is not None: + if self.on_legacy_admin_request_callback is not None: + raise RuntimeError("Tried to register on_legacy_admin_request twice") + + self.on_legacy_admin_request_callback = on_legacy_admin_request diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index ba12b6d79a..199337673f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -273,10 +273,7 @@ class BulkPushRuleEvaluator: related_event_id, allow_none=True ) if related_event is not None: - related_events[relation_type] = _flatten_dict( - related_event, - msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, - ) + related_events[relation_type] = _flatten_dict(related_event) reply_event_id = ( event.content.get("m.relates_to", {}) @@ -291,10 +288,7 @@ class BulkPushRuleEvaluator: ) if related_event is not None: - related_events["m.in_reply_to"] = _flatten_dict( - related_event, - msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, - ) + related_events["m.in_reply_to"] = _flatten_dict(related_event) # indicate that this is from a fallback relation. if relation_type == "m.thread" and event.content.get( @@ -401,10 +395,7 @@ class BulkPushRuleEvaluator: ) evaluator = PushRuleEvaluator( - _flatten_dict( - event, - msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, - ), + _flatten_dict(event), has_mentions, room_member_count, sender_power_level, @@ -413,7 +404,6 @@ class BulkPushRuleEvaluator: self._related_event_match_enabled, event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag - self.hs.config.experimental.msc3966_exact_event_property_contains, ) users = rules_by_user.keys() @@ -495,8 +485,6 @@ def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, result: Optional[Dict[str, JsonValue]] = None, - *, - msc3873_escape_event_match_key: bool = False, ) -> Dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, @@ -525,11 +513,10 @@ def _flatten_dict( if result is None: result = {} for key, value in d.items(): - if msc3873_escape_event_match_key: - # Escape periods in the key with a backslash (and backslashes with an - # extra backslash). This is since a period is used as a separator between - # nested fields. - key = key.replace("\\", "\\\\").replace(".", "\\.") + # Escape periods in the key with a backslash (and backslashes with an + # extra backslash). This is since a period is used as a separator between + # nested fields. + key = key.replace("\\", "\\\\").replace(".", "\\.") if _is_simple_value(value): result[".".join(prefix + [key])] = value @@ -537,12 +524,7 @@ def _flatten_dict( result[".".join(prefix + [key])] = [v for v in value if _is_simple_value(v)] elif isinstance(value, Mapping): # do not set `room_version` due to recursion considerations below - _flatten_dict( - value, - prefix=(prefix + [key]), - result=result, - msc3873_escape_event_match_key=msc3873_escape_event_match_key, - ) + _flatten_dict(value, prefix=(prefix + [key]), result=result) # `room_version` should only ever be set when looking at the top level of an event if ( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 424854efbe..200f667fdf 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -18,16 +18,12 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from twisted.internet import defer from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IAddress, IConnector -from twisted.internet.protocol import ReconnectingClientFactory -from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.federation import send_queue from synapse.federation.sender import FederationSender from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -53,7 +49,6 @@ from synapse.util.async_helpers import Linearizer, timeout_deferred from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -62,52 +57,6 @@ logger = logging.getLogger(__name__) _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 5 -class DirectTcpReplicationClientFactory(ReconnectingClientFactory): - """Factory for building connections to the master. Will reconnect if the - connection is lost. - - Accepts a handler that is passed to `ClientReplicationStreamProtocol`. - """ - - initialDelay = 0.1 - maxDelay = 1 # Try at least once every N seconds - - def __init__( - self, - hs: "HomeServer", - client_name: str, - command_handler: "ReplicationCommandHandler", - ): - self.client_name = client_name - self.command_handler = command_handler - self.server_name = hs.config.server.server_name - self.hs = hs - self._clock = hs.get_clock() # As self.clock is defined in super class - - hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) - - def startedConnecting(self, connector: IConnector) -> None: - logger.info("Connecting to replication: %r", connector.getDestination()) - - def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol: - logger.info("Connected to replication: %r", addr) - return ClientReplicationStreamProtocol( - self.hs, - self.client_name, - self.server_name, - self._clock, - self.command_handler, - ) - - def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: - logger.error("Lost replication conn: %r", reason) - ReconnectingClientFactory.clientConnectionLost(self, connector, reason) - - def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: - logger.error("Failed to connect to replication: %r", reason) - ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) - - class ReplicationDataHandler: """Handles incoming stream updates from replication. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index d03a53d764..2290b3e6fe 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -625,23 +625,6 @@ class ReplicationCommandHandler: self._notifier.notify_remote_server_up(cmd.data) - # We relay to all other connections to ensure every instance gets the - # notification. - # - # When configured to use redis we'll always only have one connection and - # so this is a no-op (all instances will have already received the same - # REMOTE_SERVER_UP command). - # - # For direct TCP connections this will relay to all other connections - # connected to us. When on master this will correctly fan out to all - # other direct TCP clients and on workers there'll only be the one - # connection to master. - # - # (The logic here should also be sound if we have a mix of Redis and - # direct TCP connections so long as there is only one traffic route - # between two instances, but that is not currently supported). - self.send_command(cmd, ignore_conn=conn) - def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -689,21 +672,14 @@ class ReplicationCommandHandler: """ return bool(self._connections) - def send_command( - self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None - ) -> None: + def send_command(self, cmd: Command) -> None: """Send a command to all connected connections. Args: cmd - ignore_conn: If set don't send command to the given connection. - Used when relaying commands from one connection to all others. """ if self._connections: for connection in self._connections: - if connection == ignore_conn: - continue - try: connection.send_command(cmd) except Exception: diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 2e19e055d3..55b448adfd 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -20,6 +20,7 @@ from synapse.rest.client import ( account, account_data, account_validity, + appservice_ping, auth, capabilities, devices, @@ -140,6 +141,7 @@ class ClientRestResource(JsonResource): if is_main_process: password_policy.register_servlets(hs, client_resource) knock.register_servlets(hs, client_resource) + appservice_ping.register_servlets(hs, client_resource) # moving to /_synapse/admin if is_main_process: diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 15da9cd881..7dd1c10b91 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError, SynapseError @@ -23,10 +23,10 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.http.site import SynapseRequest -from synapse.rest.admin import assert_requester_is_admin -from synapse.rest.admin._base import admin_patterns +from synapse.logging.opentracing import set_tag +from synapse.rest.admin._base import admin_patterns, assert_user_is_admin from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, Requester, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -70,10 +70,13 @@ class SendServerNoticeServlet(RestServlet): self.__class__.__name__, ) - async def on_POST( - self, request: SynapseRequest, txn_id: Optional[str] = None + async def _do( + self, + request: SynapseRequest, + requester: Requester, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_user_is_admin(self.auth, requester) body = parse_json_object_from_request(request) assert_params_in_dict(body, ("user_id", "content")) event_type = body.get("type", EventTypes.Message) @@ -106,9 +109,18 @@ class SendServerNoticeServlet(RestServlet): return HTTPStatus.OK, {"event_id": event.event_id} - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + return await self._do(request, requester, None) + + async def on_PUT( self, request: SynapseRequest, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, txn_id + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + set_tag("txn_id", txn_id) + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester, txn_id ) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 357e9a574d..281e8fd0ad 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -683,19 +683,18 @@ class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs: "HomeServer"): - self.account_activity_handler = hs.get_account_validity_handler() + self.account_validity_handler = hs.get_account_validity_handler() + self.account_validity_module_callbacks = ( + hs.get_module_api_callbacks().account_validity + ) self.auth = hs.get_auth() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if self.account_activity_handler.on_legacy_admin_request_callback: - expiration_ts = ( - await ( - self.account_activity_handler.on_legacy_admin_request_callback( - request - ) - ) + if self.account_validity_module_callbacks.on_legacy_admin_request_callback: + expiration_ts = await self.account_validity_module_callbacks.on_legacy_admin_request_callback( + request ) else: body = parse_json_object_from_request(request) @@ -706,7 +705,7 @@ class AccountValidityRenewServlet(RestServlet): "Missing property 'user_id' in the request body", ) - expiration_ts = await self.account_activity_handler.renew_account_for_user( + expiration_ts = await self.account_validity_handler.renew_account_for_user( body["user_id"], body.get("expiration_ts"), not body.get("enable_renewal_emails", True), diff --git a/synapse/rest/client/appservice_ping.py b/synapse/rest/client/appservice_ping.py new file mode 100644
index 0000000000..31466a4ad4 --- /dev/null +++ b/synapse/rest/client/appservice_ping.py
@@ -0,0 +1,115 @@ +# Copyright 2023 Tulir Asokan +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Dict, Tuple + +from synapse.api.errors import ( + CodeMessageException, + Codes, + HttpResponseException, + SynapseError, +) +from synapse.http import RequestTimedOutError +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class AppservicePingRestServlet(RestServlet): + PATTERNS = client_patterns( + "/fi.mau.msc2659/appservice/(?P<appservice_id>[^/]*)/ping", + unstable=True, + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.as_api = hs.get_application_service_api() + self.auth = hs.get_auth() + + async def on_POST( + self, request: SynapseRequest, appservice_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + if not requester.app_service: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "Only application services can use the /appservice/ping endpoint", + Codes.FORBIDDEN, + ) + elif requester.app_service.id != appservice_id: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "Mismatching application service ID in path", + Codes.FORBIDDEN, + ) + elif not requester.app_service.url: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The application service does not have a URL set", + Codes.AS_PING_URL_NOT_SET, + ) + + content = parse_json_object_from_request(request) + txn_id = content.get("transaction_id", None) + + start = time.monotonic() + try: + await self.as_api.ping(requester.app_service, txn_id) + except RequestTimedOutError as e: + raise SynapseError( + HTTPStatus.GATEWAY_TIMEOUT, + e.msg, + Codes.AS_PING_CONNECTION_TIMEOUT, + ) + except CodeMessageException as e: + additional_fields: Dict[str, Any] = {"status": e.code} + if isinstance(e, HttpResponseException): + try: + additional_fields["body"] = e.response.decode("utf-8") + except UnicodeDecodeError: + pass + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + f"HTTP {e.code} {e.msg}", + Codes.AS_PING_BAD_STATUS, + additional_fields=additional_fields, + ) + except Exception as e: + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + f"{type(e).__name__}: {e}", + Codes.AS_PING_CONNECTION_FAILED, + ) + + duration = time.monotonic() - start + + return HTTPStatus.OK, {"duration": int(duration * 1000)} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc2659_enabled: + AppservicePingRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index bce806f2bb..4adb5271d2 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py
@@ -956,7 +956,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.worker.worker_app is None: EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) - UsernameAvailabilityRestServlet(hs).register(http_server) RegistrationSubmitTokenServlet(hs).register(http_server) + UsernameAvailabilityRestServlet(hs).register(http_server) RegistrationTokenValidityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 61e4cf0213..129b6fe6b0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py
@@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID from synapse.types.state import StateFilter from synapse.util import json_decoder from synapse.util.cancellation import cancellable @@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - def on_PUT( + async def on_PUT( self, request: SynapseRequest, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request(request, self.on_POST, request) + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester + ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) + return await self._do(request, requester) + async def _do( + self, request: SynapseRequest, requester: Requester + ) -> Tuple[int, JsonDict]: room_id, _, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events -class RoomStateEventRestServlet(TransactionRestServlet): +class RoomStateEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): - super().__init__(hs) + super().__init__() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() @@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet): def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" - register_txn_path(self, PATTERNS, http_server, with_get=True) + register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, event_type: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict: JsonDict = { @@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_GET( - self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str - ) -> Tuple[int, str]: - return 200, "Not implemented" + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_id, event_type, None) - def on_PUT( + async def on_PUT( self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_type, txn_id + return await self.txns.fetch_or_execute_request( + request, + requester, + self._do, + request, + requester, + room_id, + event_type, + txn_id, ) @@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_identifier: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request, allow_empty_body=True) # twisted.web.server.Request.args is incorrectly defined as Optional[Any] @@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): return 200, {"room_id": room_id} - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_identifier: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_identifier, None) + + async def on_PUT( self, request: SynapseRequest, room_identifier: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester, room_identifier, txn_id ) # TODO: Needs unit testing -class PublicRoomListRestServlet(TransactionRestServlet): +class PublicRoomListRestServlet(RestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) def __init__(self, hs: "HomeServer"): - super().__init__(hs) + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - async def on_POST( - self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=False) - + async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]: await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} - def on_PUT( + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=False) + return await self._do(requester, room_id) + + async def on_PUT( self, request: SynapseRequest, room_id: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=False) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, requester, room_id ) @@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, membership_action: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - if requester.is_guest and membership_action not in { Membership.JOIN, Membership.LEAVE, @@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet): return 200, return_value - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + membership_action: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_id, membership_action, None) + + async def on_PUT( self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, membership_action, txn_id + return await self.txns.fetch_or_execute_request( + request, + requester, + self._do, + request, + requester, + room_id, + membership_action, + txn_id, ) @@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, event_id: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) try: @@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + return await self._do(request, requester, room_id, event_id, None) + + async def on_PUT( self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_id, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester, room_id, event_id, txn_id ) @@ -1224,7 +1280,6 @@ def register_txn_path( servlet: RestServlet, regex_string: str, http_server: HttpServer, - with_get: bool = False, ) -> None: """Registers a transaction-based path. @@ -1236,7 +1291,6 @@ def register_txn_path( regex_string: The regex string to register. Must NOT have a trailing $ as this string will be appended to. http_server: The http_server to register paths with. - with_get: True to also register respective GET paths for the PUTs. """ on_POST = getattr(servlet, "on_POST", None) on_PUT = getattr(servlet, "on_PUT", None) @@ -1254,18 +1308,6 @@ def register_txn_path( on_PUT, servlet.__class__.__name__, ) - on_GET = getattr(servlet, "on_GET", None) - if with_get: - if on_GET is None: - raise RuntimeError( - "register_txn_path called with with_get = True, but no on_GET method exists" - ) - http_server.register_paths( - "GET", - client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - on_GET, - servlet.__class__.__name__, - ) class TimestampLookupRestServlet(RestServlet): diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 55d52f0b28..110af6df47 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Tuple +from typing import TYPE_CHECKING, Tuple from synapse.http import servlet from synapse.http.server import HttpServer @@ -21,7 +21,7 @@ from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_r from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import JsonDict +from synapse.types import JsonDict, Requester from ._base import client_patterns @@ -43,19 +43,26 @@ class SendToDeviceRestServlet(servlet.RestServlet): self.txns = HttpTransactionCache(hs) self.device_message_handler = hs.get_device_message_handler() - def on_PUT( + async def on_PUT( self, request: SynapseRequest, message_type: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self._put, request, message_type, txn_id + return await self.txns.fetch_or_execute_request( + request, + requester, + self._put, + request, + requester, + message_type, ) async def _put( - self, request: SynapseRequest, message_type: str, txn_id: str + self, + request: SynapseRequest, + requester: Requester, + message_type: str, ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) assert_params_in_dict(content, ("messages",)) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 3f40f1874a..f2aaab6227 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py
@@ -15,16 +15,16 @@ """This module contains logic for storing HTTP PUT transactions. This is used to ensure idempotency when performing PUTs using the REST API.""" import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple from typing_extensions import ParamSpec from twisted.internet.defer import Deferred from twisted.python.failure import Failure -from twisted.web.server import Request +from twisted.web.iweb import IRequest from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.types import JsonDict +from synapse.types import JsonDict, Requester from synapse.util.async_helpers import ObservableDeferred if TYPE_CHECKING: @@ -41,53 +41,47 @@ P = ParamSpec("P") class HttpTransactionCache: def __init__(self, hs: "HomeServer"): self.hs = hs - self.auth = self.hs.get_auth() self.clock = self.hs.get_clock() # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) self.transactions: Dict[ - str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] + Hashable, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] ] = {} # Try to clean entries every 30 mins. This means entries will exist # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) - def _get_transaction_key(self, request: Request) -> str: + def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable: """A helper function which returns a transaction key that can be used with TransactionCache for idempotent requests. Idempotency is based on the returned key being the same for separate requests to the same endpoint. The key is formed from the HTTP request - path and the access_token for the requesting user. + path and attributes from the requester: the access_token_id for regular users, + the user ID for guest users, and the appservice ID for appservice users. Args: - request: The incoming request. Must contain an access_token. + request: The incoming request. + requester: The requester doing the request. Returns: A transaction key """ assert request.path is not None - token = self.auth.get_access_token_from_request(request) - return request.path.decode("utf8") + "/" + token + path: str = request.path.decode("utf8") + if requester.is_guest: + assert requester.user is not None, "Guest requester must have a user ID set" + return (path, "guest", requester.user) + elif requester.app_service is not None: + return (path, "appservice", requester.app_service.id) + else: + assert ( + requester.access_token_id is not None + ), "Requester must have an access_token_id" + return (path, "user", requester.access_token_id) def fetch_or_execute_request( self, - request: Request, - fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], - *args: P.args, - **kwargs: P.kwargs, - ) -> Awaitable[Tuple[int, JsonDict]]: - """A helper function for fetch_or_execute which extracts - a transaction key from the given request. - - See: - fetch_or_execute - """ - return self.fetch_or_execute( - self._get_transaction_key(request), fn, *args, **kwargs - ) - - def fetch_or_execute( - self, - txn_key: str, + request: IRequest, + requester: Requester, fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], *args: P.args, **kwargs: P.kwargs, @@ -96,14 +90,15 @@ class HttpTransactionCache: to produce a response for this transaction. Args: - txn_key: A key to ensure idempotency should fetch_or_execute be - called again at a later point in time. + request: + requester: fn: A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: Deferred which resolves to a tuple of (response_code, response_dict). """ + txn_key = self._get_transaction_key(request, requester) if txn_key in self.transactions: observable = self.transactions[txn_key][0] else: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index e19c0946c0..ec171582b3 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -109,6 +109,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3773": self.config.experimental.msc3773_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, + # Adds a ping endpoint for appservices to check HS->AS connection + "fi.mau.msc2659": self.config.experimental.msc2659_enabled, # Adds support for login token requests as per MSC3882 "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 @@ -120,6 +122,8 @@ class VersionsRestServlet(RestServlet): is not None, # Adds support for relation-based redactions as per MSC3912. "org.matrix.msc3912": self.config.experimental.msc3912_enabled, + # Adds support for unstable "intentional mentions" behaviour. + "org.matrix.msc3952_intentional_mentions": self.config.experimental.msc3952_intentional_mentions, }, }, ) diff --git a/synapse/rest/media/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py
index 7ada728757..58513c4be4 100644 --- a/synapse/rest/media/preview_url_resource.py +++ b/synapse/rest/media/preview_url_resource.py
@@ -12,26 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import errno -import fnmatch -import logging -import os -import re -import shutil -import sys -import traceback -from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple -from urllib.parse import urljoin, urlparse, urlsplit -from urllib.request import urlopen -import attr +from typing import TYPE_CHECKING -from twisted.internet.defer import Deferred -from twisted.internet.error import DNSLookupError - -from synapse.api.errors import Codes, SynapseError -from synapse.http.client import SimpleHttpClient from synapse.http.server import ( DirectServeJsonResource, respond_with_json, @@ -39,71 +22,13 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.media._base import FileInfo, get_filename_from_headers from synapse.media.media_storage import MediaStorage -from synapse.media.oembed import OEmbedProvider -from synapse.media.preview_html import decode_body, parse_html_to_open_graph -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import JsonDict, UserID -from synapse.util import json_encoder -from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.stringutils import random_string +from synapse.media.url_previewer import UrlPreviewer if TYPE_CHECKING: from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer -logger = logging.getLogger(__name__) - -OG_TAG_NAME_MAXLEN = 50 -OG_TAG_VALUE_MAXLEN = 1000 - -ONE_HOUR = 60 * 60 * 1000 -ONE_DAY = 24 * ONE_HOUR -IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DownloadResult: - length: int - uri: str - response_code: int - media_type: str - download_name: Optional[str] - expires: int - etag: Optional[str] - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class MediaInfo: - """ - Information parsed from downloading media being previewed. - """ - - # The Content-Type header of the response. - media_type: str - # The length (in bytes) of the downloaded media. - media_length: int - # The media filename, according to the server. This is parsed from the - # returned headers, if possible. - download_name: Optional[str] - # The time of the preview. - created_ts_ms: int - # Information from the media storage provider about where the file is stored - # on disk. - filesystem_id: str - filename: str - # The URI being previewed. - uri: str - # The HTTP response code. - response_code: int - # The timestamp (in milliseconds) of when this preview expires. - expires: int - # The ETag header of the response. - etag: Optional[str] - class PreviewUrlResource(DirectServeJsonResource): """ @@ -121,54 +46,6 @@ class PreviewUrlResource(DirectServeJsonResource): * The URL metadata must be stored somewhere, rather than just using Matrix itself to store the media. * Matrix cannot be used to distribute the metadata between homeservers. - - When Synapse is asked to preview a URL it does the following: - - 1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the - config). - 2. Checks the URL against an in-memory cache and returns the result if it exists. (This - is also used to de-duplicate processing of multiple in-flight requests at once.) - 3. Kicks off a background process to generate a preview: - 1. Checks URL and timestamp against the database cache and returns the result if it - has not expired and was successful (a 2xx return code). - 2. Checks if the URL matches an oEmbed (https://oembed.com/) pattern. If it - does, update the URL to download. - 3. Downloads the URL and stores it into a file via the media storage provider - and saves the local media metadata. - 4. If the media is an image: - 1. Generates thumbnails. - 2. Generates an Open Graph response based on image properties. - 5. If the media is HTML: - 1. Decodes the HTML via the stored file. - 2. Generates an Open Graph response from the HTML. - 3. If a JSON oEmbed URL was found in the HTML via autodiscovery: - 1. Downloads the URL and stores it into a file via the media storage provider - and saves the local media metadata. - 2. Convert the oEmbed response to an Open Graph response. - 3. Override any Open Graph data from the HTML with data from oEmbed. - 4. If an image exists in the Open Graph response: - 1. Downloads the URL and stores it into a file via the media storage - provider and saves the local media metadata. - 2. Generates thumbnails. - 3. Updates the Open Graph response based on image properties. - 6. If the media is JSON and an oEmbed URL was found: - 1. Convert the oEmbed response to an Open Graph response. - 2. If a thumbnail or image is in the oEmbed response: - 1. Downloads the URL and stores it into a file via the media storage - provider and saves the local media metadata. - 2. Generates thumbnails. - 3. Updates the Open Graph response based on image properties. - 7. Stores the result in the database cache. - 4. Returns the result. - - If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or - image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole - does not fail. As much information as possible is returned. - - The in-memory cache expires after 1 hour. - - Expired entries in the database cache (and their associated media files) are - deleted every 10 seconds. The default expiration time is 1 hour from download. """ isLeaf = True @@ -183,48 +60,10 @@ class PreviewUrlResource(DirectServeJsonResource): self.auth = hs.get_auth() self.clock = hs.get_clock() - self.filepaths = media_repo.filepaths - self.max_spider_size = hs.config.media.max_spider_size - self.server_name = hs.hostname - self.store = hs.get_datastores().main - self.client = SimpleHttpClient( - hs, - treq_args={"browser_like_redirects": True}, - ip_whitelist=hs.config.media.url_preview_ip_range_whitelist, - ip_blacklist=hs.config.media.url_preview_ip_range_blacklist, - use_proxy=True, - ) self.media_repo = media_repo - self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage - self._oembed = OEmbedProvider(hs) - - # We run the background jobs if we're the instance specified (or no - # instance is specified, where we assume there is only one instance - # serving media). - instance_running_jobs = hs.config.media.media_instance_running_background_jobs - self._worker_run_media_background_jobs = ( - instance_running_jobs is None - or instance_running_jobs == hs.get_instance_name() - ) - - self.url_preview_url_blacklist = hs.config.media.url_preview_url_blacklist - self.url_preview_accept_language = hs.config.media.url_preview_accept_language - - # memory cache mapping urls to an ObservableDeferred returning - # JSON-encoded OG metadata - self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache( - cache_name="url_previews", - clock=self.clock, - # don't spider URLs more often than once an hour - expiry_ms=ONE_HOUR, - ) - - if self._worker_run_media_background_jobs: - self._cleaner_loop = self.clock.looping_call( - self._start_expire_url_cache_data, 10 * 1000 - ) + self._url_previewer = UrlPreviewer(hs, media_repo, media_storage) async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") @@ -238,632 +77,5 @@ class PreviewUrlResource(DirectServeJsonResource): if ts is None: ts = self.clock.time_msec() - # XXX: we could move this into _do_preview if we wanted. - url_tuple = urlsplit(url) - for entry in self.url_preview_url_blacklist: - match = True - for attrib in entry: - pattern = entry[attrib] - value = getattr(url_tuple, attrib) - logger.debug( - "Matching attrib '%s' with value '%s' against pattern '%s'", - attrib, - value, - pattern, - ) - - if value is None: - match = False - continue - - # Some attributes might not be parsed as strings by urlsplit (such as the - # port, which is parsed as an int). Because we use match functions that - # expect strings, we want to make sure that's what we give them. - value_str = str(value) - - if pattern.startswith("^"): - if not re.match(pattern, value_str): - match = False - continue - else: - if not fnmatch.fnmatch(value_str, pattern): - match = False - continue - if match: - logger.warning("URL %s blocked by url_blacklist entry %s", url, entry) - raise SynapseError( - 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN - ) - - # the in-memory cache: - # * ensures that only one request is active at a time - # * takes load off the DB for the thundering herds - # * also caches any failures (unlike the DB) so we don't keep - # requesting the same endpoint - - observable = self._cache.get(url) - - if not observable: - download = run_in_background(self._do_preview, url, requester.user, ts) - observable = ObservableDeferred(download, consumeErrors=True) - self._cache[url] = observable - else: - logger.info("Returning cached response") - - og = await make_deferred_yieldable(observable.observe()) + og = await self._url_previewer.preview(url, requester.user, ts) respond_with_json_bytes(request, 200, og, send_cors=True) - - async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes: - """Check the db, and download the URL and build a preview - - Args: - url: The URL to preview. - user: The user requesting the preview. - ts: The timestamp requested for the preview. - - Returns: - json-encoded og data - """ - # check the URL cache in the DB (which will also provide us with - # historical previews, if we have any) - cache_result = await self.store.get_url_cache(url, ts) - if ( - cache_result - and cache_result["expires_ts"] > ts - and cache_result["response_code"] / 100 == 2 - ): - # It may be stored as text in the database, not as bytes (such as - # PostgreSQL). If so, encode it back before handing it on. - og = cache_result["og"] - if isinstance(og, str): - og = og.encode("utf8") - return og - - # If this URL can be accessed via oEmbed, use that instead. - url_to_download = url - oembed_url = self._oembed.get_oembed_url(url) - if oembed_url: - url_to_download = oembed_url - - media_info = await self._handle_url(url_to_download, user) - - logger.debug("got media_info of '%s'", media_info) - - # The number of milliseconds that the response should be considered valid. - expiration_ms = media_info.expires - author_name: Optional[str] = None - - if _is_media(media_info.media_type): - file_id = media_info.filesystem_id - dims = await self.media_repo._generate_thumbnails( - None, file_id, file_id, media_info.media_type, url_cache=True - ) - - og = { - "og:description": media_info.download_name, - "og:image": f"mxc://{self.server_name}/{media_info.filesystem_id}", - "og:image:type": media_info.media_type, - "matrix:image:size": media_info.media_length, - } - - if dims: - og["og:image:width"] = dims["width"] - og["og:image:height"] = dims["height"] - else: - logger.warning("Couldn't get dims for %s" % url) - - # define our OG response for this media - elif _is_html(media_info.media_type): - # TODO: somehow stop a big HTML tree from exploding synapse's RAM - - with open(media_info.filename, "rb") as file: - body = file.read() - - tree = decode_body(body, media_info.uri, media_info.media_type) - if tree is not None: - # Check if this HTML document points to oEmbed information and - # defer to that. - oembed_url = self._oembed.autodiscover_from_html(tree) - og_from_oembed: JsonDict = {} - if oembed_url: - try: - oembed_info = await self._handle_url( - oembed_url, user, allow_data_urls=True - ) - except Exception as e: - # Fetching the oEmbed info failed, don't block the entire URL preview. - logger.warning( - "oEmbed fetch failed during URL preview: %s errored with %s", - oembed_url, - e, - ) - else: - ( - og_from_oembed, - author_name, - expiration_ms, - ) = await self._handle_oembed_response( - url, oembed_info, expiration_ms - ) - - # Parse Open Graph information from the HTML in case the oEmbed - # response failed or is incomplete. - og_from_html = parse_html_to_open_graph(tree) - - # Compile the Open Graph response by using the scraped - # information from the HTML and overlaying any information - # from the oEmbed response. - og = {**og_from_html, **og_from_oembed} - - await self._precache_image_url(user, media_info, og) - else: - og = {} - - elif oembed_url: - # Handle the oEmbed information. - og, author_name, expiration_ms = await self._handle_oembed_response( - url, media_info, expiration_ms - ) - await self._precache_image_url(user, media_info, og) - - else: - logger.warning("Failed to find any OG data in %s", url) - og = {} - - # If we don't have a title but we have author_name, copy it as - # title - if not og.get("og:title") and author_name: - og["og:title"] = author_name - - # filter out any stupidly long values - keys_to_remove = [] - for k, v in og.items(): - # values can be numeric as well as strings, hence the cast to str - if len(k) > OG_TAG_NAME_MAXLEN or len(str(v)) > OG_TAG_VALUE_MAXLEN: - logger.warning( - "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN] - ) - keys_to_remove.append(k) - for k in keys_to_remove: - del og[k] - - logger.debug("Calculated OG for %s as %s", url, og) - - jsonog = json_encoder.encode(og) - - # Cap the amount of time to consider a response valid. - expiration_ms = min(expiration_ms, ONE_DAY) - - # store OG in history-aware DB cache - await self.store.store_url_cache( - url, - media_info.response_code, - media_info.etag, - media_info.created_ts_ms + expiration_ms, - jsonog, - media_info.filesystem_id, - media_info.created_ts_ms, - ) - - return jsonog.encode("utf8") - - async def _download_url(self, url: str, output_stream: BinaryIO) -> DownloadResult: - """ - Fetches a remote URL and parses the headers. - - Args: - url: The URL to fetch. - output_stream: The stream to write the content to. - - Returns: - A tuple of: - Media length, URL downloaded, the HTTP response code, - the media type, the downloaded file name, the number of - milliseconds the result is valid for, the etag header. - """ - - try: - logger.debug("Trying to get preview for url '%s'", url) - length, headers, uri, code = await self.client.get_file( - url, - output_stream=output_stream, - max_size=self.max_spider_size, - headers={ - b"Accept-Language": self.url_preview_accept_language, - # Use a custom user agent for the preview because some sites will only return - # Open Graph metadata to crawler user agents. Omit the Synapse version - # string to avoid leaking information. - b"User-Agent": [ - "Synapse (bot; +https://github.com/matrix-org/synapse)" - ], - }, - is_allowed_content_type=_is_previewable, - ) - except SynapseError: - # Pass SynapseErrors through directly, so that the servlet - # handler will return a SynapseError to the client instead of - # blank data or a 500. - raise - except DNSLookupError: - # DNS lookup returned no results - # Note: This will also be the case if one of the resolved IP - # addresses is blacklisted - raise SynapseError( - 502, - "DNS resolution failure during URL preview generation", - Codes.UNKNOWN, - ) - except Exception as e: - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading %s: %r", url, e) - - raise SynapseError( - 500, - "Failed to download content: %s" - % (traceback.format_exception_only(sys.exc_info()[0], e),), - Codes.UNKNOWN, - ) - - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") - else: - media_type = "application/octet-stream" - - download_name = get_filename_from_headers(headers) - - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - expires = ONE_HOUR - etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None - - return DownloadResult( - length, uri, code, media_type, download_name, expires, etag - ) - - async def _parse_data_url( - self, url: str, output_stream: BinaryIO - ) -> DownloadResult: - """ - Parses a data: URL. - - Args: - url: The URL to parse. - output_stream: The stream to write the content to. - - Returns: - A tuple of: - Media length, URL downloaded, the HTTP response code, - the media type, the downloaded file name, the number of - milliseconds the result is valid for, the etag header. - """ - - try: - logger.debug("Trying to parse data url '%s'", url) - with urlopen(url) as url_info: - # TODO Can this be more efficient. - output_stream.write(url_info.read()) - except Exception as e: - logger.warning("Error parsing data: URL %s: %r", url, e) - - raise SynapseError( - 500, - "Failed to parse data URL: %s" - % (traceback.format_exception_only(sys.exc_info()[0], e),), - Codes.UNKNOWN, - ) - - return DownloadResult( - # Read back the length that has been written. - length=output_stream.tell(), - uri=url, - # If it was parsed, consider this a 200 OK. - response_code=200, - # urlopen shoves the media-type from the data URL into the content type - # header object. - media_type=url_info.headers.get_content_type(), - # Some features are not supported by data: URLs. - download_name=None, - expires=ONE_HOUR, - etag=None, - ) - - async def _handle_url( - self, url: str, user: UserID, allow_data_urls: bool = False - ) -> MediaInfo: - """ - Fetches content from a URL and parses the result to generate a MediaInfo. - - It uses the media storage provider to persist the fetched content and - stores the mapping into the database. - - Args: - url: The URL to fetch. - user: The user who ahs requested this URL. - allow_data_urls: True if data URLs should be allowed. - - Returns: - A MediaInfo object describing the fetched content. - """ - - # TODO: we should probably honour robots.txt... except in practice - # we're most likely being explicitly triggered by a human rather than a - # bot, so are we really a robot? - - file_id = datetime.date.today().isoformat() + "_" + random_string(16) - - file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) - - with self.media_storage.store_into_file(file_info) as (f, fname, finish): - if url.startswith("data:"): - if not allow_data_urls: - raise SynapseError( - 500, "Previewing of data: URLs is forbidden", Codes.UNKNOWN - ) - - download_result = await self._parse_data_url(url, f) - else: - download_result = await self._download_url(url, f) - - await finish() - - try: - time_now_ms = self.clock.time_msec() - - await self.store.store_local_media( - media_id=file_id, - media_type=download_result.media_type, - time_now_ms=time_now_ms, - upload_name=download_result.download_name, - media_length=download_result.length, - user_id=user, - url_cache=url, - ) - - except Exception as e: - logger.error("Error handling downloaded %s: %r", url, e) - # TODO: we really ought to delete the downloaded file in this - # case, since we won't have recorded it in the db, and will - # therefore not expire it. - raise - - return MediaInfo( - media_type=download_result.media_type, - media_length=download_result.length, - download_name=download_result.download_name, - created_ts_ms=time_now_ms, - filesystem_id=file_id, - filename=fname, - uri=download_result.uri, - response_code=download_result.response_code, - expires=download_result.expires, - etag=download_result.etag, - ) - - async def _precache_image_url( - self, user: UserID, media_info: MediaInfo, og: JsonDict - ) -> None: - """ - Pre-cache the image (if one exists) for posterity - - Args: - user: The user requesting the preview. - media_info: The media being previewed. - og: The Open Graph dictionary. This is modified with image information. - """ - # If there's no image or it is blank, there's nothing to do. - if "og:image" not in og: - return - - # Remove the raw image URL, this will be replaced with an MXC URL, if successful. - image_url = og.pop("og:image") - if not image_url: - return - - # The image URL from the HTML might be relative to the previewed page, - # convert it to an URL which can be requested directly. - url_parts = urlparse(image_url) - if url_parts.scheme != "data": - image_url = urljoin(media_info.uri, image_url) - - # FIXME: it might be cleaner to use the same flow as the main /preview_url - # request itself and benefit from the same caching etc. But for now we - # just rely on the caching on the master request to speed things up. - try: - image_info = await self._handle_url(image_url, user, allow_data_urls=True) - except Exception as e: - # Pre-caching the image failed, don't block the entire URL preview. - logger.warning( - "Pre-caching image failed during URL preview: %s errored with %s", - image_url, - e, - ) - return - - if _is_media(image_info.media_type): - # TODO: make sure we don't choke on white-on-transparent images - file_id = image_info.filesystem_id - dims = await self.media_repo._generate_thumbnails( - None, file_id, file_id, image_info.media_type, url_cache=True - ) - if dims: - og["og:image:width"] = dims["width"] - og["og:image:height"] = dims["height"] - else: - logger.warning("Couldn't get dims for %s", image_url) - - og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}" - og["og:image:type"] = image_info.media_type - og["matrix:image:size"] = image_info.media_length - - async def _handle_oembed_response( - self, url: str, media_info: MediaInfo, expiration_ms: int - ) -> Tuple[JsonDict, Optional[str], int]: - """ - Parse the downloaded oEmbed info. - - Args: - url: The URL which is being previewed (not the one which was - requested). - media_info: The media being previewed. - expiration_ms: The length of time, in milliseconds, the media is valid for. - - Returns: - A tuple of: - The Open Graph dictionary, if the oEmbed info can be parsed. - The author name if it could be retrieved from oEmbed. - The (possibly updated) length of time, in milliseconds, the media is valid for. - """ - # If JSON was not returned, there's nothing to do. - if not _is_json(media_info.media_type): - return {}, None, expiration_ms - - with open(media_info.filename, "rb") as file: - body = file.read() - - oembed_response = self._oembed.parse_oembed_response(url, body) - open_graph_result = oembed_response.open_graph_result - - # Use the cache age from the oEmbed result, if one was given. - if open_graph_result and oembed_response.cache_age is not None: - expiration_ms = oembed_response.cache_age - - return open_graph_result, oembed_response.author_name, expiration_ms - - def _start_expire_url_cache_data(self) -> Deferred: - return run_as_background_process( - "expire_url_cache_data", self._expire_url_cache_data - ) - - async def _expire_url_cache_data(self) -> None: - """Clean up expired url cache content, media and thumbnails.""" - - assert self._worker_run_media_background_jobs - - now = self.clock.time_msec() - - logger.debug("Running url preview cache expiry") - - def try_remove_parent_dirs(dirs: Iterable[str]) -> None: - """Attempt to remove the given chain of parent directories - - Args: - dirs: The list of directory paths to delete, with children appearing - before their parents. - """ - for dir in dirs: - try: - os.rmdir(dir) - except FileNotFoundError: - # Already deleted, continue with deleting the rest - pass - except OSError as e: - # Failed, skip deleting the rest of the parent dirs - if e.errno != errno.ENOTEMPTY: - logger.warning( - "Failed to remove media directory while clearing url preview cache: %r: %s", - dir, - e, - ) - break - - # First we delete expired url cache entries - media_ids = await self.store.get_expired_url_cache(now) - - removed_media = [] - for media_id in media_ids: - fname = self.filepaths.url_cache_filepath(media_id) - try: - os.remove(fname) - except FileNotFoundError: - pass # If the path doesn't exist, meh - except OSError as e: - logger.warning( - "Failed to remove media while clearing url preview cache: %r: %s", - media_id, - e, - ) - continue - - removed_media.append(media_id) - - dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) - try_remove_parent_dirs(dirs) - - await self.store.delete_url_cache(removed_media) - - if removed_media: - logger.debug( - "Deleted %d entries from url preview cache", len(removed_media) - ) - else: - logger.debug("No entries removed from url preview cache") - - # Now we delete old images associated with the url cache. - # These may be cached for a bit on the client (i.e., they - # may have a room open with a preview url thing open). - # So we wait a couple of days before deleting, just in case. - expire_before = now - IMAGE_CACHE_EXPIRY_MS - media_ids = await self.store.get_url_cache_media_before(expire_before) - - removed_media = [] - for media_id in media_ids: - fname = self.filepaths.url_cache_filepath(media_id) - try: - os.remove(fname) - except FileNotFoundError: - pass # If the path doesn't exist, meh - except OSError as e: - logger.warning( - "Failed to remove media from url preview cache: %r: %s", media_id, e - ) - continue - - dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) - try_remove_parent_dirs(dirs) - - thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id) - try: - shutil.rmtree(thumbnail_dir) - except FileNotFoundError: - pass # If the path doesn't exist, meh - except OSError as e: - logger.warning( - "Failed to remove media from url preview cache: %r: %s", media_id, e - ) - continue - - removed_media.append(media_id) - - dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id) - # Note that one of the directories to be deleted has already been - # removed by the `rmtree` above. - try_remove_parent_dirs(dirs) - - await self.store.delete_url_cache_media(removed_media) - - if removed_media: - logger.debug("Deleted %d media from url preview cache", len(removed_media)) - else: - logger.debug("No media removed from url preview cache") - - -def _is_media(content_type: str) -> bool: - return content_type.lower().startswith("image/") - - -def _is_html(content_type: str) -> bool: - content_type = content_type.lower() - return content_type.startswith("text/html") or content_type.startswith( - "application/xhtml" - ) - - -def _is_json(content_type: str) -> bool: - return content_type.lower().startswith("application/json") - - -def _is_previewable(content_type: str) -> bool: - """Returns True for content types for which we will perform URL preview and False - otherwise.""" - - return _is_html(content_type) or _is_media(content_type) or _is_json(content_type) diff --git a/synapse/server.py b/synapse/server.py
index df80fc1beb..a191c19993 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -23,6 +23,8 @@ import functools import logging from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast +from typing_extensions import TypeAlias + from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.tcp import Port from twisted.web.iweb import IPolicyForHTTPS @@ -108,6 +110,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.media.media_repository import MediaRepository from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi +from synapse.module_api.callbacks import ModuleApiCallbacks from synapse.notifier import Notifier, ReplicationNotifier from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.push.pusherpool import PusherPool @@ -142,10 +145,31 @@ if TYPE_CHECKING: from synapse.handlers.saml import SamlHandler -T = TypeVar("T") +# The annotation for `cache_in_self` used to be +# def (builder: Callable[["HomeServer"],T]) -> Callable[["HomeServer"],T] +# which mypy was happy with. +# +# But PyCharm was confused by this. If `foo` was decorated by `@cache_in_self`, then +# an expression like `hs.foo()` +# +# - would erroneously warn that we hadn't provided a `hs` argument to foo (PyCharm +# confused about boundmethods and unbound methods?), and +# - would be considered to have type `Any`, making for a poor autocomplete and +# cross-referencing experience. +# +# Instead, use a typevar `F` to express that `@cache_in_self` returns exactly the +# same type it receives. This isn't strictly true [*], but it's more than good +# enough to keep PyCharm and mypy happy. +# +# [*]: (e.g. `builder` could be an object with a __call__ attribute rather than a +# types.FunctionType instance, whereas the return value is always a +# types.FunctionType instance.) + +T: TypeAlias = object +F = TypeVar("F", bound=Callable[["HomeServer"], T]) -def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer"], T]: +def cache_in_self(builder: F) -> F: """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and returning if so. If not, calls the given function and sets `self.foo` to it. @@ -183,7 +207,7 @@ def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer return dep - return _get + return cast(F, _get) class HomeServer(metaclass=abc.ABCMeta): @@ -778,6 +802,10 @@ class HomeServer(metaclass=abc.ABCMeta): return ModuleApi(self, self.get_auth_handler()) @cache_in_self + def get_module_api_callbacks(self) -> ModuleApiCallbacks: + return ModuleApiCallbacks() + + @cache_in_self def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py
index 9ca50d6a09..c599397b86 100644 --- a/synapse/storage/controllers/purge_events.py +++ b/synapse/storage/controllers/purge_events.py
@@ -16,6 +16,7 @@ import itertools import logging from typing import TYPE_CHECKING, Set +from synapse.logging.context import nested_logging_context from synapse.storage.databases import Databases if TYPE_CHECKING: @@ -33,8 +34,9 @@ class PurgeEventsStorageController: async def purge_room(self, room_id: str) -> None: """Deletes all record of a room""" - state_groups_to_delete = await self.stores.main.purge_room(room_id) - await self.stores.state.purge_room_state(room_id, state_groups_to_delete) + with nested_logging_context(room_id): + state_groups_to_delete = await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id, state_groups_to_delete) async def purge_history( self, room_id: str, token: str, delete_local_events: bool @@ -51,15 +53,17 @@ class PurgeEventsStorageController: (instead of just marking them as outliers and deleting their state groups). """ - state_groups = await self.stores.main.purge_history( - room_id, token, delete_local_events - ) - - logger.info("[purge] finding state groups that can be deleted") + with nested_logging_context(room_id): + state_groups = await self.stores.main.purge_history( + room_id, token, delete_local_events + ) - sg_to_delete = await self._find_unreferenced_groups(state_groups) + logger.info("[purge] finding state groups that can be deleted") + sg_to_delete = await self._find_unreferenced_groups(state_groups) - await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + await self.stores.state.purge_unreferenced_state_groups( + room_id, sg_to_delete + ) async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: """Used when purging history to figure out which state groups can be diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 5efe31aa19..fec4ae5b97 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -34,6 +34,7 @@ from typing import ( Tuple, Type, TypeVar, + Union, cast, overload, ) @@ -100,6 +101,15 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { } +class _PoolConnection(Connection): + """ + A Connection from twisted.enterprise.adbapi.Connection. + """ + + def reconnect(self) -> None: + ... + + def make_pool( reactor: IReactorCore, db_config: DatabaseConnectionConfig, @@ -856,7 +866,8 @@ class DatabasePool: try: with opentracing.start_active_span(f"db.{desc}"): result = await self.runWithConnection( - self.new_transaction, + # mypy seems to have an issue with this, maybe a bug? + self.new_transaction, # type: ignore[arg-type] desc, after_callbacks, async_after_callbacks, @@ -892,7 +903,7 @@ class DatabasePool: async def runWithConnection( self, - func: Callable[..., R], + func: Callable[Concatenate[LoggingDatabaseConnection, P], R], *args: Any, db_autocommit: bool = False, isolation_level: Optional[int] = None, @@ -926,7 +937,7 @@ class DatabasePool: start_time = monotonic_time() - def inner_func(conn, *args, **kwargs): + def inner_func(conn: _PoolConnection, *args: P.args, **kwargs: P.kwargs) -> R: # We shouldn't be in a transaction. If we are then something # somewhere hasn't committed after doing work. (This is likely only # possible during startup, as `run*` will ensure changes are @@ -1019,7 +1030,7 @@ class DatabasePool: decoder: Optional[Callable[[Cursor], R]], query: str, *args: Any, - ) -> R: + ) -> Union[List[Tuple[Any, ...]], R]: """Runs a single query for a result set. Args: @@ -1032,7 +1043,7 @@ class DatabasePool: The result of decoder(results) """ - def interaction(txn): + def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]: txn.execute(query, args) if decoder: return decoder(txn) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 9c41d01e13..7a7c0d9c75 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -325,6 +325,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # We then run the same purge a second time without this isolation level to # purge any of those rows which were added during the first. + logger.info("[purge] Starting initial main purge of [1/2]") state_groups_to_delete = await self.db_pool.runInteraction( "purge_room", self._purge_room_txn, @@ -332,6 +333,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): isolation_level=IsolationLevel.READ_COMMITTED, ) + logger.info("[purge] Starting secondary main purge of [2/2]") state_groups_to_delete.extend( await self.db_pool.runInteraction( "purge_room", @@ -339,6 +341,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): room_id=room_id, ), ) + logger.info("[purge] Done with main purge") return state_groups_to_delete @@ -376,7 +379,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) referenced_chain_id_tuples = list(txn) - logger.info("[purge] removing events from event_auth_chain_links") + logger.info("[purge] removing from event_auth_chain_links") txn.executemany( """ DELETE FROM event_auth_chain_links WHERE @@ -399,7 +402,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "rejections", "state_events", ): - logger.info("[purge] removing %s from %s", room_id, table) + logger.info("[purge] removing from %s", table) txn.execute( """ @@ -454,7 +457,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # happy "rooms", ): - logger.info("[purge] removing %s from %s", room_id, table) + logger.info("[purge] removing from %s", table) txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) # Other tables we do NOT need to clear out: @@ -486,6 +489,4 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # that already exist. self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,)) - logger.info("[purge] done") - return state_groups diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6d72bd9f67..c3bd36efc9 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -224,7 +224,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction( "set_destination_retry_timings", - self._set_destination_retry_timings_native, + self._set_destination_retry_timings_txn, destination, failure_ts, retry_last_ts, @@ -232,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): db_autocommit=True, # Safe as it's a single upsert ) - def _set_destination_retry_timings_native( + def _set_destination_retry_timings_txn( self, txn: LoggingTransaction, destination: str, @@ -266,58 +266,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): txn, self.get_destination_retry_timings, (destination,) ) - def _set_destination_retry_timings_emulated( - self, - txn: LoggingTransaction, - destination: str, - failure_ts: Optional[int], - retry_last_ts: int, - retry_interval: int, - ) -> None: - self.database_engine.lock_table(txn, "destinations") - - # We need to be careful here as the data may have changed from under us - # due to a worker setting the timings. - - prev_row = self.db_pool.simple_select_one_txn( - txn, - table="destinations", - keyvalues={"destination": destination}, - retcols=("failure_ts", "retry_last_ts", "retry_interval"), - allow_none=True, - ) - - if not prev_row: - self.db_pool.simple_insert_txn( - txn, - table="destinations", - values={ - "destination": destination, - "failure_ts": failure_ts, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - ) - elif ( - retry_interval == 0 - or prev_row["retry_interval"] is None - or prev_row["retry_interval"] < retry_interval - ): - self.db_pool.simple_update_one_txn( - txn, - "destinations", - keyvalues={"destination": destination}, - updatevalues={ - "failure_ts": failure_ts, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - ) - - self._invalidate_cache_and_stream( - txn, self.get_destination_retry_timings, (destination,) - ) - async def store_destination_rooms_entries( self, destinations: Iterable[str], diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f16a509ac4..97f09b73dd 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -54,6 +54,7 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import ( JsonDict, + UserID, UserProfile, get_domain_from_id, get_localpart_from_id, @@ -473,11 +474,116 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return False + async def set_remote_user_profile_in_user_dir_stale( + self, user_id: str, next_try_at_ms: int, retry_counter: int + ) -> None: + """ + Marks a remote user as having a possibly-stale user directory profile. + + Args: + user_id: the remote user who may have a stale profile on this server. + next_try_at_ms: timestamp in ms after which the user directory profile can be + refreshed. + retry_counter: number of failures in refreshing the profile so far. Used for + exponential backoff calculations. + """ + assert not self.hs.is_mine_id( + user_id + ), "Can't mark a local user as a stale remote user." + + server_name = UserID.from_string(user_id).domain + + await self.db_pool.simple_upsert( + table="user_directory_stale_remote_users", + keyvalues={"user_id": user_id}, + values={ + "next_try_at_ts": next_try_at_ms, + "retry_counter": retry_counter, + "user_server_name": server_name, + }, + desc="set_remote_user_profile_in_user_dir_stale", + ) + + async def clear_remote_user_profile_in_user_dir_stale(self, user_id: str) -> None: + """ + Marks a remote user as no longer having a possibly-stale user directory profile. + + Args: + user_id: the remote user who no longer has a stale profile on this server. + """ + await self.db_pool.simple_delete( + table="user_directory_stale_remote_users", + keyvalues={"user_id": user_id}, + desc="clear_remote_user_profile_in_user_dir_stale", + ) + + async def get_remote_servers_with_profiles_to_refresh( + self, now_ts: int, limit: int + ) -> List[str]: + """ + Get a list of up to `limit` server names which have users whose + locally-cached profiles we believe to be stale + and are refreshable given the current time `now_ts` in milliseconds. + """ + + def _get_remote_servers_with_refreshable_profiles_txn( + txn: LoggingTransaction, + ) -> List[str]: + sql = """ + SELECT user_server_name + FROM user_directory_stale_remote_users + WHERE next_try_at_ts < ? + GROUP BY user_server_name + ORDER BY MIN(next_try_at_ts), user_server_name + LIMIT ? + """ + txn.execute(sql, (now_ts, limit)) + return [row[0] for row in txn] + + return await self.db_pool.runInteraction( + "get_remote_servers_with_profiles_to_refresh", + _get_remote_servers_with_refreshable_profiles_txn, + ) + + async def get_remote_users_to_refresh_on_server( + self, server_name: str, now_ts: int, limit: int + ) -> List[Tuple[str, int, int]]: + """ + Get a list of up to `limit` user IDs from the server `server_name` + whose locally-cached profiles we believe to be stale + and are refreshable given the current time `now_ts` in milliseconds. + + Returns: + tuple of: + - User ID + - Retry counter (number of failures so far) + - Time the retry is scheduled for, in milliseconds + """ + + def _get_remote_users_to_refresh_on_server_txn( + txn: LoggingTransaction, + ) -> List[Tuple[str, int, int]]: + sql = """ + SELECT user_id, retry_counter, next_try_at_ts + FROM user_directory_stale_remote_users + WHERE user_server_name = ? AND next_try_at_ts < ? + ORDER BY next_try_at_ts + LIMIT ? + """ + txn.execute(sql, (server_name, now_ts, limit)) + return cast(List[Tuple[str, int, int]], txn.fetchall()) + + return await self.db_pool.runInteraction( + "get_remote_users_to_refresh_on_server", + _get_remote_users_to_refresh_on_server_txn, + ) + async def update_profile_in_user_dir( self, user_id: str, display_name: Optional[str], avatar_url: Optional[str] ) -> None: """ Update or add a user's profile in the user directory. + If the user is remote, the profile will be marked as not stale. """ # If the display name or avatar URL are unexpected types, replace with None. display_name = non_null_str_or_none(display_name) @@ -491,6 +597,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): values={"display_name": display_name, "avatar_url": avatar_url}, ) + if not self.hs.is_mine_id(user_id): + # Remote users: Make sure the profile is not marked as stale anymore. + self.db_pool.simple_delete_txn( + txn, + table="user_directory_stale_remote_users", + keyvalues={"user_id": user_id}, + ) + # The display name that goes into the database index. index_display_name = display_name if index_display_name is not None: diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index bf4cdfdf29..29ff64e876 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py
@@ -805,12 +805,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_groups_to_delete: State groups to delete """ + logger.info("[purge] Starting state purge") await self.db_pool.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, state_groups_to_delete, ) + logger.info("[purge] Done with state purge") def _purge_room_state_txn( self, diff --git a/synapse/storage/schema/main/delta/74/01_user_directory_stale_remote_users.sql b/synapse/storage/schema/main/delta/74/01_user_directory_stale_remote_users.sql new file mode 100644
index 0000000000..dcb38f3d7b --- /dev/null +++ b/synapse/storage/schema/main/delta/74/01_user_directory_stale_remote_users.sql
@@ -0,0 +1,39 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Table containing a list of remote users whose profiles may have changed +-- since their last update in the user directory. +CREATE TABLE user_directory_stale_remote_users ( + -- The User ID of the remote user whose profile may be stale. + user_id TEXT NOT NULL PRIMARY KEY, + + -- The server name of the user. + user_server_name TEXT NOT NULL, + + -- The timestamp (in ms) after which we should next try to request the user's + -- latest profile. + next_try_at_ts BIGINT NOT NULL, + + -- The number of retries so far. + -- 0 means we have not yet attempted to refresh the profile. + -- Used for calculating exponential backoff. + retry_counter INTEGER NOT NULL +); + +-- Create an index so we can easily query upcoming servers to try. +CREATE INDEX user_directory_stale_remote_users_next_try_idx ON user_directory_stale_remote_users(next_try_at_ts, user_server_name); + +-- Create an index so we can easily query upcoming users to try for a particular server. +CREATE INDEX user_directory_stale_remote_users_next_try_by_server_idx ON user_directory_stale_remote_users(user_server_name, next_try_at_ts); diff --git a/synapse/storage/schema/main/delta/74/90COMMENTS_destinations.sql.postgres b/synapse/storage/schema/main/delta/74/90COMMENTS_destinations.sql.postgres new file mode 100644
index 0000000000..cc7dda1a11 --- /dev/null +++ b/synapse/storage/schema/main/delta/74/90COMMENTS_destinations.sql.postgres
@@ -0,0 +1,52 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +--- destinations +COMMENT ON TABLE destinations IS + 'Information about remote homeservers and the health of our connection to them.'; + +COMMENT ON COLUMN destinations.destination IS 'server name of remote homeserver in question'; + +COMMENT ON COLUMN destinations.last_successful_stream_ordering IS +$$Stream ordering of the most recently successfully sent PDU to this server, sent through normal send (not e.g. backfill). +In Catch-Up Mode, the original PDU persisted by us is represented here, even if we sent a later forward extremity in its stead. +See `destination_rooms` for more information about catch-up.$$; + +COMMENT ON COLUMN destinations.retry_last_ts IS +$$The last time we tried and failed to reach the remote server, in ms. +This field is reset to `0` when we succeed in connecting again.$$; + +COMMENT ON COLUMN destinations.retry_interval IS +$$How long, in milliseconds, to wait since the last time we tried to reach the remote server before trying again. +This field is reset to `0` when we succeed in connecting again.$$; + +COMMENT ON COLUMN destinations.failure_ts IS +$$The first time we tried and failed to reach the remote server, in ms. +This field is reset to `NULL` when we succeed in connecting again.$$; + + + +--- destination_rooms +COMMENT ON TABLE destination_rooms IS + 'Information about transmission of PDUs in a given room to a given remote homeserver.'; + +COMMENT ON COLUMN destination_rooms.destination IS 'server name of remote homeserver in question'; + +COMMENT ON COLUMN destination_rooms.room_id IS 'room ID in question'; + +COMMENT ON COLUMN destination_rooms.stream_ordering IS +$$`stream_ordering` of the most recent PDU in this room that needs to be sent (by us) to this homeserver. +This can only be pointing to our own PDU because we are only responsible for sending our own PDUs.$$;