summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorH. Shay <hillerys@element.io>2023-03-06 12:48:59 -0800
committerH. Shay <hillerys@element.io>2023-03-06 12:48:59 -0800
commit7fc487421f19d8841c36481701655bc31bfcd79f (patch)
treee5c4e634fa957a561675156c868c561c36a71d66 /synapse
parentadd clearer return values (diff)
parentPass the requester during event serialization. (#15174) (diff)
downloadsynapse-7fc487421f19d8841c36481701655bc31bfcd79f.tar.xz
Merge branch 'develop' into shay/rework_module
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py5
-rwxr-xr-xsynapse/_scripts/move_remote_media_to_new_store.py2
-rw-r--r--synapse/_scripts/register_new_matrix_user.py2
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py86
-rwxr-xr-xsynapse/_scripts/synctl.py1
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/app/admin_cmd.py39
-rw-r--r--synapse/app/complement_fork_starter.py2
-rw-r--r--synapse/app/generic_worker.py1
-rw-r--r--synapse/app/homeserver.py1
-rw-r--r--synapse/config/consent.py1
-rw-r--r--synapse/config/database.py1
-rw-r--r--synapse/config/experimental.py25
-rw-r--r--synapse/config/homeserver.py1
-rw-r--r--synapse/config/ratelimiting.py14
-rw-r--r--synapse/config/repository.py13
-rw-r--r--synapse/config/server.py2
-rw-r--r--synapse/config/tls.py1
-rw-r--r--synapse/crypto/keyring.py2
-rw-r--r--synapse/events/snapshot.py49
-rw-r--r--synapse/events/spamcheck.py4
-rw-r--r--synapse/events/third_party_rules.py65
-rw-r--r--synapse/events/utils.py92
-rw-r--r--synapse/federation/send_queue.py4
-rw-r--r--synapse/handlers/account_data.py7
-rw-r--r--synapse/handlers/admin.py87
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py53
-rw-r--r--synapse/handlers/deactivate_account.py20
-rw-r--r--synapse/handlers/directory.py8
-rw-r--r--synapse/handlers/e2e_keys.py14
-rw-r--r--synapse/handlers/e2e_room_keys.py1
-rw-r--r--synapse/handlers/event_auth.py1
-rw-r--r--synapse/handlers/events.py20
-rw-r--r--synapse/handlers/federation.py16
-rw-r--r--synapse/handlers/initial_sync.py52
-rw-r--r--synapse/handlers/message.py30
-rw-r--r--synapse/handlers/pagination.py4
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/relations.py88
-rw-r--r--synapse/handlers/room.py84
-rw-r--r--synapse/handlers/room_batch.py10
-rw-r--r--synapse/handlers/room_member.py11
-rw-r--r--synapse/handlers/room_member_worker.py4
-rw-r--r--synapse/handlers/search.py43
-rw-r--r--synapse/handlers/sync.py1
-rw-r--r--synapse/handlers/ui_auth/checkers.py18
-rw-r--r--synapse/http/client.py5
-rw-r--r--synapse/http/matrixfederationclient.py18
-rw-r--r--synapse/logging/opentracing.py25
-rw-r--r--synapse/media/_base.py479
-rw-r--r--synapse/media/filepath.py (renamed from synapse/rest/media/v1/filepath.py)0
-rw-r--r--synapse/media/media_repository.py (renamed from synapse/rest/media/v1/media_repository.py)94
-rw-r--r--synapse/media/media_storage.py374
-rw-r--r--synapse/media/oembed.py (renamed from synapse/rest/media/v1/oembed.py)2
-rw-r--r--synapse/media/preview_html.py (renamed from synapse/rest/media/v1/preview_html.py)0
-rw-r--r--synapse/media/storage_provider.py181
-rw-r--r--synapse/media/thumbnailer.py (renamed from synapse/rest/media/v1/thumbnailer.py)1
-rw-r--r--synapse/metrics/__init__.py1
-rw-r--r--synapse/metrics/_gc.py1
-rw-r--r--synapse/module_api/__init__.py16
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py36
-rw-r--r--synapse/push/clientformat.py11
-rw-r--r--synapse/replication/http/account_data.py1
-rw-r--r--synapse/replication/http/devices.py1
-rw-r--r--synapse/replication/http/membership.py15
-rw-r--r--synapse/replication/tcp/client.py12
-rw-r--r--synapse/replication/tcp/redis.py1
-rw-r--r--synapse/replication/tcp/resource.py18
-rw-r--r--synapse/replication/tcp/streams/events.py1
-rw-r--r--synapse/rest/__init__.py5
-rw-r--r--synapse/rest/admin/event_reports.py41
-rw-r--r--synapse/rest/admin/rooms.py4
-rw-r--r--synapse/rest/admin/users.py19
-rw-r--r--synapse/rest/client/account.py9
-rw-r--r--synapse/rest/client/auth.py1
-rw-r--r--synapse/rest/client/devices.py2
-rw-r--r--synapse/rest/client/events.py16
-rw-r--r--synapse/rest/client/filter.py1
-rw-r--r--synapse/rest/client/keys.py32
-rw-r--r--synapse/rest/client/knock.py17
-rw-r--r--synapse/rest/client/notifications.py12
-rw-r--r--synapse/rest/client/register.py18
-rw-r--r--synapse/rest/client/room.py24
-rw-r--r--synapse/rest/client/room_batch.py16
-rw-r--r--synapse/rest/client/sync.py35
-rw-r--r--synapse/rest/media/config_resource.py (renamed from synapse/rest/media/v1/config_resource.py)0
-rw-r--r--synapse/rest/media/download_resource.py (renamed from synapse/rest/media/v1/download_resource.py)5
-rw-r--r--synapse/rest/media/media_repository_resource.py93
-rw-r--r--synapse/rest/media/preview_url_resource.py (renamed from synapse/rest/media/v1/preview_url_resource.py)45
-rw-r--r--synapse/rest/media/thumbnail_resource.py (renamed from synapse/rest/media/v1/thumbnail_resource.py)10
-rw-r--r--synapse/rest/media/upload_resource.py (renamed from synapse/rest/media/v1/upload_resource.py)4
-rw-r--r--synapse/rest/media/v1/_base.py468
-rw-r--r--synapse/rest/media/v1/media_storage.py365
-rw-r--r--synapse/rest/media/v1/storage_provider.py172
-rw-r--r--synapse/server.py8
-rw-r--r--synapse/server_notices/server_notices_manager.py3
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/account_data.py76
-rw-r--r--synapse/storage/databases/main/cache.py3
-rw-r--r--synapse/storage/databases/main/deviceinbox.py5
-rw-r--r--synapse/storage/databases/main/devices.py11
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py8
-rw-r--r--synapse/storage/databases/main/event_federation.py1
-rw-r--r--synapse/storage/databases/main/events.py5
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py7
-rw-r--r--synapse/storage/databases/main/filtering.py25
-rw-r--r--synapse/storage/databases/main/media_repository.py1
-rw-r--r--synapse/storage/databases/main/push_rule.py3
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py7
-rw-r--r--synapse/storage/databases/main/registration.py13
-rw-r--r--synapse/storage/databases/main/relations.py137
-rw-r--r--synapse/storage/databases/main/room.py391
-rw-r--r--synapse/storage/databases/main/search.py2
-rw-r--r--synapse/storage/databases/main/state.py1
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/stream.py1
-rw-r--r--synapse/storage/databases/main/transactions.py1
-rw-r--r--synapse/storage/databases/main/user_directory.py77
-rw-r--r--synapse/storage/databases/state/bg_updates.py1
-rw-r--r--synapse/storage/databases/state/store.py126
-rw-r--r--synapse/storage/engines/__init__.py4
-rw-r--r--synapse/storage/prepare_database.py4
-rw-r--r--synapse/storage/types.py74
-rw-r--r--synapse/storage/util/id_generators.py62
-rw-r--r--synapse/storage/util/sequence.py2
-rw-r--r--synapse/streams/__init__.py7
-rw-r--r--synapse/types/state.py2
-rw-r--r--synapse/util/caches/__init__.py1
-rw-r--r--synapse/util/check_dependencies.py2
-rw-r--r--synapse/util/patch_inline_callbacks.py1
135 files changed, 2656 insertions, 2139 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index fbfd506a43..a203ed533a 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -1,5 +1,6 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018-9 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" This is a reference implementation of a Matrix homeserver. +""" This is an implementation of a Matrix homeserver. """ import json diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py
index 819afaaca6..0dd36bee20 100755 --- a/synapse/_scripts/move_remote_media_to_new_store.py +++ b/synapse/_scripts/move_remote_media_to_new_store.py
@@ -37,7 +37,7 @@ import os import shutil import sys -from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.media.filepath import MediaFilePaths logger = logging.getLogger() diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 2b74a40166..19ca399d44 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py
@@ -47,7 +47,6 @@ def request_registration( _print: Callable[[str], None] = print, exit: Callable[[int], None] = sys.exit, ) -> None: - url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) # Get the nonce @@ -154,7 +153,6 @@ def register_new_user( def main() -> None: - logging.captureWarnings(True) parser = argparse.ArgumentParser( diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 5e137dbbf7..2c9cbf8b27 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py
@@ -94,61 +94,80 @@ reactor = cast(ISynapseReactor, reactor_) logger = logging.getLogger("synapse_port_db") +# SQLite doesn't have a dedicated boolean type (it stores True/False as 1/0). This means +# portdb will read sqlite bools as integers, then try to insert them into postgres +# boolean columns---which fails. Lacking some Python-parseable metaschema, we must +# specify which integer columns should be inserted as booleans into postgres. BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url"], - "rooms": ["is_public", "has_auth_chain_index"], + "access_tokens": ["used"], + "account_validity": ["email_sent"], + "device_lists_changes_in_room": ["converted_to_destinations"], + "device_lists_outbound_pokes": ["sent"], + "devices": ["hidden"], + "e2e_fallback_keys_json": ["used"], + "e2e_room_keys": ["is_verified"], "event_edges": ["is_state"], + "events": ["processed", "outlier", "contains_url"], + "local_media_repository": ["safe_from_quarantine"], "presence_list": ["accepted"], "presence_stream": ["currently_active"], "public_room_list_stream": ["visibility"], - "devices": ["hidden"], - "device_lists_outbound_pokes": ["sent"], - "users_who_share_rooms": ["share_private"], - "e2e_room_keys": ["is_verified"], - "account_validity": ["email_sent"], + "pushers": ["enabled"], "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], - "local_media_repository": ["safe_from_quarantine"], + "rooms": ["is_public", "has_auth_chain_index"], "users": ["shadow_banned", "approved"], - "e2e_fallback_keys_json": ["used"], - "access_tokens": ["used"], - "device_lists_changes_in_room": ["converted_to_destinations"], - "pushers": ["enabled"], + "un_partial_stated_event_stream": ["rejection_status_changed"], + "users_who_share_rooms": ["share_private"], } +# These tables are never deleted from in normal operation [*], so we can resume porting +# over rows from a previous attempt rather than starting from scratch. +# +# [*]: We do delete from many of these tables when purging a room, and +# presumably when purging old events. So we might e.g. +# +# 1. Run portdb and port half of some table. +# 2. Stop portdb. +# 3. Purge something, deleting some of the rows we've ported over. +# 4. Restart portdb. The rows deleted from sqlite are still present in postgres. +# +# But this isn't the end of the world: we should be able to repeat the purge +# on the postgres DB when porting completes. APPEND_ONLY_TABLES = [ + "cache_invalidation_stream_by_instance", + "event_auth", + "event_edges", + "event_json", "event_reference_hashes", + "event_search", + "event_to_state_groups", "events", - "event_json", - "state_events", - "room_memberships", - "topics", - "room_names", - "rooms", + "ex_outlier_stream", "local_media_repository", "local_media_repository_thumbnails", + "presence_stream", + "public_room_list_stream", + "push_rules_stream", + "received_transactions", + "redactions", + "rejections", "remote_media_cache", "remote_media_cache_thumbnails", - "redactions", - "event_edges", - "event_auth", - "received_transactions", + "room_memberships", + "room_names", + "rooms", "sent_transactions", - "transaction_id_to_pdu", - "users", + "state_events", + "state_group_edges", "state_groups", "state_groups_state", - "event_to_state_groups", - "rejections", - "event_search", - "presence_stream", - "push_rules_stream", - "ex_outlier_stream", - "cache_invalidation_stream_by_instance", - "public_room_list_stream", - "state_group_edges", "stream_ordering_to_exterm", + "topics", + "transaction_id_to_pdu", + "un_partial_stated_event_stream", + "users", ] @@ -1186,7 +1205,6 @@ class CursesProgress(Progress): if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) else: - if self.total_processed > 0: left = float(self.total_remaining) / self.total_processed diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py
index b4c96ad7f3..077b90935e 100755 --- a/synapse/_scripts/synctl.py +++ b/synapse/_scripts/synctl.py
@@ -167,7 +167,6 @@ Worker = collections.namedtuple( def main() -> None: - parser = argparse.ArgumentParser() parser.add_argument( diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index a5aa2185a2..28062dd69d 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py
@@ -213,7 +213,7 @@ def handle_startup_exception(e: Exception) -> NoReturn: def redirect_stdio_to_logs() -> None: streams = [("stdout", LogLevel.info), ("stderr", LogLevel.error)] - for (stream, level) in streams: + for stream, level in streams: oldStream = getattr(sys, stream) loggingFile = LoggingFile( logger=twisted.logger.Logger(namespace=stream), diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index fe7afb9475..b05fe2c589 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py
@@ -17,7 +17,7 @@ import logging import os import sys import tempfile -from typing import List, Optional +from typing import List, Mapping, Optional from twisted.internet import defer, task @@ -44,6 +44,7 @@ from synapse.storage.databases.main.event_push_actions import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.profile import ProfileWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore @@ -86,6 +87,7 @@ class AdminCmdSlavedStore( RegistrationWorkerStore, RoomWorkerStore, ProfileWorkerStore, + MediaRepositoryStore, ): def __init__( self, @@ -149,7 +151,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(events_file, "a") as f: for event in events: - print(json.dumps(event.get_pdu_json()), file=f) + json.dump(event.get_pdu_json(), fp=f) def write_state( self, room_id: str, event_id: str, state: StateMap[EventBase] @@ -162,7 +164,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(event_file, "a") as f: for event in state.values(): - print(json.dumps(event.get_pdu_json()), file=f) + json.dump(event.get_pdu_json(), fp=f) def write_invite( self, room_id: str, event: EventBase, state: StateMap[EventBase] @@ -178,7 +180,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(invite_state, "a") as f: for event in state.values(): - print(json.dumps(event), file=f) + json.dump(event, fp=f) def write_knock( self, room_id: str, event: EventBase, state: StateMap[EventBase] @@ -194,7 +196,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(knock_state, "a") as f: for event in state.values(): - print(json.dumps(event), file=f) + json.dump(event, fp=f) def write_profile(self, profile: JsonDict) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -202,7 +204,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): profile_file = os.path.join(user_directory, "profile") with open(profile_file, "a") as f: - print(json.dumps(profile), file=f) + json.dump(profile, fp=f) def write_devices(self, devices: List[JsonDict]) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -211,7 +213,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): for device in devices: with open(device_file, "a") as f: - print(json.dumps(device), file=f) + json.dump(device, fp=f) def write_connections(self, connections: List[JsonDict]) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -220,7 +222,28 @@ class FileExfiltrationWriter(ExfiltrationWriter): for connection in connections: with open(connection_file, "a") as f: - print(json.dumps(connection), file=f) + json.dump(connection, fp=f) + + def write_account_data( + self, file_name: str, account_data: Mapping[str, JsonDict] + ) -> None: + account_data_directory = os.path.join( + self.base_directory, "user_data", "account_data" + ) + os.makedirs(account_data_directory, exist_ok=True) + + account_data_file = os.path.join(account_data_directory, file_name) + + with open(account_data_file, "a") as f: + json.dump(account_data, fp=f) + + def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + file_directory = os.path.join(self.base_directory, "media_ids") + os.makedirs(file_directory, exist_ok=True) + media_id_file = os.path.join(file_directory, media_id) + + with open(media_id_file, "w") as f: + json.dump(media_metadata, fp=f) def finished(self) -> str: return self.base_directory diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py
index 920538f44d..c8dc3f9d76 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py
@@ -219,7 +219,7 @@ def main() -> None: # memory space and don't need to repeat the work of loading the code! # Instead of using fork() directly, we use the multiprocessing library, # which uses fork() on Unix platforms. - for (func, worker_args) in zip(worker_functions, args_by_worker): + for func, worker_args in zip(worker_functions, args_by_worker): process = multiprocessing.Process( target=_worker_entrypoint, args=(func, proxy_reactor, worker_args) ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 946f3a3807..0dec24369a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -157,7 +157,6 @@ class GenericWorkerServer(HomeServer): DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore def _listen_http(self, listener_config: ListenerConfig) -> None: - assert listener_config.http_options is not None # We always include a health resource. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6176a70eb2..b8830b1a9c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -321,7 +321,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer: and not config.registration.registrations_require_3pid and not config.registration.registration_requires_token ): - raise ConfigError( "You have enabled open registration without any verification. This is a known vector for " "spam and abuse. If you would like to allow public registration, please consider adding email, " diff --git a/synapse/config/consent.py b/synapse/config/consent.py
index be74609dc4..5bfd0cbb71 100644 --- a/synapse/config/consent.py +++ b/synapse/config/consent.py
@@ -22,7 +22,6 @@ from ._base import Config class ConsentConfig(Config): - section = "consent" def __init__(self, *args: Any): diff --git a/synapse/config/database.py b/synapse/config/database.py
index 928fec8dfe..596d8769fe 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py
@@ -154,7 +154,6 @@ class DatabaseConfig(Config): logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) def set_databasepath(self, database_path: str) -> None: - if database_path != ":memory:": database_path = self.abspath(database_path) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 1d294f8798..489f2601ac 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -166,22 +166,20 @@ class ExperimentalConfig(Config): # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) - # MSC3925: do not replace events with their edits - self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) - - # MSC3758: exact_event_match push rule condition - self.msc3758_exact_event_match = experimental.get( - "msc3758_exact_event_match", False + # MSC3873: Disambiguate event_match keys. + self.msc3873_escape_event_match_key = experimental.get( + "msc3873_escape_event_match_key", False ) - # MSC3873: Disambiguate event_match keys. - self.msc3783_escape_event_match_key = experimental.get( - "msc3783_escape_event_match_key", False + # MSC3966: exact_event_property_contains push rule condition. + self.msc3966_exact_event_property_contains = experimental.get( + "msc3966_exact_event_property_contains", False ) - # MSC3952: Intentional mentions - self.msc3952_intentional_mentions = experimental.get( - "msc3952_intentional_mentions", False + # MSC3952: Intentional mentions, this depends on MSC3966. + self.msc3952_intentional_mentions = ( + experimental.get("msc3952_intentional_mentions", False) + and self.msc3966_exact_event_property_contains ) # MSC3959: Do not generate notifications for edits. @@ -193,3 +191,6 @@ class ExperimentalConfig(Config): self.msc3966_exact_event_property_contains = experimental.get( "msc3966_exact_event_property_contains", False ) + + # MSC3967: Do not require UIA when first uploading cross signing keys + self.msc3967_enabled = experimental.get("msc3967_enabled", False) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 4d2b298a70..c205a78039 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py
@@ -56,7 +56,6 @@ from .workers import WorkerConfig class HomeServerConfig(RootConfig): - config_classes = [ ModulesConfig, ServerConfig, diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 5c13fe428a..a5514e70a2 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py
@@ -46,7 +46,6 @@ class RatelimitConfig(Config): section = "ratelimiting" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: @@ -87,9 +86,18 @@ class RatelimitConfig(Config): defaults={"per_second": 0.1, "burst_count": 5}, ) + # It is reasonable to login with a bunch of devices at once (i.e. when + # setting up an account), but it is *not* valid to continually be + # logging into new devices. rc_login_config = config.get("rc_login", {}) - self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {})) - self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {})) + self.rc_login_address = RatelimitSettings( + rc_login_config.get("address", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + self.rc_login_account = RatelimitSettings( + rc_login_config.get("account", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) self.rc_login_failed_attempts = RatelimitSettings( rc_login_config.get("failed_attempts", {}) ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index e4759711ed..ecb3edbe3a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -116,7 +116,6 @@ class ContentRepositoryConfig(Config): section = "media" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( @@ -179,11 +178,13 @@ class ContentRepositoryConfig(Config): for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to # expose FileStorageProviderBackend - if provider_config["module"] == "file_system": - provider_config["module"] = ( - "synapse.rest.media.v1.storage_provider" - ".FileStorageProviderBackend" - ) + if ( + provider_config["module"] == "file_system" + or provider_config["module"] == "synapse.rest.media.v1.storage_provider" + ): + provider_config[ + "module" + ] = "synapse.media.storage_provider.FileStorageProviderBackend" provider_class, parsed_config = load_module( provider_config, ("media_storage_providers", "<item %i>" % i) diff --git a/synapse/config/server.py b/synapse/config/server.py
index ecdaa2d9dd..0e46b849cf 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -177,6 +177,7 @@ KNOWN_RESOURCES = { "client", "consent", "federation", + "health", "keys", "media", "metrics", @@ -734,7 +735,6 @@ class ServerConfig(Config): listeners: Optional[List[dict]], **kwargs: Any, ) -> str: - _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: unsecure_port = bind_port - 400 diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 336fe3e0da..318270ebb8 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py
@@ -30,7 +30,6 @@ class TlsConfig(Config): section = "tls" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.tls_certificate_file = self.abspath(config.get("tls_certificate_path")) self.tls_private_key_file = self.abspath(config.get("tls_private_key_path")) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 86cd4af9bd..d710607c63 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -399,7 +399,7 @@ class Keyring: # We now convert the returned list of results into a map from server # name to key ID to FetchKeyResult, to return. to_return: Dict[str, Dict[str, FetchKeyResult]] = {} - for (request, results) in zip(deduped_requests, results_per_request): + for request, results in zip(deduped_requests, results_per_request): to_return_by_server = to_return.setdefault(request.server_name, {}) for key_id, key_result in results.items(): existing = to_return_by_server.get(key_id) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e0d82ad81c..a91a5d1e3c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py
@@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers + from synapse.storage.databases import StateGroupDataStore from synapse.storage.databases.main import DataStore from synapse.types.state import StateFilter @@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase): partial_state: bool state_map_before_event: Optional[StateMap[str]] = None + @classmethod + async def batch_persist_unpersisted_contexts( + cls, + events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]], + room_id: str, + last_known_state_group: int, + datastore: "StateGroupDataStore", + ) -> List[Tuple[EventBase, EventContext]]: + """ + Takes a list of events and their associated unpersisted contexts and persists + the unpersisted contexts, returning a list of events and persisted contexts. + Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: A list of events and their unpersisted contexts + room_id: the room_id for the events + last_known_state_group: the last persisted state group + datastore: a state datastore + """ + amended_events_and_context = await datastore.store_state_deltas_for_batched( + events_and_context, room_id, last_known_state_group + ) + + events_and_persisted_context = [] + for event, unpersisted_context in amended_events_and_context: + if event.is_state(): + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, + ) + else: + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.prev_group_for_state_group_before_event, + delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, + ) + events_and_persisted_context.append((event, context)) + return events_and_persisted_context + async def get_prev_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> StateMap[str]: diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 623a2c71ea..765c15bb51 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py
@@ -33,8 +33,8 @@ from typing_extensions import Literal import synapse from synapse.api.errors import Codes from synapse.logging.opentracing import trace -from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.media_storage import ReadableFileWrapper +from synapse.media._base import FileInfo +from synapse.media.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import JsonDict, RoomAlias, UserProfile from synapse.util.async_helpers import delay_cancellation, maybe_awaitable diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index e7532a568a..1b7c6de974 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py
@@ -49,6 +49,8 @@ CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable] +ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable] +ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -82,7 +84,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: # correctly, we need to await its result. Therefore it doesn't make a lot of # sense to make it go through the run() wrapper. if f.__name__ == "check_event_allowed": - # We need to wrap check_event_allowed because its old form would return either # a boolean or a dict, but now we want to return the dict separately from the # boolean. @@ -104,7 +105,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: return wrap_check_event_allowed if f.__name__ == "on_create_room": - # We need to wrap on_create_room because its old form would return a boolean # if the room creation is denied, but now we just want it to raise an # exception. @@ -181,6 +181,12 @@ class ThirdPartyEventRules: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = [] self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = [] + self._on_add_user_third_party_identifier_callbacks: List[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = [] + self._on_remove_user_third_party_identifier_callbacks: List[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = [] def register_third_party_rules_callbacks( self, @@ -200,6 +206,12 @@ class ThirdPartyEventRules: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, + on_add_user_third_party_identifier: Optional[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, + on_remove_user_third_party_identifier: Optional[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -237,6 +249,11 @@ class ThirdPartyEventRules: if on_threepid_bind is not None: self._on_threepid_bind_callbacks.append(on_threepid_bind) + if on_add_user_third_party_identifier is not None: + self._on_add_user_third_party_identifier_callbacks.append( + on_add_user_third_party_identifier + ) + async def check_event_allowed( self, event: EventBase, @@ -552,6 +569,9 @@ class ThirdPartyEventRules: local homeserver, not when it's created on an identity server (and then kept track of so that it can be unbound on the same IS later on). + THIS MODULE CALLBACK METHOD HAS BEEN DEPRECATED. Please use the + `on_add_user_third_party_identifier` callback method instead. + Args: user_id: the user being associated with the threepid. medium: the threepid's medium. @@ -564,3 +584,44 @@ class ThirdPartyEventRules: logger.exception( "Failed to run module API callback %s: %s", callback, e ) + + async def on_add_user_third_party_identifier( + self, user_id: str, medium: str, address: str + ) -> None: + """Called when an association between a user's Matrix ID and a third-party ID + (email, phone number) has successfully been registered on the homeserver. + + Args: + user_id: The User ID included in the association. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + """ + for callback in self._on_add_user_third_party_identifier_callbacks: + try: + await callback(user_id, medium, address) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + + async def on_remove_user_third_party_identifier( + self, user_id: str, medium: str, address: str + ) -> None: + """Called when an association between a user's Matrix ID and a third-party ID + (email, phone number) has been successfully removed on the homeserver. + + This is called *after* any known bindings on identity servers for this + association have been removed. + + Args: + user_id: The User ID included in the removed association. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + """ + for callback in self._on_remove_user_third_party_identifier_callbacks: + try: + await callback(user_id, medium, address) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ebf8c7ed83..b9c15ffcdb 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -38,8 +38,7 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion -from synapse.types import JsonDict -from synapse.util.frozenutils import unfreeze +from synapse.types import JsonDict, Requester from . import EventBase @@ -317,8 +316,9 @@ class SerializeEventConfig: as_client_event: bool = True # Function to convert from federation format to client format event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 - # ID of the user's auth token - used for namespacing of transaction IDs - token_id: Optional[int] = None + # The entity that requested the event. This is used to determine whether to include + # the transaction_id in the unsigned section of the event. + requester: Optional[Requester] = None # List of event fields to include. If empty, all fields will be returned. only_event_fields: Optional[List[str]] = None # Some events can have stripped room state stored in the `unsigned` field. @@ -368,11 +368,24 @@ def serialize_event( e.unsigned["redacted_because"], time_now_ms, config=config ) - if config.token_id is not None: - if config.token_id == getattr(e.internal_metadata, "token_id", None): - txn_id = getattr(e.internal_metadata, "txn_id", None) - if txn_id is not None: - d["unsigned"]["transaction_id"] = txn_id + # If we have a txn_id saved in the internal_metadata, we should include it in the + # unsigned section of the event if it was sent by the same session as the one + # requesting the event. + # There is a special case for guests, because they only have one access token + # without associated access_token_id, so we always include the txn_id for events + # they sent. + txn_id = getattr(e.internal_metadata, "txn_id", None) + if txn_id is not None and config.requester is not None: + event_token_id = getattr(e.internal_metadata, "token_id", None) + if config.requester.user.to_string() == e.sender and ( + ( + event_token_id is not None + and config.requester.access_token_id is not None + and event_token_id == config.requester.access_token_id + ) + or config.requester.is_guest + ): + d["unsigned"]["transaction_id"] = txn_id # invite_room_state and knock_room_state are a list of stripped room state events # that are meant to provide metadata about a room to an invitee/knocker. They are @@ -403,14 +416,6 @@ class EventClientSerializer: clients. """ - def __init__(self, inhibit_replacement_via_edits: bool = False): - """ - Args: - inhibit_replacement_via_edits: If this is set to True, then events are - never replaced by their edits. - """ - self._inhibit_replacement_via_edits = inhibit_replacement_via_edits - def serialize_event( self, event: Union[JsonDict, EventBase], @@ -418,7 +423,6 @@ class EventClientSerializer: *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, - apply_edits: bool = True, ) -> JsonDict: """Serializes a single event. @@ -428,10 +432,7 @@ class EventClientSerializer: config: Event serialization config bundle_aggregations: A map from event_id to the aggregations to be bundled into the event. - apply_edits: Whether the content of the event should be modified to reflect - any replacement in `bundle_aggregations[<event_id>].replace`. - See also the `inhibit_replacement_via_edits` constructor arg: if that is - set to True, then this argument is ignored. + Returns: The serialized event """ @@ -450,38 +451,10 @@ class EventClientSerializer: config, bundle_aggregations, serialized_event, - apply_edits=apply_edits, ) return serialized_event - def _apply_edit( - self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase - ) -> None: - """Replace the content, preserving existing relations of the serialized event. - - Args: - orig_event: The original event. - serialized_event: The original event, serialized. This is modified. - edit: The event which edits the above. - """ - - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relates_to = orig_event.content.get("m.relates_to") - if relates_to: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relates_to.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - def _inject_bundled_aggregations( self, event: EventBase, @@ -489,7 +462,6 @@ class EventClientSerializer: config: SerializeEventConfig, bundled_aggregations: Dict[str, "BundledAggregations"], serialized_event: JsonDict, - apply_edits: bool, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. @@ -504,9 +476,6 @@ class EventClientSerializer: While serializing the bundled aggregations this map may be searched again for additional events in a recursive manner. serialized_event: The serialized event which may be modified. - apply_edits: Whether the content of the event should be modified to reflect - any replacement in `aggregations.replace` (subject to the - `inhibit_replacement_via_edits` constructor arg). """ # We have already checked that aggregations exist for this event. @@ -516,22 +485,12 @@ class EventClientSerializer: # being serialized. serialized_aggregations = {} - if event_aggregations.annotations: - serialized_aggregations[ - RelationTypes.ANNOTATION - ] = event_aggregations.annotations - if event_aggregations.references: serialized_aggregations[ RelationTypes.REFERENCE ] = event_aggregations.references if event_aggregations.replace: - # If there is an edit, optionally apply it to the event. - edit = event_aggregations.replace - if apply_edits and not self._inhibit_replacement_via_edits: - self._apply_edit(event, serialized_event, edit) - # Include information about it in the relations dict. # # Matrix spec v1.5 (https://spec.matrix.org/v1.5/client-server-api/#server-side-aggregation-of-mreplace-relationships) @@ -539,10 +498,7 @@ class EventClientSerializer: # `sender` of the edit; however MSC3925 proposes extending it to the whole # of the edit, which is what we do here. serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event( - edit, - time_now, - config=config, - apply_edits=False, + event_aggregations.replace, time_now, config=config ) # Include any threaded replies to this event. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index d720b5fd3f..3063df7990 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py
@@ -314,7 +314,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): # stream position. keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} - for ((destination, edu_key), pos) in keyed_edus.items(): + for (destination, edu_key), pos in keyed_edus.items(): rows.append( ( pos, @@ -329,7 +329,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): j = self.edus.bisect_right(to_token) + 1 edus = self.edus.items()[i:j] - for (pos, edu) in edus: + for pos, edu in edus: rows.append((pos, EduRow(edu))) # Sort rows based on pos diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 797de46dbc..7e01c18c6c 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py
@@ -155,9 +155,6 @@ class AccountDataHandler: max_stream_id = await self._store.remove_account_data_for_room( user_id, room_id, account_data_type ) - if max_stream_id is None: - # The referenced account data did not exist, so no delete occurred. - return None self._notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] @@ -230,9 +227,6 @@ class AccountDataHandler: max_stream_id = await self._store.remove_account_data_for_user( user_id, account_data_type ) - if max_stream_id is None: - # The referenced account data did not exist, so no delete occurred. - return None self._notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] @@ -248,7 +242,6 @@ class AccountDataHandler: instance_name=random.choice(self._account_data_writers), user_id=user_id, account_data_type=account_data_type, - content={}, ) return response["max_stream_id"] diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index b03c214b14..b06f25b03c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -14,7 +14,7 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set from synapse.api.constants import Direction, Membership from synapse.events import EventBase @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main + self._store = hs.get_datastores().main self._device_handler = hs.get_device_handler() self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -38,7 +38,7 @@ class AdminHandler: async def get_whois(self, user: UserID) -> JsonDict: connections = [] - sessions = await self.store.get_user_ip_and_agents(user) + sessions = await self._store.get_user_ip_and_agents(user) for session in sessions: connections.append( { @@ -57,7 +57,7 @@ class AdminHandler: async def get_user(self, user: UserID) -> Optional[JsonDict]: """Function to get user details""" - user_info_dict = await self.store.get_user_by_id(user.to_string()) + user_info_dict = await self._store.get_user_by_id(user.to_string()) if user_info_dict is None: return None @@ -89,11 +89,11 @@ class AdminHandler: } # Add additional user metadata - profile = await self.store.get_profileinfo(user.localpart) - threepids = await self.store.user_get_threepids(user.to_string()) + profile = await self._store.get_profileinfo(user.localpart) + threepids = await self._store.user_get_threepids(user.to_string()) external_ids = [ ({"auth_provider": auth_provider, "external_id": external_id}) - for auth_provider, external_id in await self.store.get_external_ids_by_user( + for auth_provider, external_id in await self._store.get_external_ids_by_user( user.to_string() ) ] @@ -101,7 +101,7 @@ class AdminHandler: user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["threepids"] = threepids user_info_dict["external_ids"] = external_ids - user_info_dict["erased"] = await self.store.is_user_erased(user.to_string()) + user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) return user_info_dict @@ -117,7 +117,7 @@ class AdminHandler: The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = await self.store.get_rooms_for_local_user_where_membership_is( + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, membership_list=( Membership.JOIN, @@ -131,7 +131,7 @@ class AdminHandler: # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle # those separately. - rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) + rooms_user_has_been_in = await self._store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): room_id = room.room_id @@ -140,7 +140,7 @@ class AdminHandler: "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) ) - forgotten = await self.store.did_forget(user_id, room_id) + forgotten = await self._store.did_forget(user_id, room_id) if forgotten: logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) continue @@ -152,14 +152,14 @@ class AdminHandler: if room.membership == Membership.INVITE: event_id = room.event_id - invite = await self.store.get_event(event_id, allow_none=True) + invite = await self._store.get_event(event_id, allow_none=True) if invite: invited_state = invite.unsigned["invite_room_state"] writer.write_invite(room_id, invite, invited_state) if room.membership == Membership.KNOCK: event_id = room.event_id - knock = await self.store.get_event(event_id, allow_none=True) + knock = await self._store.get_event(event_id, allow_none=True) if knock: knock_state = knock.unsigned["knock_room_state"] writer.write_knock(room_id, knock, knock_state) @@ -170,7 +170,7 @@ class AdminHandler: # were joined. We estimate that point by looking at the # stream_ordering of the last membership if it wasn't a join. if room.membership == Membership.JOIN: - stream_ordering = self.store.get_room_max_stream_ordering() + stream_ordering = self._store.get_room_max_stream_ordering() else: stream_ordering = room.stream_ordering @@ -197,7 +197,7 @@ class AdminHandler: # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = await self.store.paginate_room_events( + events, _ = await self._store.paginate_room_events( room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS ) if not events: @@ -252,16 +252,49 @@ class AdminHandler: profile = await self.get_user(UserID.from_string(user_id)) if profile is not None: writer.write_profile(profile) + logger.info("[%s] Written profile", user_id) # Get all devices the user has devices = await self._device_handler.get_devices_by_user(user_id) writer.write_devices(devices) + logger.info("[%s] Written %s devices", user_id, len(devices)) # Get all connections the user has connections = await self.get_whois(UserID.from_string(user_id)) writer.write_connections( connections["devices"][""]["sessions"][0]["connections"] ) + logger.info("[%s] Written %s connections", user_id, len(connections)) + + # Get all account data the user has global and in rooms + global_data = await self._store.get_global_account_data_for_user(user_id) + by_room_data = await self._store.get_room_account_data_for_user(user_id) + writer.write_account_data("global", global_data) + for room_id in by_room_data: + writer.write_account_data(room_id, by_room_data[room_id]) + logger.info( + "[%s] Written account data for %s rooms", user_id, len(by_room_data) + ) + + # Get all media ids the user has + limit = 100 + start = 0 + while True: + media_ids, total = await self._store.get_local_media_by_user_paginate( + start, limit, user_id + ) + for media in media_ids: + writer.write_media_id(media["media_id"], media) + + logger.info( + "[%s] Written %d media_ids of %s", + user_id, + (start + len(media_ids)), + total, + ) + if (start + limit) >= total: + break + start += limit return writer.finished() @@ -341,6 +374,30 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + def write_account_data( + self, file_name: str, account_data: Mapping[str, JsonDict] + ) -> None: + """Write the account data of a user. + + Args: + file_name: file name to write data + account_data: mapping of global or room account_data + """ + raise NotImplementedError() + + @abc.abstractmethod + def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + """Write the media's metadata of a user. + Exports only the metadata, as this can be fetched from the database via + read only. In order to access the files, a connection to the correct + media repository would be required. + + Args: + media_id: ID of the media. + media_metadata: Metadata of one media file. + """ + + @abc.abstractmethod def finished(self) -> Any: """Called when all data has successfully been exported and written. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5d1d21cdc8..ec3ab968e9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -737,7 +737,7 @@ class ApplicationServicesHandler: ) ret = [] - for (success, result) in results: + for success, result in results: if success: ret.extend(result) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 57a6854b1e..308e38edea 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -201,7 +201,7 @@ class AuthHandler: for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): - self.checkers[inst.AUTH_TYPE] = inst # type: ignore + self.checkers[inst.AUTH_TYPE] = inst self.bcrypt_rounds = hs.config.registration.bcrypt_rounds @@ -815,7 +815,6 @@ class AuthHandler: now_ms = self._clock.time_msec() if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms: - raise SynapseError( HTTPStatus.FORBIDDEN, "The supplied refresh token has expired", @@ -1543,6 +1542,17 @@ class AuthHandler: async def add_threepid( self, user_id: str, medium: str, address: str, validated_at: int ) -> None: + """ + Adds an association between a user's Matrix ID and a third-party ID (email, + phone number). + + Args: + user_id: The ID of the user to associate. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + validated_at: The timestamp in ms of when the validation that the user owns + this third-party ID occurred. + """ # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -1567,42 +1577,44 @@ class AuthHandler: user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) + # Inform Synapse modules that a 3PID association has been created. + await self._third_party_rules.on_add_user_third_party_identifier( + user_id, medium, address + ) + + # Deprecated method for informing Synapse modules that a 3PID association + # has successfully been created. await self._third_party_rules.on_threepid_bind(user_id, medium, address) - async def delete_threepid( - self, user_id: str, medium: str, address: str, id_server: Optional[str] = None - ) -> bool: - """Attempts to unbind the 3pid on the identity servers and deletes it - from the local database. + async def delete_local_threepid( + self, user_id: str, medium: str, address: str + ) -> None: + """Deletes an association between a third-party ID and a user ID from the local + database. This method does not unbind the association from any identity servers. + + If `medium` is 'email' and a pusher is associated with this third-party ID, the + pusher will also be deleted. Args: user_id: ID of user to remove the 3pid from. medium: The medium of the 3pid being removed: "email" or "msisdn". address: The 3pid address to remove. - id_server: Use the given identity server when unbinding - any threepids. If None then will attempt to unbind using the - identity server specified when binding (if known). - - Returns: - Returns True if successfully unbound the 3pid on - the identity server, False if identity server doesn't support the - unbind API. """ - # 'Canonicalise' email addresses as per above if medium == "email": address = canonicalise_email(address) - result = await self.hs.get_identity_handler().try_unbind_threepid( - user_id, medium, address, id_server + await self.store.user_delete_threepid(user_id, medium, address) + + # Inform Synapse modules that a 3PID association has been deleted. + await self._third_party_rules.on_remove_user_third_party_identifier( + user_id, medium, address ) - await self.store.user_delete_threepid(user_id, medium, address) if medium == "email": await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id="m.email", pushkey=address, user_id=user_id ) - return result async def hash(self, password: str) -> str: """Computes a secure hash of password. @@ -2259,7 +2271,6 @@ class PasswordAuthProvider: async def on_logged_out( self, user_id: str, device_id: Optional[str], access_token: str ) -> None: - # call all of the on_logged_out callbacks for callback in self.on_logged_out_callbacks: try: diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index d24f649382..d31263c717 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -100,26 +100,28 @@ class DeactivateAccountHandler: # unbinding identity_server_supports_unbinding = True - # Retrieve the 3PIDs this user has bound to an identity server - threepids = await self.store.user_get_bound_threepids(user_id) - - for threepid in threepids: + # Attempt to unbind any known bound threepids to this account from identity + # server(s). + bound_threepids = await self.store.user_get_bound_threepids(user_id) + for threepid in bound_threepids: try: result = await self._identity_handler.try_unbind_threepid( user_id, threepid["medium"], threepid["address"], id_server ) - identity_server_supports_unbinding &= result except Exception: # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") - await self.store.user_delete_threepid( + + identity_server_supports_unbinding &= result + + # Remove any local threepid associations for this account. + local_threepids = await self.store.user_get_threepids(user_id) + for threepid in local_threepids: + await self._auth_handler.delete_local_threepid( user_id, threepid["medium"], threepid["address"] ) - # Remove all 3PIDs this user has bound to the homeserver - await self.store.user_delete_threepids(user_id) - # delete any devices belonging to the user, which will also # delete corresponding access tokens. await self._device_handler.delete_all_devices_for_user(user_id) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a5798e9483..1fb23cc9bf 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py
@@ -497,9 +497,11 @@ class DirectoryHandler: raise SynapseError(403, "Not allowed to publish room") # Check if publishing is blocked by a third party module - allowed_by_third_party_rules = await ( - self.third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility + allowed_by_third_party_rules = ( + await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 43cbece21b..4e9c8d8db0 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -1301,6 +1301,20 @@ class E2eKeysHandler: return desired_key_data + async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool: + """Checks if the user has cross-signing set up + + Args: + user_id: The user to check + + Returns: + True if the user has cross-signing set up, False otherwise + """ + existing_master_key = await self.store.get_e2e_cross_signing_key( + user_id, "master" + ) + return existing_master_key is not None + def _check_cross_signing_key( key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 83f53ceb88..50317ec753 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py
@@ -188,7 +188,6 @@ class E2eRoomKeysHandler: # XXX: perhaps we should use a finer grained lock here? async with self._upload_linearizer.queue(user_id): - # Check that the version we're trying to upload is the current version try: version_info = await self.store.get_e2e_room_keys_version_info(user_id) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 46dd63c3f0..c508861b6a 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py
@@ -236,7 +236,6 @@ class EventAuthHandler: # in any of them. allowed_rooms = await self.get_rooms_that_allow_join(state_ids) if not await self.is_user_in_rooms(allowed_rooms, user_id): - # If this is a remote request, the user might be in an allowed room # that we do not know about. if get_domain_from_id(user_id) != self._server_name: diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 949b69cb41..68c07f0265 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py
@@ -23,7 +23,7 @@ from synapse.events.utils import SerializeEventConfig from synapse.handlers.presence import format_user_presence_state from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -46,13 +46,12 @@ class EventStreamHandler: async def get_stream( self, - auth_user_id: str, + requester: Requester, pagin_config: PaginationConfig, timeout: int = 0, as_client_event: bool = True, affect_presence: bool = True, room_id: Optional[str] = None, - is_guest: bool = False, ) -> JsonDict: """Fetches the events stream for a given user.""" @@ -62,13 +61,12 @@ class EventStreamHandler: raise SynapseError(403, "This room has been blocked on this server") # send any outstanding server notices to the user. - await self._server_notices_sender.on_user_syncing(auth_user_id) + await self._server_notices_sender.on_user_syncing(requester.user.to_string()) - auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() context = await presence_handler.user_syncing( - auth_user_id, + requester.user.to_string(), affect_presence=affect_presence, presence_state=PresenceState.ONLINE, ) @@ -82,10 +80,10 @@ class EventStreamHandler: timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) stream_result = await self.notifier.get_events_for( - auth_user, + requester.user, pagin_config, timeout, - is_guest=is_guest, + is_guest=requester.is_guest, explicit_room_id=room_id, ) events = stream_result.events @@ -102,7 +100,7 @@ class EventStreamHandler: if event.membership != Membership.JOIN: continue # Send down presence. - if event.state_key == auth_user_id: + if event.state_key == requester.user.to_string(): # Send down presence for everyone in the room. users: Iterable[str] = await self.store.get_users_in_room( event.room_id @@ -124,7 +122,9 @@ class EventStreamHandler: chunks = self._event_serializer.serialize_events( events, time_now, - config=SerializeEventConfig(as_client_event=as_client_event), + config=SerializeEventConfig( + as_client_event=as_client_event, requester=requester + ), ) chunk = { diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5371336529..50f8041f17 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -952,7 +952,20 @@ class FederationHandler: # # Note that this requires the /send_join request to come back to the # same server. + prev_event_ids = None if room_version.msc3083_join_rules: + # Note that the room's state can change out from under us and render our + # nice join rules-conformant event non-conformant by the time we build the + # event. When this happens, our validation at the end fails and we respond + # to the requesting server with a 403, which is misleading — it indicates + # that the user is not allowed to join the room and the joining server + # should not bother retrying via this homeserver or any others, when + # in fact we've just messed up with building the event. + # + # To reduce the likelihood of this race, we capture the forward extremities + # of the room (prev_event_ids) just before fetching the current state, and + # hope that the state we fetch corresponds to the prev events we chose. + prev_event_ids = await self.store.get_prev_events_for_room(room_id) state_ids = await self._state_storage_controller.get_current_state_ids( room_id ) @@ -995,7 +1008,8 @@ class FederationHandler: unpersisted_context, _, ) = await self.event_creation_handler.create_new_client_event( - builder=builder + builder=builder, + prev_event_ids=prev_event_ids, ) except SynapseError as e: logger.warning("Failed to create join to %s because %s", room_id, e) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 1a29abde98..b3be7a86f0 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,6 @@ class InitialSyncHandler: as_client_event: bool = True, include_archived: bool = False, ) -> JsonDict: - memberships = [Membership.INVITE, Membership.JOIN] if include_archived: memberships.append(Membership.LEAVE) @@ -319,11 +318,9 @@ class InitialSyncHandler: ) is_peeking = member_event_id is None - user_id = requester.user.to_string() - if membership == Membership.JOIN: result = await self._room_initial_sync_joined( - user_id, room_id, pagin_config, membership, is_peeking + requester, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: # The member_event_id will always be available if membership is set @@ -331,10 +328,16 @@ class InitialSyncHandler: assert member_event_id result = await self._room_initial_sync_parted( - user_id, room_id, pagin_config, membership, member_event_id, is_peeking + requester, + room_id, + pagin_config, + membership, + member_event_id, + is_peeking, ) account_data_events = [] + user_id = requester.user.to_string() tags = await self.store.get_tags_for_room(user_id, room_id) if tags: account_data_events.append( @@ -351,7 +354,7 @@ class InitialSyncHandler: async def _room_initial_sync_parted( self, - user_id: str, + requester: Requester, room_id: str, pagin_config: PaginationConfig, membership: str, @@ -370,13 +373,17 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self._storage_controllers, user_id, messages, is_peeking=is_peeking + self._storage_controllers, + requester.user.to_string(), + messages, + is_peeking=is_peeking, ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token) time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) return { "membership": membership, @@ -384,14 +391,18 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events( + messages, time_now, config=serialize_options + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(room_state.values(), time_now) + self._event_serializer.serialize_events( + room_state.values(), time_now, config=serialize_options + ) ), "presence": [], "receipts": [], @@ -399,7 +410,7 @@ class InitialSyncHandler: async def _room_initial_sync_joined( self, - user_id: str, + requester: Requester, room_id: str, pagin_config: PaginationConfig, membership: str, @@ -411,9 +422,12 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) # Don't bundle aggregations as this is a deprecated API. state = self._event_serializer.serialize_events( - current_state.values(), time_now + current_state.values(), + time_now, + config=serialize_options, ) now_token = self.hs.get_event_sources().get_current_token() @@ -451,7 +465,10 @@ class InitialSyncHandler: if not receipts: return [] - return ReceiptEventSource.filter_out_private_receipts(receipts, user_id) + return ReceiptEventSource.filter_out_private_receipts( + receipts, + requester.user.to_string(), + ) presence, receipts, (messages, token) = await make_deferred_yieldable( gather_results( @@ -470,20 +487,23 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self._storage_controllers, user_id, messages, is_peeking=is_peeking + self._storage_controllers, + requester.user.to_string(), + messages, + is_peeking=is_peeking, ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) end_token = now_token - time_now = self.clock.time_msec() - ret = { "room_id": room_id, "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events( + messages, time_now, config=serialize_options + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 77d92f1574..d283a938c0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -16,6 +16,7 @@ # limitations under the License. import logging import random +from builtins import dict from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple @@ -50,7 +51,7 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext, UnpersistedEventContextBase -from synapse.events.utils import maybe_upsert_event_field +from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler from synapse.logging import opentracing @@ -245,8 +246,11 @@ class MessageHandler: ) room_state = room_state_events[membership_event_id] - now = self.clock.time_msec() - events = self._event_serializer.serialize_events(room_state.values(), now) + events = self._event_serializer.serialize_events( + room_state.values(), + self.clock.time_msec(), + config=SerializeEventConfig(requester=requester), + ) return events async def _user_can_see_state_at_event( @@ -574,7 +578,7 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext, Optional[dict]]: + ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -723,8 +727,6 @@ class EventCreationHandler: current_state_group=current_state_group, ) - context = await unpersisted_context.persist(event) - # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -741,7 +743,7 @@ class EventCreationHandler: assert state_map is not None prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: - prev_state_ids = await context.get_prev_state_ids( + prev_state_ids = await unpersisted_context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) @@ -766,8 +768,7 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) - - return event, context, new_event + return event, unpersisted_context, new_event async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1007,7 +1008,11 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context, third_party_event_dict = await self.create_event( + ( + event, + unpersisted_context, + third_party_event_dict, + ) = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1018,6 +1023,7 @@ class EventCreationHandler: historical=historical, depth=depth, ) + context = await unpersisted_context.persist(event) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( event.sender, @@ -1209,7 +1215,6 @@ class EventCreationHandler: if for_batch: assert prev_event_ids is not None assert state_map is not None - assert current_state_group is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth @@ -2067,7 +2072,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context, _ = await self.create_event( + event, unpersisted_context, _ = await self.create_event( requester, { "type": EventTypes.Dummy, @@ -2076,6 +2081,7 @@ class EventCreationHandler: "sender": user_id, }, ) + context = await unpersisted_context.persist(event) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index ceefa16b49..8c79c055ba 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -579,7 +579,9 @@ class PaginationHandler: time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig(as_client_event=as_client_event) + serialize_options = SerializeEventConfig( + as_client_event=as_client_event, requester=requester + ) chunk = { "chunk": ( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 87af31aa27..4ad2233573 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler): ) if self.unpersisted_users_changes: - await self.store.update_presence( [ self.user_to_current_state[user_id] @@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler): now = self.clock.time_msec() with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c611efb760..e4e506e62c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -476,7 +476,7 @@ class RegistrationHandler: # create room expects the localpart of the room alias config["room_alias_name"] = room_alias.localpart - info, _ = await room_creation_handler.create_room( + room_id, _, _ = await room_creation_handler.create_room( fake_requester, config=config, ratelimit=False, @@ -490,7 +490,7 @@ class RegistrationHandler: user_id, authenticated_entity=self._server_name ), target=UserID.from_string(user_id), - room_id=info["room_id"], + room_id=room_id, # Since it was just created, there are no remote hosts. remote_room_hosts=[], action="join", diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0fb15391e0..1d09fdf135 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py
@@ -20,6 +20,7 @@ import attr from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event +from synapse.events.utils import SerializeEventConfig from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent @@ -60,13 +61,12 @@ class BundledAggregations: Some values require additional processing during serialization. """ - annotations: Optional[JsonDict] = None references: Optional[JsonDict] = None replace: Optional[EventBase] = None thread: Optional[_ThreadAggregation] = None def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) + return bool(self.references or self.replace or self.thread) class RelationsHandler: @@ -152,16 +152,23 @@ class RelationsHandler: ) now = self._clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) return_value: JsonDict = { "chunk": self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations + events, + now, + bundle_aggregations=aggregations, + config=serialize_options, ), } if include_original_event: # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. return_value["original_event"] = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None + event, + now, + bundle_aggregations=None, + config=serialize_options, ) if next_token: @@ -227,67 +234,6 @@ class RelationsHandler: e.msg, ) - async def get_annotations_for_events( - self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() - ) -> Dict[str, List[JsonDict]]: - """Get a list of annotations to the given events, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happened - on an event. - - Args: - event_ids: Fetch events that relate to these event IDs. - ignored_users: The users ignored by the requesting user. - - Returns: - A map of event IDs to a list of groups of annotations that match. - Each entry is a dict with `type`, `key` and `count` fields. - """ - # Get the base results for all users. - full_results = await self._main_store.get_aggregation_groups_for_events( - event_ids - ) - - # Avoid additional logic if there are no ignored users. - if not ignored_users: - return { - event_id: results - for event_id, results in full_results.items() - if results - } - - # Then subtract off the results for any ignored users. - ignored_results = await self._main_store.get_aggregation_groups_for_users( - [event_id for event_id, results in full_results.items() if results], - ignored_users, - ) - - filtered_results = {} - for event_id, results in full_results.items(): - # If no annotations, skip. - if not results: - continue - - # If there are not ignored results for this event, copy verbatim. - if event_id not in ignored_results: - filtered_results[event_id] = results - continue - - # Otherwise, subtract out the ignored results. - event_ignored_results = ignored_results[event_id] - for result in results: - key = (result["type"], result["key"]) - if key in event_ignored_results: - # Ensure to not modify the cache. - result = result.copy() - result["count"] -= event_ignored_results[key] - if result["count"] <= 0: - continue - filtered_results.setdefault(event_id, []).append(result) - - return filtered_results - async def get_references_for_events( self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() ) -> Dict[str, List[_RelatedEvent]]: @@ -531,17 +477,6 @@ class RelationsHandler: # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - async def _fetch_annotations() -> None: - """Fetch any annotations (ie, reactions) to bundle with this event.""" - annotations_by_event_id = await self.get_annotations_for_events( - events_by_id.keys(), ignored_users=ignored_users - ) - for event_id, annotations in annotations_by_event_id.items(): - if annotations: - results.setdefault(event_id, BundledAggregations()).annotations = { - "chunk": annotations - } - async def _fetch_references() -> None: """Fetch any references to bundle with this event.""" references_by_event_id = await self.get_references_for_events( @@ -575,7 +510,6 @@ class RelationsHandler: await make_deferred_yieldable( gather_results( ( - run_in_background(_fetch_annotations), run_in_background(_fetch_references), run_in_background(_fetch_edits), ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9fb7209549..c70afa3176 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -51,6 +51,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM @@ -211,7 +212,7 @@ class RoomCreationHandler: # the required power level to send the tombstone event. ( tombstone_event, - tombstone_context, + tombstone_unpersisted_context, _, ) = await self.event_creation_handler.create_event( requester, @@ -226,6 +227,9 @@ class RoomCreationHandler: }, }, ) + tombstone_context = await tombstone_unpersisted_context.persist( + tombstone_event + ) validate_event_for_room_version(tombstone_event) await self._event_auth_handler.check_auth_rules_from_context( tombstone_event @@ -691,13 +695,14 @@ class RoomCreationHandler: config: JsonDict, ratelimit: bool = True, creator_join_profile: Optional[JsonDict] = None, - ) -> Tuple[dict, int]: + ) -> Tuple[str, Optional[RoomAlias], int]: """Creates a new room. Args: - requester: - The user who requested the room creation. - config : A dict of configuration options. + requester: The user who requested the room creation. + config: A dict of configuration options. This will be the body of + a /createRoom request; see + https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom ratelimit: set to False to disable the rate limiter creator_join_profile: @@ -708,14 +713,17 @@ class RoomCreationHandler: `avatar_url` and/or `displayname`. Returns: - First, a dict containing the keys `room_id` and, if an alias - was, requested, `room_alias`. Secondly, the stream_id of the - last persisted event. + A 3-tuple containing: + - the room ID; + - if requested, the room alias, otherwise None; and + - the `stream_id` of the last persisted event. Raises: - SynapseError if the room ID couldn't be stored, 3pid invitation config - validation failed, or something went horribly wrong. - ResourceLimitError if server is blocked to some resource being - exceeded + SynapseError: + if the room ID couldn't be stored, 3pid invitation config + validation failed, or something went horribly wrong. + ResourceLimitError: + if server is blocked to some resource being + exceeded """ user_id = requester.user.to_string() @@ -865,9 +873,11 @@ class RoomCreationHandler: ) # Check whether this visibility value is blocked by a third party module - allowed_by_third_party_rules = await ( - self.third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility + allowed_by_third_party_rules = ( + await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) ) ) if not allowed_by_third_party_rules: @@ -1025,11 +1035,6 @@ class RoomCreationHandler: last_sent_event_id = member_event_id depth += 1 - result = {"room_id": room_id} - - if room_alias: - result["room_alias"] = room_alias.to_string() - # Always wait for room creation to propagate before returning await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), @@ -1037,7 +1042,7 @@ class RoomCreationHandler: last_stream_id, ) - return result, last_stream_id + return room_id, room_alias, last_stream_id async def _send_events_for_new_room( self, @@ -1092,7 +1097,11 @@ class RoomCreationHandler: content: JsonDict, for_batch: bool, **kwargs: Any, - ) -> Tuple[EventBase, synapse.events.snapshot.EventContext, Optional[dict]]: + ) -> Tuple[ + EventBase, + synapse.events.snapshot.UnpersistedEventContextBase, + Optional[dict], + ]: """ Creates an event and associated event context. Args: @@ -1113,7 +1122,7 @@ class RoomCreationHandler: ( new_event, - new_context, + new_unpersisted_context, third_party_event, ) = await self.event_creation_handler.create_event( creator, @@ -1122,13 +1131,13 @@ class RoomCreationHandler: depth=depth, state_map=state_map, for_batch=for_batch, - current_state_group=current_state_group, ) + depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - return new_event, new_context, third_party_event + return new_event, new_unpersisted_context, third_party_event try: config = self._presets_dict[preset_config] @@ -1138,10 +1147,10 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - creation_event, creation_context, _ = await create_event( + creation_event, unpersisted_creation_context, _ = await create_event( EventTypes.Create, creation_content, False ) - + creation_context = await unpersisted_creation_context.persist(creation_event) logger.debug("Sending %s in new room", EventTypes.Member) ev = await self.event_creation_handler.handle_new_client_event( requester=creator, @@ -1186,7 +1195,6 @@ class RoomCreationHandler: power_event, power_context, power_tp_event = await create_event( EventTypes.PowerLevels, pl_content, True ) - current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) if power_tp_event: third_party_events_to_append.append(power_tp_event) @@ -1237,7 +1245,6 @@ class RoomCreationHandler: power_level_content, True, ) - current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) if pl_tp_event: third_party_events_to_append.append(pl_tp_event) @@ -1246,7 +1253,6 @@ class RoomCreationHandler: room_alias_event, room_alias_context, ra_tp_event = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) - current_state_group = room_alias_context._state_group events_to_send.append((room_alias_event, room_alias_context)) if ra_tp_event: third_party_events_to_append.append(ra_tp_event) @@ -1257,7 +1263,6 @@ class RoomCreationHandler: {"join_rule": config["join_rules"]}, True, ) - current_state_group = join_rules_context._state_group events_to_send.append((join_rules_event, join_rules_context)) if jr_tp_event: third_party_events_to_append.append(jr_tp_event) @@ -1268,7 +1273,6 @@ class RoomCreationHandler: {"history_visibility": config["history_visibility"]}, True, ) - current_state_group = visibility_context._state_group events_to_send.append((visibility_event, visibility_context)) if vis_tp_event: third_party_events_to_append.append(vis_tp_event) @@ -1284,7 +1288,6 @@ class RoomCreationHandler: {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, True, ) - current_state_group = guest_access_context._state_group events_to_send.append((guest_access_event, guest_access_context)) if ga_tp_event: third_party_events_to_append.append(ga_tp_event) @@ -1293,7 +1296,6 @@ class RoomCreationHandler: event, context, tp_event = await create_event( etype, content, True, state_key=state_key ) - current_state_group = context._state_group events_to_send.append((event, context)) if tp_event: third_party_events_to_append.append(tp_event) @@ -1325,9 +1327,16 @@ class RoomCreationHandler: context = await unpersisted_context.persist(event) events_to_send.append((event, context)) + datastore = self.hs.get_datastores().state + events_and_context = ( + await UnpersistedEventContext.batch_persist_unpersisted_contexts( + events_to_send, room_id, current_state_group, datastore + ) + ) + last_event = await self.event_creation_handler.handle_new_client_event( creator, - events_to_send, + events_and_context, ignore_shadow_ban=True, ratelimit=False, ) @@ -1867,7 +1876,7 @@ class RoomShutdownHandler: new_room_user_id, authenticated_entity=requester_user_id ) - info, stream_id = await self._room_creation_handler.create_room( + new_room_id, _, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": RoomCreationPreset.PUBLIC_CHAT, @@ -1876,7 +1885,6 @@ class RoomShutdownHandler: }, ratelimit=False, ) - new_room_id = info["room_id"] logger.info( "Shutting down room %r, joining to new room: %r", room_id, new_room_id @@ -1929,6 +1937,7 @@ class RoomShutdownHandler: # Join users to new room if new_room_user_id: + assert new_room_id is not None await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, @@ -1961,6 +1970,7 @@ class RoomShutdownHandler: aliases_for_room = await self.store.get_aliases_for_room(room_id) + assert new_room_id is not None await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 3d432ac295..8b5e02af17 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,11 @@ class RoomBatchHandler: # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - event, context, _ = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + _, + ) = await self.event_creation_handler.create_event( await self.create_requester_for_user_id_from_app_service( ev["sender"], app_service_requester.app_service ), @@ -345,7 +349,7 @@ class RoomBatchHandler: historical=True, depth=inherited_depth, ) - + context = await unpersisted_context.persist(event) assert context._state_group # Normally this is done when persisting the event but we have to @@ -374,7 +378,7 @@ class RoomBatchHandler: # correct stream_ordering as they are backfilled (which decrements). # Events are sorted by (topological_ordering, stream_ordering) # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): + for event, context in reversed(events_to_persist): # This call can't raise `PartialStateConflictError` since we forbid # use of the historical batch API during partial state await self.event_creation_handler.handle_new_client_event( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index f136ee0e97..9d3096df8d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -207,6 +207,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): @abc.abstractmethod async def remote_knock( self, + requester: Requester, remote_room_hosts: List[str], room_id: str, user: UserID, @@ -416,7 +417,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): try: ( event, - context, + unpersisted_context, third_party_event, ) = await self.event_creation_handler.create_event( requester, @@ -439,7 +440,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier=outlier, historical=historical, ) - + context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -1088,7 +1089,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) return await self.remote_knock( - remote_room_hosts, room_id, target, content + requester, remote_room_hosts, room_id, target, content ) return await self._local_membership_update( @@ -1964,7 +1965,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): try: ( event, - context, + unpersisted_context, third_party_event_dict, ) = await self.event_creation_handler.create_event( requester, @@ -1974,6 +1975,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): auth_event_ids=auth_event_ids, outlier=True, ) + context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True events_and_context = [(event, context)] @@ -2013,6 +2015,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def remote_knock( self, + requester: Requester, remote_room_hosts: List[str], room_id: str, user: UserID, diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index ba261702d4..76e36b8a6d 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py
@@ -113,6 +113,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): async def remote_knock( self, + requester: Requester, remote_room_hosts: List[str], room_id: str, user: UserID, @@ -123,9 +124,10 @@ class RoomMemberWorkerHandler(RoomMemberHandler): Implements RoomMemberHandler.remote_knock """ ret = await self._remote_knock_client( + requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, - user=user, + user_id=user.to_string(), content=content, ) return ret["event_id"], ret["stream_id"] diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9bbf83047d..aad4706f14 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py
@@ -23,7 +23,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID +from synapse.events.utils import SerializeEventConfig +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client @@ -109,12 +110,12 @@ class SearchHandler: return historical_room_ids async def search( - self, user: UserID, content: JsonDict, batch: Optional[str] = None + self, requester: Requester, content: JsonDict, batch: Optional[str] = None ) -> JsonDict: """Performs a full text search for a user. Args: - user: The user performing the search. + requester: The user performing the search. content: Search parameters batch: The next_batch parameter. Used for pagination. @@ -199,7 +200,7 @@ class SearchHandler: ) return await self._search( - user, + requester, batch_group, batch_group_key, batch_token, @@ -217,7 +218,7 @@ class SearchHandler: async def _search( self, - user: UserID, + requester: Requester, batch_group: Optional[str], batch_group_key: Optional[str], batch_token: Optional[str], @@ -235,7 +236,7 @@ class SearchHandler: """Performs a full text search for a user. Args: - user: The user performing the search. + requester: The user performing the search. batch_group: Pagination information. batch_group_key: Pagination information. batch_token: Pagination information. @@ -269,7 +270,7 @@ class SearchHandler: # TODO: Search through left rooms too rooms = await self.store.get_rooms_for_local_user_where_membership_is( - user.to_string(), + requester.user.to_string(), membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) @@ -303,13 +304,13 @@ class SearchHandler: if order_by == "rank": search_result, sender_group = await self._search_by_rank( - user, room_ids, search_term, keys, search_filter + requester.user, room_ids, search_term, keys, search_filter ) # Unused return values for rank search. global_next_batch = None elif order_by == "recent": search_result, global_next_batch = await self._search_by_recent( - user, + requester.user, room_ids, search_term, keys, @@ -334,7 +335,7 @@ class SearchHandler: assert after_limit is not None contexts = await self._calculate_event_contexts( - user, + requester.user, search_result.allowed_events, before_limit, after_limit, @@ -363,27 +364,37 @@ class SearchHandler: # The returned events. search_result.allowed_events, ), - user.to_string(), + requester.user.to_string(), ) # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise, the 'age' will be wrong. time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now, bundle_aggregations=aggregations + context["events_before"], + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now, bundle_aggregations=aggregations + context["events_after"], + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ) results = [ { "rank": search_result.rank_map[e.event_id], "result": self._event_serializer.serialize_event( - e, time_now, bundle_aggregations=aggregations + e, + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ), "context": contexts.get(e.event_id, {}), } @@ -398,7 +409,9 @@ class SearchHandler: if state_results: rooms_cat_res["state"] = { - room_id: self._event_serializer.serialize_events(state_events, time_now) + room_id: self._event_serializer.serialize_events( + state_events, time_now, config=serialize_options + ) for room_id, state_events in state_results.items() } diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4e4595312c..fd6d946c37 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -1297,7 +1297,6 @@ class SyncHandler: return RoomNotifCounts.empty() with Measure(self.clock, "unread_notifs_for_room_id"): - return await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 332edcca24..78a75bfed6 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py
@@ -13,7 +13,8 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type from twisted.web.client import PartialDownloadError @@ -27,19 +28,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class UserInteractiveAuthChecker: +class UserInteractiveAuthChecker(ABC): """Abstract base class for an interactive auth checker""" - def __init__(self, hs: "HomeServer"): + # This should really be an "abstract class property", i.e. it should + # be an error to instantiate a subclass that doesn't specify an AUTH_TYPE. + # But calling this a `ClassVar` is simpler than a decorator stack of + # @property @abstractmethod and @classmethod (if that's even the right order). + AUTH_TYPE: ClassVar[str] + + def __init__(self, hs: "HomeServer"): # noqa: B027 pass + @abstractmethod def is_enabled(self) -> bool: """Check if the configuration of the homeserver allows this checker to work Returns: True if this login type is enabled. """ + raise NotImplementedError() + @abstractmethod async def check_auth(self, authdict: dict, clientip: str) -> Any: """Given the authentication dict from the client, attempt to check this step @@ -304,7 +314,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): ) -INTERACTIVE_AUTH_CHECKERS = [ +INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, diff --git a/synapse/http/client.py b/synapse/http/client.py
index a05f297933..ae48e7c3f0 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -44,6 +44,7 @@ from twisted.internet.interfaces import ( IAddress, IDelayedCall, IHostResolution, + IOpenSSLContextFactory, IReactorCore, IReactorPluggableNameResolver, IReactorTime, @@ -958,8 +959,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): self._context = SSL.Context(SSL.SSLv23_METHOD) self._context.set_verify(VERIFY_NONE, lambda *_: False) - def getContext(self, hostname=None, port=None): + def getContext(self) -> SSL.Context: return self._context - def creatorForNetloc(self, hostname: bytes, port: int): + def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory: return self diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b92f1d3d1a..3302d4e48a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -440,7 +440,7 @@ class MatrixFederationHttpClient: Args: request: details of request to be sent - retry_on_dns_fail: true if the request should be retied on DNS failures + retry_on_dns_fail: true if the request should be retried on DNS failures timeout: number of milliseconds to wait for the response headers (including connecting to the server), *for each attempt*. @@ -475,7 +475,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -871,7 +871,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -958,7 +958,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1036,6 +1036,8 @@ class MatrixFederationHttpClient: args: A dictionary used to create query strings, defaults to None. + retry_on_dns_fail: true if the request should be retried on DNS failures + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. @@ -1063,7 +1065,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1141,7 +1143,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1197,7 +1199,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1267,7 +1269,7 @@ class MatrixFederationHttpClient: def _flatten_response_never_received(e: BaseException) -> str: if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined] + _flatten_response_never_received(f.value) for f in e.reasons ) return "%s:[%s]" % (type(e).__name__, reasons) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 6c7cf1b294..c70eee649c 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -188,7 +188,7 @@ from typing import ( ) import attr -from typing_extensions import ParamSpec +from typing_extensions import Concatenate, ParamSpec from twisted.internet import defer from twisted.web.http import Request @@ -445,7 +445,7 @@ def init_tracer(hs: "HomeServer") -> None: opentracing = None # type: ignore[assignment] return - if not opentracing or not JaegerConfig: + if opentracing is None or JaegerConfig is None: raise ConfigError( "The server has been configured to use opentracing but opentracing is not " "installed." @@ -524,6 +524,7 @@ def whitelisted_homeserver(destination: str) -> bool: # Start spans and scopes + # Could use kwargs but I want these to be explicit def start_active_span( operation_name: str, @@ -872,7 +873,7 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte def _custom_sync_async_decorator( func: Callable[P, R], - wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]], + wrapping_logic: Callable[Concatenate[Callable[P, R], P], ContextManager[None]], ) -> Callable[P, R]: """ Decorates a function that is sync or async (coroutines), or that returns a Twisted @@ -902,10 +903,14 @@ def _custom_sync_async_decorator( """ if inspect.iscoroutinefunction(func): - + # In this branch, R = Awaitable[RInner], for some other type RInner @wraps(func) - async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + async def _wrapper( + *args: P.args, **kwargs: P.kwargs + ) -> Any: # Return type is RInner with wrapping_logic(func, *args, **kwargs): + # type-ignore: func() returns R, but mypy doesn't know that R is + # Awaitable here. return await func(*args, **kwargs) # type: ignore[misc] else: @@ -972,7 +977,11 @@ def trace_with_opname( if not opentracing: return func - return _custom_sync_async_decorator(func, _wrapping_logic) + # type-ignore: mypy seems to be confused by the ParamSpecs here. + # I think the problem is https://github.com/python/mypy/issues/12909 + return _custom_sync_async_decorator( + func, _wrapping_logic # type: ignore[arg-type] + ) return _decorator @@ -1018,7 +1027,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield - return _custom_sync_async_decorator(func, _wrapping_logic) + # type-ignore: mypy seems to be confused by the ParamSpecs here. + # I think the problem is https://github.com/python/mypy/issues/12909 + return _custom_sync_async_decorator(func, _wrapping_logic) # type: ignore[arg-type] @contextlib.contextmanager diff --git a/synapse/media/_base.py b/synapse/media/_base.py new file mode 100644
index 0000000000..ef8334ae25 --- /dev/null +++ b/synapse/media/_base.py
@@ -0,0 +1,479 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import urllib +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type + +import attr + +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender +from twisted.web.server import Request + +from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.http.server import finish_request, respond_with_json +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable +from synapse.util.stringutils import is_ascii, parse_and_validate_server_name + +logger = logging.getLogger(__name__) + +# list all text content types that will have the charset default to UTF-8 when +# none is given +TEXT_CONTENT_TYPES = [ + "text/css", + "text/csv", + "text/html", + "text/calendar", + "text/plain", + "text/javascript", + "application/json", + "application/ld+json", + "application/rtf", + "image/svg+xml", + "text/xml", +] + + +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: + """Parses the server name, media ID and optional file name from the request URI + + Also performs some rough validation on the server name. + + Args: + request: The `Request`. + + Returns: + A tuple containing the parsed server name, media ID and optional file name. + + Raises: + SynapseError(404): if parsing or validation fail for any reason + """ + try: + # The type on postpath seems incorrect in Twisted 21.2.0. + postpath: List[bytes] = request.postpath # type: ignore + assert postpath + + # This allows users to append e.g. /test.png to the URL. Useful for + # clients that parse the URL to see content type. + server_name_bytes, media_id_bytes = postpath[:2] + server_name = server_name_bytes.decode("utf-8") + media_id = media_id_bytes.decode("utf8") + + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + + file_name = None + if len(postpath) > 2: + try: + file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) + except UnicodeDecodeError: + pass + return server_name, media_id, file_name + except Exception: + raise SynapseError( + 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN + ) + + +def respond_404(request: SynapseRequest) -> None: + respond_with_json( + request, + 404, + cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), + send_cors=True, + ) + + +async def respond_with_file( + request: SynapseRequest, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: + logger.debug("Responding with %r", file_path) + + if os.path.isfile(file_path): + if file_size is None: + stat = os.stat(file_path) + file_size = stat.st_size + + add_file_headers(request, media_type, file_size, upload_name) + + with open(file_path, "rb") as f: + await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + + finish_request(request) + else: + respond_404(request) + + +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: + """Adds the correct response headers in preparation for responding with the + media. + + Args: + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. + """ + + def _quote(x: str) -> str: + return urllib.parse.quote(x.encode("utf-8")) + + # Default to a UTF-8 charset for text content types. + # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' + if media_type.lower() in TEXT_CONTENT_TYPES: + content_type = media_type + "; charset=UTF-8" + else: + content_type = media_type + + request.setHeader(b"Content-Type", content_type.encode("UTF-8")) + if upload_name: + # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. + # + # `filename` is defined to be a `value`, which is defined by RFC2616 + # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` + # is (essentially) a single US-ASCII word, and a `quoted-string` is a + # US-ASCII string surrounded by double-quotes, using backslash as an + # escape character. Note that %-encoding is *not* permitted. + # + # `filename*` is defined to be an `ext-value`, which is defined in + # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, + # where `value-chars` is essentially a %-encoded string in the given charset. + # + # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 + # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 + # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 + + # We avoid the quoted-string version of `filename`, because (a) synapse didn't + # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we + # may as well just do the filename* version. + if _can_encode_filename_as_token(upload_name): + disposition = "inline; filename=%s" % (upload_name,) + else: + disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) + + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) + + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") + + +# separators as defined in RFC2616. SP and HT are handled separately. +# see _can_encode_filename_as_token. +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} + + +def _can_encode_filename_as_token(x: str) -> bool: + for c in x: + # from RFC2616: + # + # token = 1*<any CHAR except CTLs or separators> + # + # separators = "(" | ")" | "<" | ">" | "@" + # | "," | ";" | ":" | "\" | <"> + # | "/" | "[" | "]" | "?" | "=" + # | "{" | "}" | SP | HT + # + # CHAR = <any US-ASCII character (octets 0 - 127)> + # + # CTL = <any US-ASCII control character + # (octets 0 - 31) and DEL (127)> + # + if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: + return False + return True + + +async def respond_with_responder( + request: SynapseRequest, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: + """Responds to the request with given responder. If responder is None then + returns 404. + + Args: + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. + """ + if not responder: + respond_404(request) + return + + # If we have a responder we *must* use it as a context manager. + with responder: + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + + logger.debug("Responding to media request with responder %s", responder) + add_file_headers(request, media_type, file_size, upload_name) + try: + await responder.write_to_consumer(request) + except Exception as e: + # The majority of the time this will be due to the client having gone + # away. Unfortunately, Twisted simply throws a generic exception at us + # in that case. + logger.warning("Failed to write to consumer: %s %s", type(e), e) + + # Unregister the producer, if it has one, so Twisted doesn't complain + if request.producer: + request.unregisterProducer() + + finish_request(request) + + +class Responder(ABC): + """Represents a response that can be streamed to the requester. + + Responder is a context manager which *must* be used, so that any resources + held can be cleaned up. + """ + + @abstractmethod + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: + """Stream response into consumer + + Args: + consumer: The consumer to stream into. + + Returns: + Resolves once the response has finished being written + """ + raise NotImplementedError() + + def __enter__(self) -> None: # noqa: B027 + pass + + def __exit__( # noqa: B027 + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + pass + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThumbnailInfo: + """Details about a generated thumbnail.""" + + width: int + height: int + method: str + # Content type of thumbnail, e.g. image/png + type: str + # The size of the media file, in bytes. + length: Optional[int] = None + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FileInfo: + """Details about a requested/uploaded file.""" + + # The server name where the media originated from, or None if local. + server_name: Optional[str] + # The local ID of the file. For local files this is the same as the media_id + file_id: str + # If the file is for the url preview cache + url_cache: bool = False + # Whether the file is a thumbnail or not. + thumbnail: Optional[ThumbnailInfo] = None + + # The below properties exist to maintain compatibility with third-party modules. + @property + def thumbnail_width(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.width + + @property + def thumbnail_height(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.height + + @property + def thumbnail_method(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.method + + @property + def thumbnail_type(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.type + + @property + def thumbnail_length(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.length + + +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: + """ + Get the filename of the downloaded file by inspecting the + Content-Disposition HTTP header. + + Args: + headers: The HTTP request headers. + + Returns: + The filename, or None. + """ + content_disposition = headers.get(b"Content-Disposition", [b""]) + + # No header, bail out. + if not content_disposition[0]: + return None + + _, params = _parse_header(content_disposition[0]) + + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get(b"filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith(b"utf-8''"): + upload_name_utf8 = upload_name_utf8[7:] + # We have a filename*= section. This MUST be ASCII, and any UTF-8 + # bytes are %-quoted. + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get(b"filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii.decode("ascii") + + # This may be None here, indicating we did not find a matching name. + return upload_name + + +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: + """Parse a Content-type like header. + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + line: header to be parsed + + Returns: + The main content-type, followed by the parameter dictionary + """ + parts = _parseparam(b";" + line) + key = next(parts) + pdict = {} + for p in parts: + i = p.find(b"=") + if i >= 0: + name = p[:i].strip().lower() + value = p[i + 1 :].strip() + + # strip double-quotes + if len(value) >= 2 and value[0:1] == value[-1:] == b'"': + value = value[1:-1] + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') + pdict[name] = value + + return key, pdict + + +def _parseparam(s: bytes) -> Generator[bytes, None, None]: + """Generator which splits the input on ;, respecting double-quoted sequences + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + s: header to be parsed + + Returns: + The split input + """ + while s[:1] == b";": + s = s[1:] + + # look for the next ; + end = s.find(b";") + + # if there is an odd number of " marks between here and the next ;, skip to the + # next ; instead + while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: + end = s.find(b";", end + 1) + + if end < 0: + end = len(s) + f = s[:end] + yield f.strip() + s = s[end:] diff --git a/synapse/rest/media/v1/filepath.py b/synapse/media/filepath.py
index 1f6441c412..1f6441c412 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/media/filepath.py
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/media/media_repository.py
index c70e1837af..b81e3c2b0c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/media/media_repository.py
@@ -32,18 +32,10 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.config._base import ConfigError from synapse.config.repository import ThumbnailRequirement -from synapse.http.server import UnrecognizedRequestResource from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID -from synapse.util.async_helpers import Linearizer -from synapse.util.retryutils import NotRetryingDestination -from synapse.util.stringutils import random_string - -from ._base import ( +from synapse.media._base import ( FileInfo, Responder, ThumbnailInfo, @@ -51,15 +43,15 @@ from ._base import ( respond_404, respond_with_responder, ) -from .config_resource import MediaConfigResource -from .download_resource import DownloadResource -from .filepath import MediaFilePaths -from .media_storage import MediaStorage -from .preview_url_resource import PreviewUrlResource -from .storage_provider import StorageProviderWrapper -from .thumbnail_resource import ThumbnailResource -from .thumbnailer import Thumbnailer, ThumbnailError -from .upload_resource import UploadResource +from synapse.media.filepath import MediaFilePaths +from synapse.media.media_storage import MediaStorage +from synapse.media.storage_provider import StorageProviderWrapper +from synapse.media.thumbnailer import Thumbnailer, ThumbnailError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID +from synapse.util.async_helpers import Linearizer +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -1044,69 +1036,3 @@ class MediaRepository: removed_media.append(media_id) return removed_media, len(removed_media) - - -class MediaRepositoryResource(UnrecognizedRequestResource): - """File uploading and downloading. - - Uploads are POSTed to a resource which returns a token which is used to GET - the download:: - - => POST /_matrix/media/r0/upload HTTP/1.1 - Content-Type: <media-type> - Content-Length: <content-length> - - <media> - - <= HTTP/1.1 200 OK - Content-Type: application/json - - { "content_uri": "mxc://<server-name>/<media-id>" } - - => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 - - <= HTTP/1.1 200 OK - Content-Type: <media-type> - Content-Disposition: attachment;filename=<upload-filename> - - <media> - - Clients can get thumbnails by supplying a desired width and height and - thumbnailing method:: - - => GET /_matrix/media/r0/thumbnail/<server_name> - /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 - - <= HTTP/1.1 200 OK - Content-Type: image/jpeg or image/png - - <thumbnail> - - The thumbnail methods are "crop" and "scale". "scale" tries to return an - image where either the width or the height is smaller than the requested - size. The client should then scale and letterbox the image if it needs to - fit within a given rectangle. "crop" tries to return an image where the - width and height are close to the requested size and the aspect matches - the requested size. The client should scale the image if it needs to fit - within a given rectangle. - """ - - def __init__(self, hs: "HomeServer"): - # If we're not configured to use it, raise if we somehow got here. - if not hs.config.media.can_load_media_repo: - raise ConfigError("Synapse is not configured to use a media repo.") - - super().__init__() - media_repo = hs.get_media_repository() - - self.putChild(b"upload", UploadResource(hs, media_repo)) - self.putChild(b"download", DownloadResource(hs, media_repo)) - self.putChild( - b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) - ) - if hs.config.media.url_preview_enabled: - self.putChild( - b"preview_url", - PreviewUrlResource(hs, media_repo, media_repo.media_storage), - ) - self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py new file mode 100644
index 0000000000..a7e22a91e1 --- /dev/null +++ b/synapse/media/media_storage.py
@@ -0,0 +1,374 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import os +import shutil +from types import TracebackType +from typing import ( + IO, + TYPE_CHECKING, + Any, + Awaitable, + BinaryIO, + Callable, + Generator, + Optional, + Sequence, + Tuple, + Type, +) + +import attr + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender + +import synapse +from synapse.api.errors import NotFoundError +from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock +from synapse.util.file_consumer import BackgroundFileConsumer + +from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.media.storage_provider import StorageProvider + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class MediaStorage: + """Responsible for storing/fetching files from local sources. + + Args: + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. + """ + + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): + self.hs = hs + self.reactor = hs.get_reactor() + self.local_media_directory = local_media_directory + self.filepaths = filepaths + self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() + + async def store_file(self, source: IO, file_info: FileInfo) -> str: + """Write `source` to the on disk media store, and also any other + configured storage providers + + Args: + source: A file like object that should be written + file_info: Info about the file to store + + Returns: + the file path written to in the primary media store + """ + + with self.store_into_file(file_info) as (f, fname, finish_cb): + # Write to the main repository + await self.write_to_file(source, f) + await finish_cb() + + return fname + + async def write_to_file(self, source: IO, output: IO) -> None: + """Asynchronously write the `source` to `output`.""" + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + + @contextlib.contextmanager + def store_into_file( + self, file_info: FileInfo + ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: + """Context manager used to get a file like object to write into, as + described by file_info. + + Actually yields a 3-tuple (file, fname, finish_cb), where file is a file + like object that can be written to, fname is the absolute path of file + on disk, and finish_cb is a function that returns an awaitable. + + fname can be used to read the contents from after upload, e.g. to + generate thumbnails. + + finish_cb must be called and waited on after the file has been + successfully been written to. Should not be called if there was an + error. + + Args: + file_info: Info about the file to store + + Example: + + with media_storage.store_into_file(info) as (f, fname, finish_cb): + # .. write into f ... + await finish_cb() + """ + + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + os.makedirs(dirname, exist_ok=True) + + finished_called = [False] + + try: + with open(fname, "wb") as f: + + async def finish() -> None: + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + spam_check = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam_check != synapse.module_api.NOT_SPAM: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + # We currently ignore any additional field returned by + # the spam-check API. + raise SpamMediaException(errcode=spam_check[0]) + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + + yield f, fname, finish + except Exception as e: + try: + os.remove(fname) + except Exception: + pass + + raise e from None + + if not finished_called: + raise Exception("Finished callback not called") + + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: + """Attempts to fetch media described by file_info from the local cache + and configured storage providers. + + Args: + file_info + + Returns: + Returns a Responder if the file was found, otherwise None. + """ + paths = [self._file_info_to_path(file_info)] + + # fallback for remote thumbnails with no method in the filename + if file_info.thumbnail and file_info.server_name: + paths.append( + self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + ) + + for path in paths: + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + logger.debug("responding with local file %s", local_path) + return FileResponder(open(local_path, "rb")) + logger.debug("local file %s did not exist", local_path) + + for provider in self.storage_providers: + for path in paths: + res: Any = await provider.fetch(path, file_info) + if res: + logger.debug("Streaming %s from %s", path, provider) + return res + logger.debug("%s not found on %s", path, provider) + + return None + + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: + """Ensures that the given file is in the local cache. Attempts to + download it from storage providers if it isn't. + + Args: + file_info + + Returns: + Full path to local file + """ + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + return local_path + + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return legacy_local_path + + dirname = os.path.dirname(local_path) + os.makedirs(dirname, exist_ok=True) + + for provider in self.storage_providers: + res: Any = await provider.fetch(path, file_info) + if res: + with res: + consumer = BackgroundFileConsumer( + open(local_path, "wb"), self.reactor + ) + await res.write_to_consumer(consumer) + await consumer.wait() + return local_path + + raise NotFoundError() + + def _file_info_to_path(self, file_info: FileInfo) -> str: + """Converts file_info into a relative path. + + The path is suitable for storing files under a directory, e.g. used to + store files on local FS under the base media repository directory. + """ + if file_info.url_cache: + if file_info.thumbnail: + return self.filepaths.url_cache_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.url_cache_filepath_rel(file_info.file_id) + + if file_info.server_name: + if file_info.thumbnail: + return self.filepaths.remote_media_thumbnail_rel( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.remote_media_filepath_rel( + file_info.server_name, file_info.file_id + ) + + if file_info.thumbnail: + return self.filepaths.local_media_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.local_media_filepath_rel(file_info.file_id) + + +def _write_file_synchronously(source: IO, dest: IO) -> None: + """Write `source` to the file like `dest` synchronously. Should be called + from a thread. + + Args: + source: A file like object that's to be written + dest: A file like object to be written to + """ + source.seek(0) # Ensure we read from the start of the file + shutil.copyfileobj(source, dest) + + +class FileResponder(Responder): + """Wraps an open file that can be sent to a request. + + Args: + open_file: A file like object to be streamed ot the client, + is closed when finished streaming. + """ + + def __init__(self, open_file: IO): + self.open_file = open_file + + def write_to_consumer(self, consumer: IConsumer) -> Deferred: + return make_deferred_yieldable( + FileSender().beginFileTransfer(self.open_file, consumer) + ) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True, auto_attribs=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2**14 + + clock: Clock + path: str + + async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: + """Reads the file in chunks and calls the callback with each chunk.""" + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) diff --git a/synapse/rest/media/v1/oembed.py b/synapse/media/oembed.py
index 7592aa5d47..c0eaf04be5 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/media/oembed.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, List, Optional import attr -from synapse.rest.media.v1.preview_html import parse_html_description +from synapse.media.preview_html import parse_html_description from synapse.types import JsonDict from synapse.util import json_decoder diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/media/preview_html.py
index 516d0434f0..516d0434f0 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/media/preview_html.py
diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py new file mode 100644
index 0000000000..1c9b71d69c --- /dev/null +++ b/synapse/media/storage_provider.py
@@ -0,0 +1,181 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +import os +import shutil +from typing import TYPE_CHECKING, Callable, Optional + +from synapse.config._base import Config +from synapse.logging.context import defer_to_thread, run_in_background +from synapse.util.async_helpers import maybe_awaitable + +from ._base import FileInfo, Responder +from .media_storage import FileResponder + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class StorageProvider(metaclass=abc.ABCMeta): + """A storage provider is a service that can store uploaded media and + retrieve them. + """ + + @abc.abstractmethod + async def store_file(self, path: str, file_info: FileInfo) -> None: + """Store the file described by file_info. The actual contents can be + retrieved by reading the file in file_info.upload_path. + + Args: + path: Relative path of file in local cache + file_info: The metadata of the file. + """ + + @abc.abstractmethod + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + """Attempt to fetch the file described by file_info and stream it + into writer. + + Args: + path: Relative path of file in local cache + file_info: The metadata of the file. + + Returns: + Returns a Responder if the provider has the file, otherwise returns None. + """ + + +class StorageProviderWrapper(StorageProvider): + """Wraps a storage provider and provides various config options + + Args: + backend: The storage provider to wrap. + store_local: Whether to store new local files or not. + store_synchronous: Whether to wait for file to be successfully + uploaded, or todo the upload in the background. + store_remote: Whether remote media should be uploaded + """ + + def __init__( + self, + backend: StorageProvider, + store_local: bool, + store_synchronous: bool, + store_remote: bool, + ): + self.backend = backend + self.store_local = store_local + self.store_synchronous = store_synchronous + self.store_remote = store_remote + + def __str__(self) -> str: + return "StorageProviderWrapper[%s]" % (self.backend,) + + async def store_file(self, path: str, file_info: FileInfo) -> None: + if not file_info.server_name and not self.store_local: + return None + + if file_info.server_name and not self.store_remote: + return None + + if file_info.url_cache: + # The URL preview cache is short lived and not worth offloading or + # backing up. + return None + + if self.store_synchronous: + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore + else: + # TODO: Handle errors. + async def store() -> None: + try: + return await maybe_awaitable( + self.backend.store_file(path, file_info) + ) + except Exception: + logger.exception("Error storing file") + + run_in_background(store) + + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + if file_info.url_cache: + # Files in the URL preview cache definitely aren't stored here, + # so avoid any potentially slow I/O or network access. + return None + + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + return await maybe_awaitable(self.backend.fetch(path, file_info)) + + +class FileStorageProviderBackend(StorageProvider): + """A storage provider that stores files in a directory on a filesystem. + + Args: + hs + config: The config returned by `parse_config`. + """ + + def __init__(self, hs: "HomeServer", config: str): + self.hs = hs + self.cache_directory = hs.config.media.media_store_path + self.base_directory = config + + def __str__(self) -> str: + return "FileStorageProviderBackend[%s]" % (self.base_directory,) + + async def store_file(self, path: str, file_info: FileInfo) -> None: + """See StorageProvider.store_file""" + + primary_fname = os.path.join(self.cache_directory, path) + backup_fname = os.path.join(self.base_directory, path) + + dirname = os.path.dirname(backup_fname) + os.makedirs(dirname, exist_ok=True) + + # mypy needs help inferring the type of the second parameter, which is generic + shutil_copyfile: Callable[[str, str], str] = shutil.copyfile + await defer_to_thread( + self.hs.get_reactor(), + shutil_copyfile, + primary_fname, + backup_fname, + ) + + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + """See StorageProvider.fetch""" + + backup_fname = os.path.join(self.base_directory, path) + if os.path.isfile(backup_fname): + return FileResponder(open(backup_fname, "rb")) + + return None + + @staticmethod + def parse_config(config: dict) -> str: + """Called on startup to parse config supplied. This should parse + the config and raise if there is a problem. + + The returned value is passed into the constructor. + + In this case we only care about a single param, the directory, so let's + just pull that out. + """ + return Config.ensure_directory(config["directory"]) diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/media/thumbnailer.py
index 9480cc5763..f909a4fb9a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/media/thumbnailer.py
@@ -38,7 +38,6 @@ class ThumbnailError(Exception): class Thumbnailer: - FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} @staticmethod diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index b01372565d..8ce5887229 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py
@@ -87,7 +87,6 @@ class LaterGauge(Collector): ] def collect(self) -> Iterable[Metric]: - g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) try: diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py
index b7d47ce3e7..a22c4e5bbd 100644 --- a/synapse/metrics/_gc.py +++ b/synapse/metrics/_gc.py
@@ -139,7 +139,6 @@ def install_gc_manager() -> None: class PyPyGCStats(Collector): def collect(self) -> Iterable[Metric]: - # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB'). diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d22dd19d38..424239e3df 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -64,9 +64,11 @@ from synapse.events.third_party_rules import ( CHECK_EVENT_ALLOWED_CALLBACK, CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK, ON_CREATE_ROOM_CALLBACK, ON_NEW_EVENT_CALLBACK, ON_PROFILE_UPDATE_CALLBACK, + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK, ON_THREEPID_BIND_CALLBACK, ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) @@ -357,6 +359,12 @@ class ModuleApi: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, + on_add_user_third_party_identifier: Optional[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, + on_remove_user_third_party_identifier: Optional[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -373,6 +381,8 @@ class ModuleApi: on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, on_threepid_bind=on_threepid_bind, + on_add_user_third_party_identifier=on_add_user_third_party_identifier, + on_remove_user_third_party_identifier=on_remove_user_third_party_identifier, ) def register_presence_router_callbacks( @@ -1576,14 +1586,14 @@ class ModuleApi: ) requester = create_requester(user_id) - room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room( + room_id, room_alias, _ = await self._hs.get_room_creation_handler().create_room( requester=requester, config=config, ratelimit=ratelimit, creator_join_profile=creator_join_profile, ) - - return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None) + room_alias_str = room_alias.to_string() if room_alias else None + return room_id, room_alias_str async def set_displayname( self, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 2e917c90c4..ba12b6d79a 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -23,7 +23,6 @@ from typing import ( Mapping, Optional, Sequence, - Set, Tuple, Union, ) @@ -276,7 +275,7 @@ class BulkPushRuleEvaluator: if related_event is not None: related_events[relation_type] = _flatten_dict( related_event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ) reply_event_id = ( @@ -294,7 +293,7 @@ class BulkPushRuleEvaluator: if related_event is not None: related_events["m.in_reply_to"] = _flatten_dict( related_event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ) # indicate that this is from a fallback relation. @@ -330,7 +329,6 @@ class BulkPushRuleEvaluator: context: EventContext, event_id_to_event: Mapping[str, EventBase], ) -> None: - if ( not event.internal_metadata.is_notifiable() or event.internal_metadata.is_historical() @@ -397,30 +395,17 @@ class BulkPushRuleEvaluator: del notification_levels[key] # Pull out any user and room mentions. - mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) - has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict) - user_mentions: Set[str] = set() - room_mention = False - if has_mentions: - # mypy seems to have lost the type even though it must be a dict here. - assert isinstance(mentions, dict) - # Remove out any non-string items and convert to a set. - user_mentions_raw = mentions.get("user_ids") - if isinstance(user_mentions_raw, list): - user_mentions = set( - filter(lambda item: isinstance(item, str), user_mentions_raw) - ) - # Room mention is only true if the value is exactly true. - room_mention = mentions.get("room") is True + has_mentions = ( + self._intentional_mentions_enabled + and EventContentFields.MSC3952_MENTIONS in event.content + ) evaluator = PushRuleEvaluator( _flatten_dict( event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ), has_mentions, - user_mentions, - room_mention, room_member_count, sender_power_level, notification_levels, @@ -428,7 +413,6 @@ class BulkPushRuleEvaluator: self._related_event_match_enabled, event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag - self.hs.config.experimental.msc3758_exact_event_match, self.hs.config.experimental.msc3966_exact_event_property_contains, ) @@ -512,7 +496,7 @@ def _flatten_dict( prefix: Optional[List[str]] = None, result: Optional[Dict[str, JsonValue]] = None, *, - msc3783_escape_event_match_key: bool = False, + msc3873_escape_event_match_key: bool = False, ) -> Dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, @@ -541,7 +525,7 @@ def _flatten_dict( if result is None: result = {} for key, value in d.items(): - if msc3783_escape_event_match_key: + if msc3873_escape_event_match_key: # Escape periods in the key with a backslash (and backslashes with an # extra backslash). This is since a period is used as a separator between # nested fields. @@ -557,7 +541,7 @@ def _flatten_dict( value, prefix=(prefix + [key]), result=result, - msc3783_escape_event_match_key=msc3783_escape_event_match_key, + msc3873_escape_event_match_key=msc3873_escape_event_match_key, ) # `room_version` should only ever be set when looking at the top level of an event diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index bb76c169c6..222afbdcc8 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py
@@ -41,11 +41,12 @@ def format_push_rules_for_user( rulearray.append(template_rule) - pattern_type = template_rule.pop("pattern_type", None) - if pattern_type == "user_id": - template_rule["pattern"] = user.to_string() - elif pattern_type == "user_localpart": - template_rule["pattern"] = user.localpart + for type_key in ("pattern", "value"): + type_value = template_rule.pop(f"{type_key}_type", None) + if type_value == "user_id": + template_rule[type_key] = user.to_string() + elif type_value == "user_localpart": + template_rule[type_key] = user.localpart template_rule["enabled"] = enabled diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 2374f810c9..111ec07e64 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py
@@ -265,7 +265,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] - return {} async def _handle_request( # type: ignore[override] diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index ecea6fc915..cc3929dcf5 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py
@@ -195,7 +195,6 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): async def _serialize_payload( # type: ignore[override] user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: - return { "user_id": user_id, "device_id": device_id, diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 9fa1060d48..67b01db67e 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py
@@ -142,17 +142,12 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, - request: SynapseRequest, - content: JsonDict, - room_id: str, - user_id: str, + self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str ) -> Tuple[int, JsonDict]: remote_room_hosts = content["remote_room_hosts"] event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) - request.requester = requester logger.debug("remote_knock: %s on room: %s", user_id, room_id) @@ -277,16 +272,12 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, - request: SynapseRequest, - content: JsonDict, - knock_event_id: str, + self, request: SynapseRequest, content: JsonDict, knock_event_id: str ) -> Tuple[int, JsonDict]: txn_id = content["txn_id"] event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) - request.requester = requester # hopefully we're now on the master, so this won't recurse! @@ -363,3 +354,5 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) + ReplicationRemoteKnockRestServlet(hs).register(http_server) + ReplicationRemoteRescindKnockRestServlet(hs).register(http_server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index cc0528bd8e..424854efbe 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -370,15 +370,23 @@ class ReplicationDataHandler: # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): logger.info( - "Waiting for repl stream %r to reach %s (%s)", + "Waiting for repl stream %r to reach %s (%s); currently at: %s", stream_name, position, instance_name, + current_position, ) try: await make_deferred_yieldable(deferred) except defer.TimeoutError: - logger.error("Timed out waiting for stream %s", stream_name) + logger.error( + "Timed out waiting for repl stream %r to reach %s (%s)" + "; currently at: %s", + stream_name, + position, + instance_name, + self._streams[stream_name].current_token(instance_name), + ) return logger.info( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index fd1c0ec6af..dfc061eb5e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -328,7 +328,6 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): outbound_redis_connection: txredisapi.ConnectionHandler, channel_names: List[str], ): - super().__init__( hs, uuid="subscriber", diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 9d17eff714..347467d863 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -238,6 +238,24 @@ class ReplicationStreamer: except Exception: logger.exception("Failed to replicate") + # The last token we send may not match the current + # token, in which case we want to send out a `POSITION` + # to tell other workers the actual current position. + if updates[-1][0] < current_token: + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( + stream.NAME, + self._instance_name, + updates[-1][0], + current_token, + ) + ) + logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 14b6705862..ad9b760713 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -139,7 +139,6 @@ class EventsStream(Stream): current_token: Token, target_row_count: int, ) -> StreamUpdateResult: - # the events stream merges together three separate sources: # * new events # * current_state changes diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 14c4e6ebbb..2e19e055d3 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -108,8 +108,7 @@ class ClientRestResource(JsonResource): if is_main_process: logout.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource) - if is_main_process: - filter.register_servlets(hs, client_resource) + filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource) if is_main_process: @@ -140,7 +139,7 @@ class ClientRestResource(JsonResource): relations.register_servlets(hs, client_resource) if is_main_process: password_policy.register_servlets(hs, client_resource) - knock.register_servlets(hs, client_resource) + knock.register_servlets(hs, client_resource) # moving to /_synapse/admin if is_main_process: diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index a3beb74e2c..c546ef7e23 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py
@@ -53,11 +53,11 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) @@ -79,7 +79,7 @@ class EventReportsRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - event_reports, total = await self.store.get_event_reports_paginate( + event_reports, total = await self._store.get_event_reports_paginate( start, limit, direction, user_id, room_id ) ret = {"event_reports": event_reports, "total": total} @@ -108,13 +108,13 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) message = ( "The report_id parameter must be a string representing a positive integer." @@ -131,8 +131,33 @@ class EventReportDetailRestServlet(RestServlet): HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM ) - ret = await self.store.get_event_report(resolved_report_id) + ret = await self._store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") return HTTPStatus.OK, ret + + async def on_DELETE( + self, request: SynapseRequest, report_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if await self._store.delete_event_report(resolved_report_id): + return HTTPStatus.OK, {} + + raise NotFoundError("Event report not found") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 1d6e4982d7..4de56bf13f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -75,7 +75,6 @@ class RoomRestV2Servlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) await assert_user_is_admin(self._auth, requester) @@ -144,7 +143,6 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): @@ -181,7 +179,6 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, delete_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) delete_status = self._pagination_handler.get_delete_status(delete_id) @@ -438,7 +435,6 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$") def __init__(self, hs: "HomeServer"): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 0c0bf540b9..357e9a574d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -304,13 +304,20 @@ class UserRestServletV2(RestServlet): # remove old threepids for medium, address in del_threepids: try: - await self.auth_handler.delete_threepid( - user_id, medium, address, None + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + await self.hs.get_identity_handler().try_unbind_threepid( + user_id, medium, address, id_server=None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, medium, address + ) + # add new threepids current_time = self.hs.get_clock().time_msec() for medium, address in add_threepids: @@ -683,8 +690,12 @@ class AccountValidityRenewServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if self.account_activity_handler.on_legacy_admin_request_callback: - expiration_ts = await ( - self.account_activity_handler.on_legacy_admin_request_callback(request) + expiration_ts = ( + await ( + self.account_activity_handler.on_legacy_admin_request_callback( + request + ) + ) ) else: body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 662f5bf762..484d7440a4 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py
@@ -768,7 +768,9 @@ class ThreepidDeleteRestServlet(RestServlet): user_id = requester.user.to_string() try: - ret = await self.auth_handler.delete_threepid( + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + ret = await self.hs.get_identity_handler().try_unbind_threepid( user_id, body.medium, body.address, body.id_server ) except Exception: @@ -783,6 +785,11 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, body.medium, body.address + ) + return 200, {"id_server_unbind_result": id_server_unbind_result} diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index eb77337044..276a1b405d 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py
@@ -97,7 +97,6 @@ class AuthRestServlet(RestServlet): return None async def on_POST(self, request: Request, stagetype: str) -> None: - session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 486c6dbbc5..dab4a77f7e 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py
@@ -255,7 +255,7 @@ class DehydratedDeviceServlet(RestServlet): """ - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) + PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device$", releases=()) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 782e7d14e8..694d77d287 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py
@@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import SynapseError +from synapse.events.utils import SerializeEventConfig from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_string from synapse.http.site import SynapseRequest @@ -43,9 +44,8 @@ class EventStreamRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - is_guest = requester.is_guest args: Dict[bytes, List[bytes]] = request.args # type: ignore - if is_guest: + if requester.is_guest: if b"room_id" not in args: raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") @@ -63,13 +63,12 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in args chunk = await self.event_stream_handler.get_stream( - requester.user.to_string(), + requester, pagin_config, timeout=timeout, as_client_event=as_client_event, - affect_presence=(not is_guest), + affect_presence=(not requester.is_guest), room_id=room_id, - is_guest=is_guest, ) return 200, chunk @@ -91,9 +90,12 @@ class EventRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) - time_now = self.clock.time_msec() if event: - result = self._event_serializer.serialize_event(event, time_now) + result = self._event_serializer.serialize_event( + event, + self.clock.time_msec(), + config=SerializeEventConfig(requester=requester), + ) return 200, result else: return 404, "Event not found." diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index cc1c2f9731..236199897c 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py
@@ -79,7 +79,6 @@ class CreateFilterRestServlet(RestServlet): async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 7873b363c0..32bb8b9a91 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py
@@ -312,15 +312,29 @@ class SigningKeyUploadServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "add a device signing key to your account", - # Allow skipping of UI auth since this is frequently called directly - # after login and it is silly to ask users to re-auth immediately. - can_skip_ui_auth=True, - ) + if self.hs.config.experimental.msc3967_enabled: + if await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id): + # If we already have a master key then cross signing is set up and we require UIA to reset + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "reset the device signing key on your account", + # Do not allow skipping of UIA auth. + can_skip_ui_auth=False, + ) + # Otherwise we don't require UIA since we are setting up cross signing for first time + else: + # Previous behaviour is to always require UIA but allow it to be skipped + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "add a device signing key to your account", + # Allow skipping of UI auth since this is frequently called directly + # after login and it is silly to ask users to re-auth immediately. + can_skip_ui_auth=True, + ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) return 200, result diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index ad025c8a45..4fa66904ba 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple from synapse.api.constants import Membership from synapse.api.errors import SynapseError @@ -24,8 +24,6 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag -from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict, RoomAlias, RoomID if TYPE_CHECKING: @@ -45,7 +43,6 @@ class KnockRoomAliasServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.txns = HttpTransactionCache(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -53,7 +50,6 @@ class KnockRoomAliasServlet(RestServlet): self, request: SynapseRequest, room_identifier: str, - txn_id: Optional[str] = None, ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -67,7 +63,6 @@ class KnockRoomAliasServlet(RestServlet): # twisted.web.server.Request.args is incorrectly defined as Optional[Any] args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args( args, "server_name", required=False ) @@ -86,7 +81,6 @@ class KnockRoomAliasServlet(RestServlet): target=requester.user, room_id=room_id, action=Membership.KNOCK, - txn_id=txn_id, third_party_signed=None, remote_room_hosts=remote_room_hosts, content=event_content, @@ -94,15 +88,6 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} - def on_PUT( - self, request: SynapseRequest, room_identifier: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id - ) - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 61268e3af1..ea10042569 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py
@@ -72,6 +72,12 @@ class NotificationsServlet(RestServlet): next_token = None + serialize_options = SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id, + requester=requester, + ) + now = self.clock.time_msec() + for pa in push_actions: returned_pa = { "room_id": pa.room_id, @@ -81,10 +87,8 @@ class NotificationsServlet(RestServlet): "event": ( self._event_serializer.serialize_event( notif_events[pa.event_id], - self.clock.time_msec(), - config=SerializeEventConfig( - event_format=format_event_for_client_v2_without_room_id - ), + now, + config=serialize_options, ) ), } diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 3cb1e7e375..bce806f2bb 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py
@@ -628,10 +628,12 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = await ( - self.password_auth_provider.get_username_for_registration( - auth_result, - params, + desired_username = ( + await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) ) ) @@ -682,9 +684,11 @@ class RegisterRestServlet(RestServlet): session_id ) - display_name = await ( - self.password_auth_provider.get_displayname_for_registration( - auth_result, params + display_name = ( + await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) ) ) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index d0db85cca7..61e4cf0213 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py
@@ -37,7 +37,7 @@ from synapse.api.errors import ( UnredactedContentDeletedError, ) from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2 +from synapse.events.utils import SerializeEventConfig, format_event_for_client_v2 from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, @@ -160,11 +160,11 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - info, _ = await self._room_creation_handler.create_room( + room_id, _, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) - return 200, info + return 200, {"room_id": room_id} def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) @@ -814,11 +814,13 @@ class RoomEventServlet(RestServlet): [event], requester.user.to_string() ) - time_now = self.clock.time_msec() # per MSC2676, /rooms/{roomId}/event/{eventId}, should return the # *original* event, rather than the edited version event_dict = self._event_serializer.serialize_event( - event, time_now, bundle_aggregations=aggregations, apply_edits=False + event, + self.clock.time_msec(), + bundle_aggregations=aggregations, + config=SerializeEventConfig(requester=requester), ) return 200, event_dict @@ -863,24 +865,30 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() + serializer_options = SerializeEventConfig(requester=requester) results = { "events_before": self._event_serializer.serialize_events( event_context.events_before, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "event": self._event_serializer.serialize_event( event_context.event, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "events_after": self._event_serializer.serialize_events( event_context.events_after, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "state": self._event_serializer.serialize_events( - event_context.state, time_now + event_context.state, + time_now, + config=serializer_options, ), "start": event_context.start, "end": event_context.end, @@ -926,7 +934,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server: HttpServer) -> None: - # /rooms/$roomid/[invite|join|leave] + # /rooms/$roomid/[join|invite|leave|ban|unban|kick] PATTERNS = ( "/rooms/(?P<room_id>[^/]*)/" "(?P<membership_action>join|invite|leave|ban|unban|kick)" @@ -1192,7 +1200,7 @@ class SearchRestServlet(RestServlet): content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = await self.search_handler.search(requester.user, content, batch) + results = await self.search_handler.search(requester, content, batch) return 200, results diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 10be4a781b..ef284ecc11 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py
@@ -15,9 +15,7 @@ import logging import re from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, Tuple - -from twisted.web.server import Request +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError, Codes, SynapseError @@ -30,7 +28,6 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest -from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict if TYPE_CHECKING: @@ -79,7 +76,6 @@ class RoomBatchSendEventRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() - self.txns = HttpTransactionCache(hs) async def on_POST( self, request: SynapseRequest, room_id: str @@ -249,16 +245,6 @@ class RoomBatchSendEventRestServlet(RestServlet): return HTTPStatus.OK, response_dict - def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: - return HTTPStatus.NOT_IMPLEMENTED, "Not implemented" - - def on_PUT( - self, request: SynapseRequest, room_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: msc2716_enabled = hs.config.experimental.msc2716_enabled diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index f2013faeb2..e578b26fa3 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py
@@ -16,7 +16,7 @@ import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from synapse.api.constants import EduTypes, Membership, PresenceState +from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState @@ -38,7 +38,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace_with_opname -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict, Requester, StreamToken from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -139,7 +139,28 @@ class SyncRestServlet(RestServlet): device_id, ) - request_key = (user, timeout, since, filter_id, full_state, device_id) + # Stream position of the last ignored users account data event for this user, + # if we're initial syncing. + # We include this in the request key to invalidate an initial sync + # in the response cache once the set of ignored users has changed. + # (We filter out ignored users from timeline events, so our sync response + # is invalid once the set of ignored users changes.) + last_ignore_accdata_streampos: Optional[int] = None + if not since: + # No `since`, so this is an initial sync. + last_ignore_accdata_streampos = await self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( + user.to_string(), AccountDataTypes.IGNORED_USER_LIST + ) + + request_key = ( + user, + timeout, + since, + filter_id, + full_state, + device_id, + last_ignore_accdata_streampos, + ) if filter_id is None: filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION @@ -205,7 +226,7 @@ class SyncRestServlet(RestServlet): # We know that the the requester has an access token since appservices # cannot use sync. response_content = await self.encode_response( - time_now, sync_result, requester.access_token_id, filter_collection + time_now, sync_result, requester, filter_collection ) logger.debug("Event formatting complete") @@ -216,7 +237,7 @@ class SyncRestServlet(RestServlet): self, time_now: int, sync_result: SyncResult, - access_token_id: Optional[int], + requester: Requester, filter: FilterCollection, ) -> JsonDict: logger.debug("Formatting events in sync response") @@ -229,12 +250,12 @@ class SyncRestServlet(RestServlet): serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, only_event_fields=filter.event_fields, ) stripped_serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, include_stripped_room_state=True, ) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/config_resource.py
index a95804d327..a95804d327 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/config_resource.py
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/download_resource.py
index 048a042692..8f270cf4cc 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/download_resource.py
@@ -22,11 +22,10 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_boolean from synapse.http.site import SynapseRequest - -from ._base import parse_media_id, respond_404 +from synapse.media._base import parse_media_id, respond_404 if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py new file mode 100644
index 0000000000..5ebaa3b032 --- /dev/null +++ b/synapse/rest/media/media_repository_resource.py
@@ -0,0 +1,93 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse.config._base import ConfigError +from synapse.http.server import UnrecognizedRequestResource + +from .config_resource import MediaConfigResource +from .download_resource import DownloadResource +from .preview_url_resource import PreviewUrlResource +from .thumbnail_resource import ThumbnailResource +from .upload_resource import UploadResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class MediaRepositoryResource(UnrecognizedRequestResource): + """File uploading and downloading. + + Uploads are POSTed to a resource which returns a token which is used to GET + the download:: + + => POST /_matrix/media/r0/upload HTTP/1.1 + Content-Type: <media-type> + Content-Length: <content-length> + + <media> + + <= HTTP/1.1 200 OK + Content-Type: application/json + + { "content_uri": "mxc://<server-name>/<media-id>" } + + => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 + + <= HTTP/1.1 200 OK + Content-Type: <media-type> + Content-Disposition: attachment;filename=<upload-filename> + + <media> + + Clients can get thumbnails by supplying a desired width and height and + thumbnailing method:: + + => GET /_matrix/media/r0/thumbnail/<server_name> + /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 + + <= HTTP/1.1 200 OK + Content-Type: image/jpeg or image/png + + <thumbnail> + + The thumbnail methods are "crop" and "scale". "scale" tries to return an + image where either the width or the height is smaller than the requested + size. The client should then scale and letterbox the image if it needs to + fit within a given rectangle. "crop" tries to return an image where the + width and height are close to the requested size and the aspect matches + the requested size. The client should scale the image if it needs to fit + within a given rectangle. + """ + + def __init__(self, hs: "HomeServer"): + # If we're not configured to use it, raise if we somehow got here. + if not hs.config.media.can_load_media_repo: + raise ConfigError("Synapse is not configured to use a media repo.") + + super().__init__() + media_repo = hs.get_media_repository() + + self.putChild(b"upload", UploadResource(hs, media_repo)) + self.putChild(b"download", DownloadResource(hs, media_repo)) + self.putChild( + b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) + ) + if hs.config.media.url_preview_enabled: + self.putChild( + b"preview_url", + PreviewUrlResource(hs, media_repo, media_repo.media_storage), + ) + self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py
index a8f6fd6b35..7ada728757 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/preview_url_resource.py
@@ -40,21 +40,19 @@ from synapse.http.server import ( from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.media._base import FileInfo, get_filename_from_headers +from synapse.media.media_storage import MediaStorage +from synapse.media.oembed import OEmbedProvider +from synapse.media.preview_html import decode_body, parse_html_to_open_graph from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.rest.media.v1._base import get_filename_from_headers -from synapse.rest.media.v1.media_storage import MediaStorage -from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string -from ._base import FileInfo - if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -163,6 +161,10 @@ class PreviewUrlResource(DirectServeJsonResource): 7. Stores the result in the database cache. 4. Returns the result. + If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or + image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole + does not fail. As much information as possible is returned. + The in-memory cache expires after 1 hour. Expired entries in the database cache (and their associated media files) are @@ -364,16 +366,25 @@ class PreviewUrlResource(DirectServeJsonResource): oembed_url = self._oembed.autodiscover_from_html(tree) og_from_oembed: JsonDict = {} if oembed_url: - oembed_info = await self._handle_url( - oembed_url, user, allow_data_urls=True - ) - ( - og_from_oembed, - author_name, - expiration_ms, - ) = await self._handle_oembed_response( - url, oembed_info, expiration_ms - ) + try: + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) + except Exception as e: + # Fetching the oEmbed info failed, don't block the entire URL preview. + logger.warning( + "oEmbed fetch failed during URL preview: %s errored with %s", + oembed_url, + e, + ) + else: + ( + og_from_oembed, + author_name, + expiration_ms, + ) = await self._handle_oembed_response( + url, oembed_info, expiration_ms + ) # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 5f725c7600..4ee2a0dbda 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py
@@ -27,9 +27,7 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.rest.media.v1.media_storage import MediaStorage - -from ._base import ( +from synapse.media._base import ( FileInfo, ThumbnailInfo, parse_media_id, @@ -37,9 +35,10 @@ from ._base import ( respond_with_file, respond_with_responder, ) +from synapse.media.media_storage import MediaStorage if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -69,7 +68,8 @@ class ThumbnailResource(DirectServeJsonResource): width = parse_integer(request, "width", required=True) height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") - m_type = parse_string(request, "type", "image/png") + # TODO Parse the Accept header to get an prioritised list of thumbnail types. + m_type = "image/png" if server_name == self.server_name: if self.dynamic_thumbnails: diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/upload_resource.py
index 97548b54e5..697348613b 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/upload_resource.py
@@ -20,10 +20,10 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_bytes_from_args from synapse.http.site import SynapseRequest -from synapse.rest.media.v1.media_storage import SpamMediaException +from synapse.media.media_storage import SpamMediaException if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index d30878f704..88427a5737 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py
@@ -1,5 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,466 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -import logging -import os -import urllib -from types import TracebackType -from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type - -import attr - -from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender -from twisted.web.server import Request - -from synapse.api.errors import Codes, SynapseError, cs_error -from synapse.http.server import finish_request, respond_with_json -from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable -from synapse.util.stringutils import is_ascii, parse_and_validate_server_name - -logger = logging.getLogger(__name__) - -# list all text content types that will have the charset default to UTF-8 when -# none is given -TEXT_CONTENT_TYPES = [ - "text/css", - "text/csv", - "text/html", - "text/calendar", - "text/plain", - "text/javascript", - "application/json", - "application/ld+json", - "application/rtf", - "image/svg+xml", - "text/xml", -] - - -def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: - """Parses the server name, media ID and optional file name from the request URI - - Also performs some rough validation on the server name. - - Args: - request: The `Request`. - - Returns: - A tuple containing the parsed server name, media ID and optional file name. - - Raises: - SynapseError(404): if parsing or validation fail for any reason - """ - try: - # The type on postpath seems incorrect in Twisted 21.2.0. - postpath: List[bytes] = request.postpath # type: ignore - assert postpath - - # This allows users to append e.g. /test.png to the URL. Useful for - # clients that parse the URL to see content type. - server_name_bytes, media_id_bytes = postpath[:2] - server_name = server_name_bytes.decode("utf-8") - media_id = media_id_bytes.decode("utf8") - - # Validate the server name, raising if invalid - parse_and_validate_server_name(server_name) - - file_name = None - if len(postpath) > 2: - try: - file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) - except UnicodeDecodeError: - pass - return server_name, media_id, file_name - except Exception: - raise SynapseError( - 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN - ) - - -def respond_404(request: SynapseRequest) -> None: - respond_with_json( - request, - 404, - cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), - send_cors=True, - ) - - -async def respond_with_file( - request: SynapseRequest, - media_type: str, - file_path: str, - file_size: Optional[int] = None, - upload_name: Optional[str] = None, -) -> None: - logger.debug("Responding with %r", file_path) - - if os.path.isfile(file_path): - if file_size is None: - stat = os.stat(file_path) - file_size = stat.st_size - - add_file_headers(request, media_type, file_size, upload_name) - - with open(file_path, "rb") as f: - await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) - - finish_request(request) - else: - respond_404(request) - - -def add_file_headers( - request: Request, - media_type: str, - file_size: Optional[int], - upload_name: Optional[str], -) -> None: - """Adds the correct response headers in preparation for responding with the - media. - - Args: - request - media_type: The media/content type. - file_size: Size in bytes of the media, if known. - upload_name: The name of the requested file, if any. - """ - - def _quote(x: str) -> str: - return urllib.parse.quote(x.encode("utf-8")) - - # Default to a UTF-8 charset for text content types. - # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' - if media_type.lower() in TEXT_CONTENT_TYPES: - content_type = media_type + "; charset=UTF-8" - else: - content_type = media_type - - request.setHeader(b"Content-Type", content_type.encode("UTF-8")) - if upload_name: - # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. - # - # `filename` is defined to be a `value`, which is defined by RFC2616 - # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` - # is (essentially) a single US-ASCII word, and a `quoted-string` is a - # US-ASCII string surrounded by double-quotes, using backslash as an - # escape character. Note that %-encoding is *not* permitted. - # - # `filename*` is defined to be an `ext-value`, which is defined in - # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, - # where `value-chars` is essentially a %-encoded string in the given charset. - # - # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 - # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 - # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 - - # We avoid the quoted-string version of `filename`, because (a) synapse didn't - # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we - # may as well just do the filename* version. - if _can_encode_filename_as_token(upload_name): - disposition = "inline; filename=%s" % (upload_name,) - else: - disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) - - request.setHeader(b"Content-Disposition", disposition.encode("ascii")) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - if file_size is not None: - request.setHeader(b"Content-Length", b"%d" % (file_size,)) - - # Tell web crawlers to not index, archive, or follow links in media. This - # should help to prevent things in the media repo from showing up in web - # search results. - request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") - - -# separators as defined in RFC2616. SP and HT are handled separately. -# see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = { - "(", - ")", - "<", - ">", - "@", - ",", - ";", - ":", - "\\", - '"', - "/", - "[", - "]", - "?", - "=", - "{", - "}", -} - - -def _can_encode_filename_as_token(x: str) -> bool: - for c in x: - # from RFC2616: - # - # token = 1*<any CHAR except CTLs or separators> - # - # separators = "(" | ")" | "<" | ">" | "@" - # | "," | ";" | ":" | "\" | <"> - # | "/" | "[" | "]" | "?" | "=" - # | "{" | "}" | SP | HT - # - # CHAR = <any US-ASCII character (octets 0 - 127)> - # - # CTL = <any US-ASCII control character - # (octets 0 - 31) and DEL (127)> - # - if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: - return False - return True - - -async def respond_with_responder( - request: SynapseRequest, - responder: "Optional[Responder]", - media_type: str, - file_size: Optional[int], - upload_name: Optional[str] = None, -) -> None: - """Responds to the request with given responder. If responder is None then - returns 404. - - Args: - request - responder - media_type: The media/content type. - file_size: Size in bytes of the media. If not known it should be None - upload_name: The name of the requested file, if any. - """ - if not responder: - respond_404(request) - return - - # If we have a responder we *must* use it as a context manager. - with responder: - if request._disconnected: - logger.warning( - "Not sending response to request %s, already disconnected.", request - ) - return - - logger.debug("Responding to media request with responder %s", responder) - add_file_headers(request, media_type, file_size, upload_name) - try: - - await responder.write_to_consumer(request) - except Exception as e: - # The majority of the time this will be due to the client having gone - # away. Unfortunately, Twisted simply throws a generic exception at us - # in that case. - logger.warning("Failed to write to consumer: %s %s", type(e), e) - - # Unregister the producer, if it has one, so Twisted doesn't complain - if request.producer: - request.unregisterProducer() - - finish_request(request) - - -class Responder: - """Represents a response that can be streamed to the requester. - - Responder is a context manager which *must* be used, so that any resources - held can be cleaned up. - """ - - def write_to_consumer(self, consumer: IConsumer) -> Awaitable: - """Stream response into consumer - - Args: - consumer: The consumer to stream into. - - Returns: - Resolves once the response has finished being written - """ - - def __enter__(self) -> None: - pass - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - pass - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ThumbnailInfo: - """Details about a generated thumbnail.""" - - width: int - height: int - method: str - # Content type of thumbnail, e.g. image/png - type: str - # The size of the media file, in bytes. - length: Optional[int] = None - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class FileInfo: - """Details about a requested/uploaded file.""" - - # The server name where the media originated from, or None if local. - server_name: Optional[str] - # The local ID of the file. For local files this is the same as the media_id - file_id: str - # If the file is for the url preview cache - url_cache: bool = False - # Whether the file is a thumbnail or not. - thumbnail: Optional[ThumbnailInfo] = None - - # The below properties exist to maintain compatibility with third-party modules. - @property - def thumbnail_width(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.width - - @property - def thumbnail_height(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.height - - @property - def thumbnail_method(self) -> Optional[str]: - if not self.thumbnail: - return None - return self.thumbnail.method - - @property - def thumbnail_type(self) -> Optional[str]: - if not self.thumbnail: - return None - return self.thumbnail.type - - @property - def thumbnail_length(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.length - - -def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: - """ - Get the filename of the downloaded file by inspecting the - Content-Disposition HTTP header. - - Args: - headers: The HTTP request headers. - - Returns: - The filename, or None. - """ - content_disposition = headers.get(b"Content-Disposition", [b""]) - - # No header, bail out. - if not content_disposition[0]: - return None - - _, params = _parse_header(content_disposition[0]) - - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get(b"filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith(b"utf-8''"): - upload_name_utf8 = upload_name_utf8[7:] - # We have a filename*= section. This MUST be ASCII, and any UTF-8 - # bytes are %-quoted. - try: - # Once it is decoded, we can then unquote the %-encoded - # parts strictly into a unicode string. - upload_name = urllib.parse.unquote( - upload_name_utf8.decode("ascii"), errors="strict" - ) - except UnicodeDecodeError: - # Incorrect UTF-8. - pass - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get(b"filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii.decode("ascii") - - # This may be None here, indicating we did not find a matching name. - return upload_name - - -def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: - """Parse a Content-type like header. - - Cargo-culted from `cgi`, but works on bytes rather than strings. - - Args: - line: header to be parsed - - Returns: - The main content-type, followed by the parameter dictionary - """ - parts = _parseparam(b";" + line) - key = next(parts) - pdict = {} - for p in parts: - i = p.find(b"=") - if i >= 0: - name = p[:i].strip().lower() - value = p[i + 1 :].strip() - - # strip double-quotes - if len(value) >= 2 and value[0:1] == value[-1:] == b'"': - value = value[1:-1] - value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') - pdict[name] = value - - return key, pdict - - -def _parseparam(s: bytes) -> Generator[bytes, None, None]: - """Generator which splits the input on ;, respecting double-quoted sequences - - Cargo-culted from `cgi`, but works on bytes rather than strings. - - Args: - s: header to be parsed - - Returns: - The split input - """ - while s[:1] == b";": - s = s[1:] - - # look for the next ; - end = s.find(b";") - - # if there is an odd number of " marks between here and the next ;, skip to the - # next ; instead - while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: - end = s.find(b";", end + 1) - - if end < 0: - end = len(s) - f = s[:end] - yield f.strip() - s = s[end:] +# This exists purely for backwards compatibility with media providers and spam checkers. +from synapse.media._base import FileInfo, Responder # noqa: F401 diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index db25848744..11b0e8e231 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,364 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import os -import shutil -from types import TracebackType -from typing import ( - IO, - TYPE_CHECKING, - Any, - Awaitable, - BinaryIO, - Callable, - Generator, - Optional, - Sequence, - Tuple, - Type, -) - -import attr - -from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender - -import synapse -from synapse.api.errors import NotFoundError -from synapse.logging.context import defer_to_thread, make_deferred_yieldable -from synapse.util import Clock -from synapse.util.file_consumer import BackgroundFileConsumer - -from ._base import FileInfo, Responder -from .filepath import MediaFilePaths - -if TYPE_CHECKING: - from synapse.rest.media.v1.storage_provider import StorageProvider - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class MediaStorage: - """Responsible for storing/fetching files from local sources. - - Args: - hs - local_media_directory: Base path where we store media on disk - filepaths - storage_providers: List of StorageProvider that are used to fetch and store files. - """ - - def __init__( - self, - hs: "HomeServer", - local_media_directory: str, - filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProvider"], - ): - self.hs = hs - self.reactor = hs.get_reactor() - self.local_media_directory = local_media_directory - self.filepaths = filepaths - self.storage_providers = storage_providers - self.spam_checker = hs.get_spam_checker() - self.clock = hs.get_clock() - - async def store_file(self, source: IO, file_info: FileInfo) -> str: - """Write `source` to the on disk media store, and also any other - configured storage providers - - Args: - source: A file like object that should be written - file_info: Info about the file to store - - Returns: - the file path written to in the primary media store - """ - - with self.store_into_file(file_info) as (f, fname, finish_cb): - # Write to the main repository - await self.write_to_file(source, f) - await finish_cb() - - return fname - - async def write_to_file(self, source: IO, output: IO) -> None: - """Asynchronously write the `source` to `output`.""" - await defer_to_thread(self.reactor, _write_file_synchronously, source, output) - - @contextlib.contextmanager - def store_into_file( - self, file_info: FileInfo - ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: - """Context manager used to get a file like object to write into, as - described by file_info. - - Actually yields a 3-tuple (file, fname, finish_cb), where file is a file - like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns an awaitable. - - fname can be used to read the contents from after upload, e.g. to - generate thumbnails. - - finish_cb must be called and waited on after the file has been - successfully been written to. Should not be called if there was an - error. - - Args: - file_info: Info about the file to store - - Example: - - with media_storage.store_into_file(info) as (f, fname, finish_cb): - # .. write into f ... - await finish_cb() - """ - - path = self._file_info_to_path(file_info) - fname = os.path.join(self.local_media_directory, path) - - dirname = os.path.dirname(fname) - os.makedirs(dirname, exist_ok=True) - - finished_called = [False] - - try: - with open(fname, "wb") as f: - - async def finish() -> None: - # Ensure that all writes have been flushed and close the - # file. - f.flush() - f.close() - - spam_check = await self.spam_checker.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info - ) - if spam_check != synapse.module_api.NOT_SPAM: - logger.info("Blocking media due to spam checker") - # Note that we'll delete the stored media, due to the - # try/except below. The media also won't be stored in - # the DB. - # We currently ignore any additional field returned by - # the spam-check API. - raise SpamMediaException(errcode=spam_check[0]) - - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - - yield f, fname, finish - except Exception as e: - try: - os.remove(fname) - except Exception: - pass - - raise e from None - - if not finished_called: - raise Exception("Finished callback not called") - - async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: - """Attempts to fetch media described by file_info from the local cache - and configured storage providers. - - Args: - file_info - - Returns: - Returns a Responder if the file was found, otherwise None. - """ - paths = [self._file_info_to_path(file_info)] - - # fallback for remote thumbnails with no method in the filename - if file_info.thumbnail and file_info.server_name: - paths.append( - self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - ) - - for path in paths: - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - logger.debug("responding with local file %s", local_path) - return FileResponder(open(local_path, "rb")) - logger.debug("local file %s did not exist", local_path) - - for provider in self.storage_providers: - for path in paths: - res: Any = await provider.fetch(path, file_info) - if res: - logger.debug("Streaming %s from %s", path, provider) - return res - logger.debug("%s not found on %s", path, provider) - - return None - - async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: - """Ensures that the given file is in the local cache. Attempts to - download it from storage providers if it isn't. - - Args: - file_info - - Returns: - Full path to local file - """ - path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return local_path - - # Fallback for paths without method names - # Should be removed in the future - if file_info.thumbnail and file_info.server_name: - legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - legacy_local_path = os.path.join(self.local_media_directory, legacy_path) - if os.path.exists(legacy_local_path): - return legacy_local_path - - dirname = os.path.dirname(local_path) - os.makedirs(dirname, exist_ok=True) - - for provider in self.storage_providers: - res: Any = await provider.fetch(path, file_info) - if res: - with res: - consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.reactor - ) - await res.write_to_consumer(consumer) - await consumer.wait() - return local_path - - raise NotFoundError() - - def _file_info_to_path(self, file_info: FileInfo) -> str: - """Converts file_info into a relative path. - - The path is suitable for storing files under a directory, e.g. used to - store files on local FS under the base media repository directory. - """ - if file_info.url_cache: - if file_info.thumbnail: - return self.filepaths.url_cache_thumbnail_rel( - media_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.url_cache_filepath_rel(file_info.file_id) - - if file_info.server_name: - if file_info.thumbnail: - return self.filepaths.remote_media_thumbnail_rel( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.remote_media_filepath_rel( - file_info.server_name, file_info.file_id - ) - - if file_info.thumbnail: - return self.filepaths.local_media_thumbnail_rel( - media_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.local_media_filepath_rel(file_info.file_id) - - -def _write_file_synchronously(source: IO, dest: IO) -> None: - """Write `source` to the file like `dest` synchronously. Should be called - from a thread. - - Args: - source: A file like object that's to be written - dest: A file like object to be written to - """ - source.seek(0) # Ensure we read from the start of the file - shutil.copyfileobj(source, dest) - - -class FileResponder(Responder): - """Wraps an open file that can be sent to a request. - - Args: - open_file: A file like object to be streamed ot the client, - is closed when finished streaming. - """ - - def __init__(self, open_file: IO): - self.open_file = open_file - - def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - FileSender().beginFileTransfer(self.open_file, consumer) - ) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.open_file.close() - - -class SpamMediaException(NotFoundError): - """The media was blocked by a spam checker, so we simply 404 the request (in - the same way as if it was quarantined). - """ - - -@attr.s(slots=True, auto_attribs=True) -class ReadableFileWrapper: - """Wrapper that allows reading a file in chunks, yielding to the reactor, - and writing to a callback. - - This is simplified `FileSender` that takes an IO object rather than an - `IConsumer`. - """ - - CHUNK_SIZE = 2**14 - - clock: Clock - path: str - - async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: - """Reads the file in chunks and calls the callback with each chunk.""" - - with open(self.path, "rb") as file: - while True: - chunk = file.read(self.CHUNK_SIZE) - if not chunk: - break - - callback(chunk) +# - # We yield to the reactor by sleeping for 0 seconds. - await self.clock.sleep(0) +# This exists purely for backwards compatibility with spam checkers. +from synapse.media.media_storage import ReadableFileWrapper # noqa: F401 diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 1c9b71d69c..d7653f30ae 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,171 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -import abc -import logging -import os -import shutil -from typing import TYPE_CHECKING, Callable, Optional - -from synapse.config._base import Config -from synapse.logging.context import defer_to_thread, run_in_background -from synapse.util.async_helpers import maybe_awaitable - -from ._base import FileInfo, Responder -from .media_storage import FileResponder - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class StorageProvider(metaclass=abc.ABCMeta): - """A storage provider is a service that can store uploaded media and - retrieve them. - """ - - @abc.abstractmethod - async def store_file(self, path: str, file_info: FileInfo) -> None: - """Store the file described by file_info. The actual contents can be - retrieved by reading the file in file_info.upload_path. - - Args: - path: Relative path of file in local cache - file_info: The metadata of the file. - """ - - @abc.abstractmethod - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - """Attempt to fetch the file described by file_info and stream it - into writer. - - Args: - path: Relative path of file in local cache - file_info: The metadata of the file. - - Returns: - Returns a Responder if the provider has the file, otherwise returns None. - """ - - -class StorageProviderWrapper(StorageProvider): - """Wraps a storage provider and provides various config options - - Args: - backend: The storage provider to wrap. - store_local: Whether to store new local files or not. - store_synchronous: Whether to wait for file to be successfully - uploaded, or todo the upload in the background. - store_remote: Whether remote media should be uploaded - """ - - def __init__( - self, - backend: StorageProvider, - store_local: bool, - store_synchronous: bool, - store_remote: bool, - ): - self.backend = backend - self.store_local = store_local - self.store_synchronous = store_synchronous - self.store_remote = store_remote - - def __str__(self) -> str: - return "StorageProviderWrapper[%s]" % (self.backend,) - - async def store_file(self, path: str, file_info: FileInfo) -> None: - if not file_info.server_name and not self.store_local: - return None - - if file_info.server_name and not self.store_remote: - return None - - if file_info.url_cache: - # The URL preview cache is short lived and not worth offloading or - # backing up. - return None - - if self.store_synchronous: - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore - else: - # TODO: Handle errors. - async def store() -> None: - try: - return await maybe_awaitable( - self.backend.store_file(path, file_info) - ) - except Exception: - logger.exception("Error storing file") - - run_in_background(store) - - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - if file_info.url_cache: - # Files in the URL preview cache definitely aren't stored here, - # so avoid any potentially slow I/O or network access. - return None - - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - return await maybe_awaitable(self.backend.fetch(path, file_info)) - - -class FileStorageProviderBackend(StorageProvider): - """A storage provider that stores files in a directory on a filesystem. - - Args: - hs - config: The config returned by `parse_config`. - """ - - def __init__(self, hs: "HomeServer", config: str): - self.hs = hs - self.cache_directory = hs.config.media.media_store_path - self.base_directory = config - - def __str__(self) -> str: - return "FileStorageProviderBackend[%s]" % (self.base_directory,) - - async def store_file(self, path: str, file_info: FileInfo) -> None: - """See StorageProvider.store_file""" - - primary_fname = os.path.join(self.cache_directory, path) - backup_fname = os.path.join(self.base_directory, path) - - dirname = os.path.dirname(backup_fname) - os.makedirs(dirname, exist_ok=True) - - # mypy needs help inferring the type of the second parameter, which is generic - shutil_copyfile: Callable[[str, str], str] = shutil.copyfile - await defer_to_thread( - self.hs.get_reactor(), - shutil_copyfile, - primary_fname, - backup_fname, - ) - - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - """See StorageProvider.fetch""" - - backup_fname = os.path.join(self.base_directory, path) - if os.path.isfile(backup_fname): - return FileResponder(open(backup_fname, "rb")) - - return None - - @staticmethod - def parse_config(config: dict) -> str: - """Called on startup to parse config supplied. This should parse - the config and raise if there is a problem. - - The returned value is passed into the constructor. - - In this case we only care about a single param, the directory, so let's - just pull that out. - """ - return Config.ensure_directory(config["directory"]) +# This exists purely for backwards compatibility with media providers. +from synapse.media.storage_provider import StorageProvider # noqa: F401 diff --git a/synapse/server.py b/synapse/server.py
index e5a3475247..df80fc1beb 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -105,6 +105,7 @@ from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient +from synapse.media.media_repository import MediaRepository from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi from synapse.notifier import Notifier, ReplicationNotifier @@ -115,10 +116,7 @@ from synapse.replication.tcp.external_cache import ExternalCache from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.streams import STREAMS_MAP, Stream -from synapse.rest.media.v1.media_repository import ( - MediaRepository, - MediaRepositoryResource, -) +from synapse.rest.media.media_repository_resource import MediaRepositoryResource from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import ( @@ -745,7 +743,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer(self.config.experimental.msc3925_inhibit_edit) + return EventClientSerializer() @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 564e3705c2..9732dbdb6e 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py
@@ -178,7 +178,7 @@ class ServerNoticesManager: "avatar_url": self._config.servernotices.server_notices_mxid_avatar_url, } - info, _ = await self._room_creation_handler.create_room( + room_id, _, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, @@ -188,7 +188,6 @@ class ServerNoticesManager: ratelimit=False, creator_join_profile=join_profile, ) - room_id = info["room_id"] self.maybe_get_notice_room_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 837dc7646e..dc3948c170 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore from .events_forward_extremities import EventForwardExtremitiesStore -from .filtering import FilteringStore +from .filtering import FilteringWorkerStore from .keys import KeyStore from .lock import LockStore from .media_repository import MediaRepositoryStore @@ -99,7 +99,7 @@ class DataStore( EventFederationStore, MediaRepositoryStore, RejectionsStore, - FilteringStore, + FilteringWorkerStore, PusherStore, PushRuleStore, ApplicationServiceTransactionStore, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 95567826f2..a9843f6e17 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -40,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -64,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ): super().__init__(database, db_conn, hs) - # `_can_write_to_account_data` indicates whether the current worker is allowed - # to write account data. A value of `True` implies that `_account_data_id_gen` - # is an `AbstractStreamIdGenerator` and not just a tracker. - self._account_data_id_gen: AbstractStreamIdTracker self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data ) + self._account_data_id_gen: AbstractStreamIdGenerator + if isinstance(database.engine, PostgresEngine): self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -237,6 +234,37 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: return None + async def get_latest_stream_id_for_global_account_data_by_type_for_user( + self, user_id: str, data_type: str + ) -> Optional[int]: + """ + Returns: + The stream ID of the account data, + or None if there is no such account data. + """ + + def get_latest_stream_id_for_global_account_data_by_type_for_user_txn( + txn: LoggingTransaction, + ) -> Optional[int]: + sql = """ + SELECT stream_id FROM account_data + WHERE user_id = ? AND account_data_type = ? + ORDER BY stream_id DESC + LIMIT 1 + """ + txn.execute(sql, (user_id, data_type)) + + row = txn.fetchone() + if row: + return row[0] + else: + return None + + return await self.db_pool.runInteraction( + "get_latest_stream_id_for_global_account_data_by_type_for_user", + get_latest_stream_id_for_global_account_data_by_type_for_user_txn, + ) + @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str @@ -527,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) The maximum stream ID. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -554,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def remove_account_data_for_room( self, user_id: str, room_id: str, account_data_type: str - ) -> Optional[int]: + ) -> int: """Delete the room account data for the user of a given type. Args: @@ -567,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) data to delete. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def _remove_account_data_for_room_txn( txn: LoggingTransaction, next_id: int @@ -606,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) next_id, ) - if not row_updated: - return None - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_room_account_data_for_user.invalidate((user_id,)) - self.get_account_data_for_room.invalidate((user_id, room_id)) - self.get_account_data_for_room_and_type.prefill( - (user_id, room_id, account_data_type), {} - ) + if row_updated: + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_room_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), {} + ) return self._account_data_id_gen.get_current_token() @@ -632,7 +656,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) The maximum stream ID. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -722,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self, user_id: str, account_data_type: str, - ) -> Optional[int]: + ) -> int: """ Delete a single piece of user account data by type. @@ -739,7 +762,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) to delete. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def _remove_account_data_for_user_txn( txn: LoggingTransaction, next_id: int @@ -809,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) next_id, ) - if not row_updated: - return None - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_global_account_data_for_user.invalidate((user_id,)) - self.get_global_account_data_by_type_for_user.prefill( - (user_id, account_data_type), {} - ) + if row_updated: + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_global_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.prefill( + (user_id, account_data_type), {} + ) return self._account_data_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 5b66431691..096dec7f87 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -266,9 +266,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_aggregation_groups_for_event", (relates_to,) - ) self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8e61aba454..0d75d9739a 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - for (user_id, messages_by_device) in edu["messages"].items(): - for (device_id, msg) in messages_by_device.items(): + for user_id, messages_by_device in edu["messages"].items(): + for device_id, msg in messages_by_device.items(): with start_active_span("store_outgoing_to_device_message"): set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) @@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): def _remove_dead_devices_from_device_inbox_txn( txn: LoggingTransaction, ) -> Tuple[int, bool]: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1ca66d57d4..5503621ad6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -52,7 +52,6 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, StreamIdGenerator, ) from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key @@ -91,7 +90,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._device_list_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "device_lists_stream", @@ -512,7 +511,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): results.append(("org.matrix.signing_key_update", result)) if issue_8631_logger.isEnabledFor(logging.DEBUG): - for (user_id, edu) in results: + for user_id, edu in results: issue_8631_logger.debug( "device update to %s for %s from %s to %s: %s", destination, @@ -712,9 +711,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): The new stream ID. """ - # TODO: this looks like it's _writing_. Should this be on DeviceStore rather - # than DeviceWorkerStore? - async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1316,7 +1313,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) """ count = 0 - for (destination, user_id, stream_id, device_id) in rows: + for destination, user_id, stream_id, device_id in rows: txn.execute( delete_sql, (destination, user_id, stream_id, stream_id, device_id) ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 6240f9a75e..9f8d2e4bea 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No backup with that version exists") values = [] - for (room_id, session_id, room_key) in room_keys: + for room_id, session_id, room_key in room_keys: values.append( ( user_id, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 2c2d145666..b9c39b1718 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) # add each cross-signing signature to the correct device in the result dict. - for (user_id, key_id, device_id, signature) in cross_sigs_result: + for user_id, key_id, device_id, signature in cross_sigs_result: target_device_result = result[user_id][device_id] # We've only looked up cross-signatures for non-deleted devices with key # data. @@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker # devices. user_list = [] user_device_list = [] - for (user_id, device_id) in query_list: + for user_id, device_id in query_list: if device_id is None: user_list.append(user_id) else: @@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker txn.execute(sql, query_params) - for (user_id, device_id, display_name, key_json) in txn: + for user_id, device_id, display_name, key_json in txn: assert device_id is not None if include_deleted_devices: deleted_devices.remove((user_id, device_id)) @@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signature_query_clauses = [] signature_query_params = [] - for (user_id, device_id) in device_query: + for user_id, device_id in device_query: signature_query_clauses.append( "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ca780cca36..ff3edeb716 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas latest_events: List[str], limit: int, ) -> List[str]: - seen_events = set(earliest_events) front = set(latest_events) - seen_events event_results: List[str] = [] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 7996cbb557..a8a4ed4436 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -469,7 +469,6 @@ class PersistEventsStore: txn: LoggingTransaction, events: List[EventBase], ) -> None: - # We only care about state events, so this if there are no state events. if not any(e.is_state() for e in events): return @@ -2025,10 +2024,6 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_relations_for_event, (redacted_relates_to,) ) - if rel_type == RelationTypes.ANNOTATION: - self.store._invalidate_cache_and_stream( - txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) - ) if rel_type == RelationTypes.REFERENCE: self.store._invalidate_cache_and_stream( txn, self.store.get_references_for_event, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 584536111d..daef3685b0 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): nbrows = 0 last_row_event_id = "" - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) @@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): results = list(txn) # (event_id, parent_id, rel_type) for each relation relations_to_insert: List[Tuple[str, str, str]] = [] - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) except Exception as e: @@ -1220,9 +1220,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined] ) self._invalidate_cache_and_stream( # type: ignore[attr-defined] - txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined] - ) - self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined] ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d0ef10258..20b7a68362 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -72,7 +72,6 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -187,8 +186,8 @@ class EventsWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdTracker - self._backfill_id_gen: AbstractStreamIdTracker + self._stream_id_gen: AbstractStreamIdGenerator + self._backfill_id_gen: AbstractStreamIdGenerator if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -1493,7 +1492,7 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(redactions_sql + clause, args) - for (redacter, redacted) in txn: + for redacter, redacted in txn: d = event_dict.get(redacted) if d: d.redactions.append(redacter) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 12f3b601f1..8e57c8e5a0 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, StoreError, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict @@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore): return db_to_json(def_json) - -class FilteringStore(FilteringWorkerStore): async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) @@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore): return filter_id - return await self.db_pool.runInteraction("add_user_filter", _do_txn) + attempts = 0 + while True: + # Try a few times. + # This is technically needed if a user tries to create two filters at once, + # leading to two concurrent transactions. + # The failure case would be: + # - SELECT filter_id ... filter_json = ? → both transactions return no rows + # - SELECT MAX(filter_id) ... → both transactions return e.g. 5 + # - INSERT INTO ... → both transactions insert filter_id = 6 + # One of the transactions will commit. The other will get a unique key + # constraint violation error (IntegrityError). This is not the same as a + # serialisability violation, which would be automatically retried by + # `runInteraction`. + try: + return await self.db_pool.runInteraction("add_user_filter", _do_txn) + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + + if attempts >= 5: + raise StoreError(500, "Couldn't generate a filter ID.") diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index b202c5eb87..fa8be214ce 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[Dict[str, Any]], int]: - # Set ordering order_by_column = MediaSortOrder(order_by).value diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9b2bbe060d..9f862f00c1 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -46,7 +46,6 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, IdGenerator, StreamIdGenerator, ) @@ -118,7 +117,7 @@ class PushRulesWorkerStore( # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._push_rules_stream_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "push_rules_stream", diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df53e726e6..9a24f7a655 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -36,7 +36,6 @@ from synapse.storage.database import ( ) from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, StreamIdGenerator, ) from synapse.types import JsonDict @@ -60,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._pushers_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "pushers", @@ -344,7 +343,6 @@ class PusherWorkerStore(SQLBaseStore): last_user = progress.get("last_user", "") def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT name FROM users WHERE deactivated = ? and name > ? @@ -392,7 +390,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, access_token FROM pushers AS p LEFT JOIN access_tokens AS a ON (p.access_token = a.id) @@ -449,7 +446,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, p.user_name, p.app_id, p.pushkey FROM pushers AS p diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index dddf49c2d5..074942b167 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -39,7 +39,7 @@ from synapse.storage.database import ( from synapse.storage.engines import PostgresEngine from synapse.storage.engines._base import IsolationLevel from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, + AbstractStreamIdGenerator, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -65,7 +65,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._receipts_id_gen: AbstractStreamIdTracker + self._receipts_id_gen: AbstractStreamIdGenerator if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( @@ -768,7 +768,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) - async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self._insert_linearized_receipt_txn, @@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): def _populate_receipt_event_stream_ordering_txn( txn: LoggingTransaction, ) -> bool: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9a55e17624..717237e024 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="user_delete_threepid", ) - async def user_delete_threepids(self, user_id: str) -> None: - """Delete all threepid this user has bound - - Args: - user_id: The user id to delete all threepids of - - """ - await self.db_pool.simple_delete( - "user_threepids", - keyvalues={"user_id": user_id}, - desc="user_delete_threepids", - ) - async def add_user_bound_threepid( self, user_id: str, medium: str, address: str, id_server: str ) -> None: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index fa3266c081..bc3a83919c 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -398,143 +398,6 @@ class RelationsWorkerStore(SQLBaseStore): return result is not None @cached() - async def get_aggregation_groups_for_event( - self, event_id: str - ) -> Sequence[JsonDict]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" - ) - async def get_aggregation_groups_for_events( - self, event_ids: Collection[str] - ) -> Mapping[str, Optional[List[JsonDict]]]: - """Get a list of annotations on the given events, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happend - on an event. - - Args: - event_ids: Fetch events that relate to these event IDs. - - Returns: - A map of event IDs to a list of groups of annotations that match. - Each entry is a dict with `type`, `key` and `count` fields. - """ - # The number of entries to return per event ID. - limit = 5 - - clause, args = make_in_list_sql_clause( - self.database_engine, "relates_to_id", event_ids - ) - args.append(RelationTypes.ANNOTATION) - - sql = f""" - SELECT - relates_to_id, - annotation.type, - aggregation_key, - COUNT(DISTINCT annotation.sender) - FROM events AS annotation - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS parent ON - parent.event_id = relates_to_id - AND parent.room_id = annotation.room_id - WHERE - {clause} - AND relation_type = ? - GROUP BY relates_to_id, annotation.type, aggregation_key - ORDER BY relates_to_id, COUNT(*) DESC - """ - - def _get_aggregation_groups_for_events_txn( - txn: LoggingTransaction, - ) -> Mapping[str, List[JsonDict]]: - txn.execute(sql, args) - - result: Dict[str, List[JsonDict]] = {} - for event_id, type, key, count in cast( - List[Tuple[str, str, str, int]], txn - ): - event_results = result.setdefault(event_id, []) - - # Limit the number of results per event ID. - if len(event_results) == limit: - continue - - event_results.append({"type": type, "key": key, "count": count}) - - return result - - return await self.db_pool.runInteraction( - "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn - ) - - async def get_aggregation_groups_for_users( - self, event_ids: Collection[str], users: FrozenSet[str] - ) -> Dict[str, Dict[Tuple[str, str], int]]: - """Fetch the partial aggregations for an event for specific users. - - This is used, in conjunction with get_aggregation_groups_for_event, to - remove information from the results for ignored users. - - Args: - event_ids: Fetch events that relate to these event IDs. - users: The users to fetch information for. - - Returns: - A map of event ID to a map of (event type, aggregation key) to a - count of users. - """ - - if not users: - return {} - - events_sql, args = make_in_list_sql_clause( - self.database_engine, "relates_to_id", event_ids - ) - - users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "annotation.sender", users - ) - args.extend(users_args) - args.append(RelationTypes.ANNOTATION) - - sql = f""" - SELECT - relates_to_id, - annotation.type, - aggregation_key, - COUNT(DISTINCT annotation.sender) - FROM events AS annotation - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS parent ON - parent.event_id = relates_to_id - AND parent.room_id = annotation.room_id - WHERE {events_sql} AND {users_sql} AND relation_type = ? - GROUP BY relates_to_id, annotation.type, aggregation_key - ORDER BY relates_to_id, COUNT(*) DESC - """ - - def _get_aggregation_groups_for_users_txn( - txn: LoggingTransaction, - ) -> Dict[str, Dict[Tuple[str, str], int]]: - txn.execute(sql, args) - - result: Dict[str, Dict[Tuple[str, str], int]] = {} - for event_id, type, key, count in cast( - List[Tuple[str, str, str, int]], txn - ): - result.setdefault(event_id, {})[(type, key)] = count - - return result - - return await self.db_pool.runInteraction( - "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn - ) - - @cached() async def get_references_for_event(self, event_id: str) -> List[JsonDict]: raise NotImplementedError() diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 644bbb8878..3825bd6079 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -1417,6 +1417,204 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_un_partial_stated_rooms_from_stream_txn, ) + async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: + """Retrieve an event report + + Args: + report_id: ID of reported event in database + Returns: + JSON dict of information from an event report or None if the + report does not exist. + """ + + def _get_event_report_txn( + txn: LoggingTransaction, report_id: int + ) -> Optional[Dict[str, Any]]: + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + WHERE er.id = ? + """ + + txn.execute(sql, [report_id]) + row = txn.fetchone() + + if not row: + return None + + event_report = { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": db_to_json(row[5]).get("score"), + "reason": db_to_json(row[5]).get("reason"), + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + "event_json": db_to_json(row[9]), + } + + return event_report + + return await self.db_pool.runInteraction( + "get_event_report", _get_event_report_txn, report_id + ) + + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: Direction = Direction.BACKWARDS, + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + Tuple of: + json list of event reports + total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: + filters = [] + args: List[object] = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + # We join on room_stats_state despite not using any columns from it + # because the join can influence the number of rows returned; + # e.g. a room that doesn't have state, maybe because it was deleted. + # The query returning the total count should be consistent with + # the query returning the results. + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + JOIN room_stats_state ON room_stats_state.room_id = er.room_id + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = cast(Tuple[int], txn.fetchone())[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, + order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + + event_reports = [] + for row in txn: + try: + s = db_to_json(row[5]).get("score") + r = db_to_json(row[5]).get("reason") + except Exception: + logger.error("Unable to parse json from event_reports: %s", row[0]) + continue + event_reports.append( + { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": s, + "reason": r, + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + } + ) + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + + async def delete_event_report(self, report_id: int) -> bool: + """Remove an event report from database. + + Args: + report_id: Report to delete + + Returns: + Whether the report was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + table="event_reports", + keyvalues={"id": report_id}, + desc="delete_event_report", + ) + except StoreError: + # Deletion failed because report does not exist + return False + + return True + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -2139,7 +2337,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): reason: Optional[str], content: JsonDict, received_ts: int, - ) -> None: + ) -> int: + """Add an event report + + Args: + room_id: Room that contains the reported event. + event_id: The reported event. + user_id: User who reports the event. + reason: Description that the user specifies. + content: Report request body (score and reason). + received_ts: Time when the user submitted the report (milliseconds). + Returns: + Id of the event report. + """ next_id = self._event_reports_id_gen.get_next() await self.db_pool.simple_insert( table="event_reports", @@ -2154,184 +2364,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): }, desc="add_event_report", ) - - async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: - """Retrieve an event report - - Args: - report_id: ID of reported event in database - Returns: - JSON dict of information from an event report or None if the - report does not exist. - """ - - def _get_event_report_txn( - txn: LoggingTransaction, report_id: int - ) -> Optional[Dict[str, Any]]: - - sql = """ - SELECT - er.id, - er.received_ts, - er.room_id, - er.event_id, - er.user_id, - er.content, - events.sender, - room_stats_state.canonical_alias, - room_stats_state.name, - event_json.json AS event_json - FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id - JOIN event_json - ON event_json.event_id = er.event_id - JOIN room_stats_state - ON room_stats_state.room_id = er.room_id - WHERE er.id = ? - """ - - txn.execute(sql, [report_id]) - row = txn.fetchone() - - if not row: - return None - - event_report = { - "id": row[0], - "received_ts": row[1], - "room_id": row[2], - "event_id": row[3], - "user_id": row[4], - "score": db_to_json(row[5]).get("score"), - "reason": db_to_json(row[5]).get("reason"), - "sender": row[6], - "canonical_alias": row[7], - "name": row[8], - "event_json": db_to_json(row[9]), - } - - return event_report - - return await self.db_pool.runInteraction( - "get_event_report", _get_event_report_txn, report_id - ) - - async def get_event_reports_paginate( - self, - start: int, - limit: int, - direction: Direction = Direction.BACKWARDS, - user_id: Optional[str] = None, - room_id: Optional[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: - """Retrieve a paginated list of event reports - - Args: - start: event offset to begin the query from - limit: number of rows to retrieve - direction: Whether to fetch the most recent first (backwards) or the - oldest first (forwards) - user_id: search for user_id. Ignored if user_id is None - room_id: search for room_id. Ignored if room_id is None - Returns: - Tuple of: - json list of event reports - total number of event reports matching the filter criteria - """ - - def _get_event_reports_paginate_txn( - txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: - filters = [] - args: List[object] = [] - - if user_id: - filters.append("er.user_id LIKE ?") - args.extend(["%" + user_id + "%"]) - if room_id: - filters.append("er.room_id LIKE ?") - args.extend(["%" + room_id + "%"]) - - if direction == Direction.BACKWARDS: - order = "DESC" - else: - order = "ASC" - - where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - - # We join on room_stats_state despite not using any columns from it - # because the join can influence the number of rows returned; - # e.g. a room that doesn't have state, maybe because it was deleted. - # The query returning the total count should be consistent with - # the query returning the results. - sql = """ - SELECT COUNT(*) as total_event_reports - FROM event_reports AS er - JOIN room_stats_state ON room_stats_state.room_id = er.room_id - {} - """.format( - where_clause - ) - txn.execute(sql, args) - count = cast(Tuple[int], txn.fetchone())[0] - - sql = """ - SELECT - er.id, - er.received_ts, - er.room_id, - er.event_id, - er.user_id, - er.content, - events.sender, - room_stats_state.canonical_alias, - room_stats_state.name - FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id - JOIN room_stats_state - ON room_stats_state.room_id = er.room_id - {where_clause} - ORDER BY er.received_ts {order} - LIMIT ? - OFFSET ? - """.format( - where_clause=where_clause, - order=order, - ) - - args += [limit, start] - txn.execute(sql, args) - - event_reports = [] - for row in txn: - try: - s = db_to_json(row[5]).get("score") - r = db_to_json(row[5]).get("reason") - except Exception: - logger.error("Unable to parse json from event_reports: %s", row[0]) - continue - event_reports.append( - { - "id": row[0], - "received_ts": row[1], - "room_id": row[2], - "event_id": row[3], - "user_id": row[4], - "score": s, - "reason": r, - "sender": row[6], - "canonical_alias": row[7], - "name": row[8], - } - ) - - return event_reports, count - - return await self.db_pool.runInteraction( - "get_event_reports_paginate", _get_event_reports_paginate_txn - ) + return next_id async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3fe433f66c..a7aae661d8 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore): class SearchBackgroundUpdateStore(SearchWorkerStore): - EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" @@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ba325d390b..ebb2ae964f 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7b7d0c3c9..d3393d8e49 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore): insert_cols = [] qargs = [] - for (key, val) in chain( + for key, val in chain( keyvalues.items(), absolutes.items(), additive_relatives.items() ): insert_cols.append(key) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 818c46182e..ac5fbf6b86 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" + # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6b33d809b6..6d72bd9f67 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def get_destination_rooms_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: - if direction == Direction.BACKWARDS: order = "DESC" else: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f6a6fd4079..f16a509ac4 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -14,6 +14,7 @@ import logging import re +import unicodedata from typing import ( TYPE_CHECKING, Iterable, @@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def _populate_user_directory_createtables( self, progress: JsonDict, batch_size: int ) -> int: - # Get all the rooms that we want to process. def _make_staging_area(txn: LoggingTransaction) -> None: sql = ( @@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): values={"display_name": display_name, "avatar_url": avatar_url}, ) + # The display name that goes into the database index. + index_display_name = display_name + if index_display_name is not None: + index_display_name = _filter_text_for_index(index_display_name) + if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name @@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, + index_display_name, ), ) elif isinstance(self.database_engine, Sqlite3Engine): - value = "%s %s" % (user_id, display_name) if display_name else user_id + value = ( + "%s %s" % (user_id, index_display_name) + if index_display_name + else user_id + ) self.db_pool.simple_upsert_txn( txn, table="user_directory_search", @@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return {"limited": limited, "results": results[0:limit]} +def _filter_text_for_index(text: str) -> str: + """Transforms text before it is inserted into the user directory index, or searched + for in the user directory index. + + Note that the user directory search table needs to be rebuilt whenever this function + changes. + """ + # Lowercase the text, to make searches case-insensitive. + # This is necessary for both PostgreSQL and SQLite. PostgreSQL's + # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using + # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at + # all. + text = text.lower() + + # Normalize the text. NFKC normalization has two effects: + # 1. It canonicalizes the text, ie. maps all visually identical strings to the same + # string. For example, ["e", "◌́"] is mapped to ["é"]. + # 2. It maps strings that are roughly equivalent to the same string. + # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to + # ["i", "9"]. + text = unicodedata.normalize("NFKC", text) + + # Note that nothing is done to make searches accent-insensitive. + # That could be achieved by converting to NFKD form instead (with combining accents + # split out) and filtering out combining accents using `unicodedata.combining(c)`. + # The downside of this may be noisier search results, since search terms with + # explicit accents will match characters with no accents, or completely different + # accents. + # + # text = unicodedata.normalize("NFKD", text) + # text = "".join([c for c in text if not unicodedata.combining(c)]) + + return text + + def _parse_query_sqlite(search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. @@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str: We specifically add both a prefix and non prefix matching term so that exact matches get ranked higher. """ + search_term = _filter_text_for_index(search_term) # Pull out the individual words, discarding any non-word characters. results = _parse_words(search_term) @@ -918,11 +963,21 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: We use this so that we can add prefix matching, which isn't something that is supported by default. """ - results = _parse_words(search_term) + search_term = _filter_text_for_index(search_term) + + escaped_words = [] + for word in _parse_words(search_term): + # Postgres tsvector and tsquery quoting rules: + # words potentially containing punctuation should be quoted + # and then existing quotes and backslashes should be doubled + # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY - both = " & ".join("(%s:* | %s)" % (result, result) for result in results) - exact = " & ".join("%s" % (result,) for result in results) - prefix = " & ".join("%s:*" % (result,) for result in results) + quoted_word = word.replace("'", "''").replace("\\", "\\\\") + escaped_words.append(f"'{quoted_word}'") + + both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words) + exact = " & ".join("%s" % (word,) for word in escaped_words) + prefix = " & ".join("%s:*" % (word,) for word in escaped_words) return both, exact, prefix @@ -944,6 +999,14 @@ def _parse_words(search_term: str) -> List[str]: if USE_ICU: return _parse_words_with_icu(search_term) + return _parse_words_with_regex(search_term) + + +def _parse_words_with_regex(search_term: str) -> List[str]: + """ + Break down search term into words, when we don't have ICU available. + See: `_parse_words` + """ return re.findall(r"([\w\-]+)", search_term, re.UNICODE) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d743282f13..097dea5182 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py
@@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 1a7232b276..bf4cdfdf29 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se import attr from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -257,14 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = self._get_state_for_groups_using_cache( + non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( + member_state, incomplete_groups_m = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) @@ -404,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) + async def store_state_deltas_for_batched( + self, + events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]], + room_id: str, + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state deltas for a group of events and contexts created to be + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: the events to generate and store a state groups for + and their associated contexts + room_id: the id of the room the events were created for + prev_group: the state group of the last event persisted before the batched events + were created + """ + + def insert_deltas_group_txn( + txn: LoggingTransaction, + events_and_context: List[Tuple[EventBase, UnpersistedEventContext]], + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state groups for the provided events and contexts. + + Requires that we have the state as a delta from the last persisted state group. + + Returns: + A list of state groups + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + num_state_groups = sum( + 1 for event, _ in events_and_context if event.is_state() + ) + + state_groups = self._state_group_seq_gen.get_next_mult_txn( + txn, num_state_groups + ) + + sg_before = prev_group + state_group_iter = iter(state_groups) + for event, context in events_and_context: + if not event.is_state(): + context.state_group_after_event = sg_before + context.state_group_before_event = sg_before + continue + + sg_after = next(state_group_iter) + context.state_group_after_event = sg_after + context.state_group_before_event = sg_before + context.state_delta_due_to_event = { + (event.type, event.state_key): event.event_id + } + sg_before = sg_after + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups", + keys=("id", "room_id", "event_id"), + values=[ + (context.state_group_after_event, room_id, event.event_id) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_group_edges", + keys=("state_group", "prev_state_group"), + values=[ + ( + context.state_group_after_event, + context.state_group_before_event, + ) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + ( + context.state_group_after_event, + room_id, + key[0], + key[1], + state_id, + ) + for event, context in events_and_context + if context.state_delta_due_to_event is not None + for key, state_id in context.state_delta_due_to_event.items() + ], + ) + return events_and_context + + return await self.db_pool.runInteraction( + "store_state_deltas_for_batched.insert_deltas_group", + insert_deltas_group_txn, + events_and_context, + prev_group, + ) + async def store_state_group( self, event_id: str, diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index a182e8a098..d1ccb7390a 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py
@@ -25,7 +25,7 @@ try: except ImportError: class PostgresEngine(BaseDatabaseEngine): # type: ignore[no-redef] - def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc] + def __new__(cls, *args: object, **kwargs: object) -> NoReturn: raise RuntimeError( f"Cannot create {cls.__name__} -- psycopg2 module is not installed" ) @@ -36,7 +36,7 @@ try: except ImportError: class Sqlite3Engine(BaseDatabaseEngine): # type: ignore[no-redef] - def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc] + def __new__(cls, *args: object, **kwargs: object) -> NoReturn: raise RuntimeError( f"Cannot create {cls.__name__} -- sqlite3 module is not installed" ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6c335a9315..2a1c6fa31b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -563,7 +563,7 @@ def _apply_module_schemas( """ # This is the old way for password_auth_provider modules to make changes # to the database. This should instead be done using the module API - for (mod, _config) in config.authproviders.password_providers: + for mod, _config in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) @@ -591,7 +591,7 @@ def _apply_module_schema_files( (modname,), ) applied_deltas = {d for d, in cur} - for (name, stream) in names_and_streams: + for name, stream in names_and_streams: if name in applied_deltas: continue diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 0031df1e06..56a0048539 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py
@@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import TracebackType -from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) from typing_extensions import Protocol @@ -112,15 +123,35 @@ class DBAPI2Module(Protocol): # extends from this hierarchy. See # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions # https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE - Warning: Type[Exception] - Error: Type[Exception] + # + # Note: rather than + # x: T + # we write + # @property + # def x(self) -> T: ... + # which expresses that the protocol attribute `x` is read-only. The mypy docs + # https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected + # explain why this is necessary for safety. TL;DR: we shouldn't be able to write + # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 . + @property + def Warning(self) -> Type[Exception]: + ... + + @property + def Error(self) -> Type[Exception]: + ... # Errors are divided into `InterfaceError`s (something went wrong in the database # driver) and `DatabaseError`s (something went wrong in the database). These are # both subclasses of `Error`, but we can't currently express this in type # annotations due to https://github.com/python/mypy/issues/8397 - InterfaceError: Type[Exception] - DatabaseError: Type[Exception] + @property + def InterfaceError(self) -> Type[Exception]: + ... + + @property + def DatabaseError(self) -> Type[Exception]: + ... # Everything below is a subclass of `DatabaseError`. @@ -128,7 +159,9 @@ class DBAPI2Module(Protocol): # - An integer was too big for its data type. # - An invalid date time was provided. # - A string contained a null code point. - DataError: Type[Exception] + @property + def DataError(self) -> Type[Exception]: + ... # Roughly: something went wrong in the database, but it's not within the application # programmer's control. Examples: @@ -138,28 +171,45 @@ class DBAPI2Module(Protocol): # - A serialisation failure occurred. # - The database ran out of resources, such as storage, memory, connections, etc. # - The database encountered an error from the operating system. - OperationalError: Type[Exception] + @property + def OperationalError(self) -> Type[Exception]: + ... # Roughly: we've given the database data which breaks a rule we asked it to enforce. # Examples: # - Stop, criminal scum! You violated the foreign key constraint # - Also check constraints, non-null constraints, etc. - IntegrityError: Type[Exception] + @property + def IntegrityError(self) -> Type[Exception]: + ... # Roughly: something went wrong within the database server itself. - InternalError: Type[Exception] + @property + def InternalError(self) -> Type[Exception]: + ... # Roughly: the application did something silly that needs to be fixed. Examples: # - We don't have permissions to do something. # - We tried to create a table with duplicate column names. # - We tried to use a reserved name. # - We referred to a column that doesn't exist. - ProgrammingError: Type[Exception] + @property + def ProgrammingError(self) -> Type[Exception]: + ... # Roughly: we've tried to do something that this database doesn't support. - NotSupportedError: Type[Exception] + @property + def NotSupportedError(self) -> Type[Exception]: + ... - def connect(self, **parameters: object) -> Connection: + # We originally wrote + # def connect(self, *args, **kwargs) -> Connection: ... + # But mypy doesn't seem to like that because sqlite3.connect takes a mandatory + # positional argument. We can't make that part of the signature though, because + # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use + # the following slightly unusual workaround. + @property + def connect(self) -> Callable[..., Connection]: ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9adff3f4f5..d2c874b9a8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -93,8 +93,11 @@ def _load_current_id( return res -class AbstractStreamIdTracker(metaclass=abc.ABCMeta): - """Tracks the "current" stream ID of a stream that may have multiple writers. +class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): + """Generates or tracks stream IDs for a stream that may have multiple writers. + + Each stream ID represents a write transaction, whose completion is tracked + so that the "current" stream ID of the stream can be determined. Stream IDs are monotonically increasing or decreasing integers representing write transactions. The "current" stream ID is the stream ID such that all transactions @@ -130,16 +133,6 @@ class AbstractStreamIdTracker(metaclass=abc.ABCMeta): """ raise NotImplementedError() - -class AbstractStreamIdGenerator(AbstractStreamIdTracker): - """Generates stream IDs for a stream that may have multiple writers. - - Each stream ID represents a write transaction, whose completion is tracked - so that the "current" stream ID of the stream can be determined. - - See `AbstractStreamIdTracker` for more details. - """ - @abc.abstractmethod def get_next(self) -> AsyncContextManager[int]: """ @@ -158,6 +151,15 @@ class AbstractStreamIdGenerator(AbstractStreamIdTracker): """ raise NotImplementedError() + @abc.abstractmethod + def get_next_txn(self, txn: LoggingTransaction) -> int: + """ + Usage: + stream_id_gen.get_next_txn(txn) + # ... persist events ... + """ + raise NotImplementedError() + class StreamIdGenerator(AbstractStreamIdGenerator): """Generates and tracks stream IDs for a stream with a single writer. @@ -263,6 +265,40 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) + def get_next_txn(self, txn: LoggingTransaction) -> int: + """ + Retrieve the next stream ID from within a database transaction. + + Clean-up functions will be called when the transaction finishes. + + Args: + txn: The database transaction object. + + Returns: + The next stream ID. + """ + if not self._is_writer: + raise Exception("Tried to allocate stream ID on non-writer") + + # Get the next stream ID. + with self._lock: + self._current += self._step + next_id = self._current + + self._unfinished_ids[next_id] = next_id + + def clear_unfinished_id(id_to_clear: int) -> None: + """A function to mark processing this ID as finished""" + with self._lock: + self._unfinished_ids.pop(id_to_clear) + + # Mark this ID as finished once the database transaction itself finishes. + txn.call_after(clear_unfinished_id, next_id) + txn.call_on_exception(clear_unfinished_id, next_id) + + # Return the new ID. + return next_id + def get_current_token(self) -> int: if not self._is_writer: return self._current @@ -568,7 +604,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): """ Usage: - stream_id = stream_id_gen.get_next(txn) + stream_id = stream_id_gen.get_next_txn(txn) # ... persist event ... """ diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 75268cbe15..80915216de 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py
@@ -205,7 +205,7 @@ class LocalSequenceGenerator(SequenceGenerator): """ Args: get_first_callback: a callback which is called on the first call to - get_next_id_txn; should return the curreent maximum id + get_next_id_txn; should return the current maximum id """ # the callback. this is cleared after it is called, so that it can be GCed. self._callback: Optional[GetFirstCallbackType] = get_first_callback diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index c6c8a0315c..8a48ffc48d 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from abc import ABC, abstractmethod from typing import Generic, List, Optional, Tuple, TypeVar from synapse.types import StrCollection, UserID @@ -22,7 +22,8 @@ K = TypeVar("K") R = TypeVar("R") -class EventSource(Generic[K, R]): +class EventSource(ABC, Generic[K, R]): + @abstractmethod async def get_new_events( self, user: UserID, @@ -32,4 +33,4 @@ class EventSource(Generic[K, R]): is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[R], K]: - ... + raise NotImplementedError() diff --git a/synapse/types/state.py b/synapse/types/state.py
index 743a4f9217..4b3071acce 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py
@@ -120,7 +120,7 @@ class StateFilter: def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: """The inverse to `from_types`.""" - for (event_type, state_keys) in self.types.items(): + for event_type, state_keys in self.types.items(): if state_keys is None: yield event_type, None else: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 9387632d0d..6ffa56217e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -98,7 +98,6 @@ class EvictionReason(Enum): @attr.s(slots=True, auto_attribs=True) class CacheMetric: - _cache: Sized _cache_type: str _cache_name: str diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 3b1e205700..1c0fde4966 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py
@@ -183,7 +183,7 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled = [] errors = [] - for (requirement, must_be_installed) in dependencies: + for requirement, must_be_installed in dependencies: try: dist: metadata.Distribution = metadata.distribution(requirement.name) except metadata.PackageNotFoundError: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index f97f98a057..d00d34e652 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py
@@ -211,7 +211,6 @@ def _check_yield_points( result = Failure() if current_context() != expected_context: - # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the # deferred we waited on didn't follow the rules, or we forgot to