diff --git a/synapse/__init__.py b/synapse/__init__.py
index fbfd506a43..a203ed533a 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -1,5 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018-9 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-""" This is a reference implementation of a Matrix homeserver.
+""" This is an implementation of a Matrix homeserver.
"""
import json
diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py
index 819afaaca6..0dd36bee20 100755
--- a/synapse/_scripts/move_remote_media_to_new_store.py
+++ b/synapse/_scripts/move_remote_media_to_new_store.py
@@ -37,7 +37,7 @@ import os
import shutil
import sys
-from synapse.rest.media.v1.filepath import MediaFilePaths
+from synapse.media.filepath import MediaFilePaths
logger = logging.getLogger()
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 2b74a40166..19ca399d44 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -47,7 +47,6 @@ def request_registration(
_print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit,
) -> None:
-
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
# Get the nonce
@@ -154,7 +153,6 @@ def register_new_user(
def main() -> None:
-
logging.captureWarnings(True)
parser = argparse.ArgumentParser(
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 0d35e0af8f..2c9cbf8b27 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -1205,7 +1205,6 @@ class CursesProgress(Progress):
if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,)
else:
-
if self.total_processed > 0:
left = float(self.total_remaining) / self.total_processed
diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py
index b4c96ad7f3..077b90935e 100755
--- a/synapse/_scripts/synctl.py
+++ b/synapse/_scripts/synctl.py
@@ -167,7 +167,6 @@ Worker = collections.namedtuple(
def main() -> None:
-
parser = argparse.ArgumentParser()
parser.add_argument(
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index a5aa2185a2..28062dd69d 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -213,7 +213,7 @@ def handle_startup_exception(e: Exception) -> NoReturn:
def redirect_stdio_to_logs() -> None:
streams = [("stdout", LogLevel.info), ("stderr", LogLevel.error)]
- for (stream, level) in streams:
+ for stream, level in streams:
oldStream = getattr(sys, stream)
loggingFile = LoggingFile(
logger=twisted.logger.Logger(namespace=stream),
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index ad51f33165..b05fe2c589 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -44,6 +44,7 @@ from synapse.storage.databases.main.event_push_actions import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.filtering import FilteringWorkerStore
+from synapse.storage.databases.main.media_repository import MediaRepositoryStore
from synapse.storage.databases.main.profile import ProfileWorkerStore
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
@@ -86,6 +87,7 @@ class AdminCmdSlavedStore(
RegistrationWorkerStore,
RoomWorkerStore,
ProfileWorkerStore,
+ MediaRepositoryStore,
):
def __init__(
self,
@@ -149,7 +151,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(events_file, "a") as f:
for event in events:
- print(json.dumps(event.get_pdu_json()), file=f)
+ json.dump(event.get_pdu_json(), fp=f)
def write_state(
self, room_id: str, event_id: str, state: StateMap[EventBase]
@@ -162,7 +164,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(event_file, "a") as f:
for event in state.values():
- print(json.dumps(event.get_pdu_json()), file=f)
+ json.dump(event.get_pdu_json(), fp=f)
def write_invite(
self, room_id: str, event: EventBase, state: StateMap[EventBase]
@@ -178,7 +180,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(invite_state, "a") as f:
for event in state.values():
- print(json.dumps(event), file=f)
+ json.dump(event, fp=f)
def write_knock(
self, room_id: str, event: EventBase, state: StateMap[EventBase]
@@ -194,7 +196,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(knock_state, "a") as f:
for event in state.values():
- print(json.dumps(event), file=f)
+ json.dump(event, fp=f)
def write_profile(self, profile: JsonDict) -> None:
user_directory = os.path.join(self.base_directory, "user_data")
@@ -202,7 +204,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
profile_file = os.path.join(user_directory, "profile")
with open(profile_file, "a") as f:
- print(json.dumps(profile), file=f)
+ json.dump(profile, fp=f)
def write_devices(self, devices: List[JsonDict]) -> None:
user_directory = os.path.join(self.base_directory, "user_data")
@@ -211,7 +213,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
for device in devices:
with open(device_file, "a") as f:
- print(json.dumps(device), file=f)
+ json.dump(device, fp=f)
def write_connections(self, connections: List[JsonDict]) -> None:
user_directory = os.path.join(self.base_directory, "user_data")
@@ -220,7 +222,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
for connection in connections:
with open(connection_file, "a") as f:
- print(json.dumps(connection), file=f)
+ json.dump(connection, fp=f)
def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict]
@@ -233,7 +235,15 @@ class FileExfiltrationWriter(ExfiltrationWriter):
account_data_file = os.path.join(account_data_directory, file_name)
with open(account_data_file, "a") as f:
- print(json.dumps(account_data), file=f)
+ json.dump(account_data, fp=f)
+
+ def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
+ file_directory = os.path.join(self.base_directory, "media_ids")
+ os.makedirs(file_directory, exist_ok=True)
+ media_id_file = os.path.join(file_directory, media_id)
+
+ with open(media_id_file, "w") as f:
+ json.dump(media_metadata, fp=f)
def finished(self) -> str:
return self.base_directory
diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py
index 920538f44d..c8dc3f9d76 100644
--- a/synapse/app/complement_fork_starter.py
+++ b/synapse/app/complement_fork_starter.py
@@ -219,7 +219,7 @@ def main() -> None:
# memory space and don't need to repeat the work of loading the code!
# Instead of using fork() directly, we use the multiprocessing library,
# which uses fork() on Unix platforms.
- for (func, worker_args) in zip(worker_functions, args_by_worker):
+ for func, worker_args in zip(worker_functions, args_by_worker):
process = multiprocessing.Process(
target=_worker_entrypoint, args=(func, proxy_reactor, worker_args)
)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 946f3a3807..0dec24369a 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -157,7 +157,6 @@ class GenericWorkerServer(HomeServer):
DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore
def _listen_http(self, listener_config: ListenerConfig) -> None:
-
assert listener_config.http_options is not None
# We always include a health resource.
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6176a70eb2..b8830b1a9c 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -321,7 +321,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
and not config.registration.registrations_require_3pid
and not config.registration.registration_requires_token
):
-
raise ConfigError(
"You have enabled open registration without any verification. This is a known vector for "
"spam and abuse. If you would like to allow public registration, please consider adding email, "
diff --git a/synapse/config/consent.py b/synapse/config/consent.py
index be74609dc4..5bfd0cbb71 100644
--- a/synapse/config/consent.py
+++ b/synapse/config/consent.py
@@ -22,7 +22,6 @@ from ._base import Config
class ConsentConfig(Config):
-
section = "consent"
def __init__(self, *args: Any):
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 928fec8dfe..596d8769fe 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -154,7 +154,6 @@ class DatabaseConfig(Config):
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
def set_databasepath(self, database_path: str) -> None:
-
if database_path != ":memory:":
database_path = self.abspath(database_path)
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 54c91953e1..bc38fae0b6 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -175,8 +175,8 @@ class ExperimentalConfig(Config):
)
# MSC3873: Disambiguate event_match keys.
- self.msc3783_escape_event_match_key = experimental.get(
- "msc3783_escape_event_match_key", False
+ self.msc3873_escape_event_match_key = experimental.get(
+ "msc3873_escape_event_match_key", False
)
# MSC3952: Intentional mentions, this depends on MSC3758.
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 4d2b298a70..c205a78039 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -56,7 +56,6 @@ from .workers import WorkerConfig
class HomeServerConfig(RootConfig):
-
config_classes = [
ModulesConfig,
ServerConfig,
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 5c13fe428a..a5514e70a2 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -46,7 +46,6 @@ class RatelimitConfig(Config):
section = "ratelimiting"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
-
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
if "rc_message" in config:
@@ -87,9 +86,18 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.1, "burst_count": 5},
)
+ # It is reasonable to login with a bunch of devices at once (i.e. when
+ # setting up an account), but it is *not* valid to continually be
+ # logging into new devices.
rc_login_config = config.get("rc_login", {})
- self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {}))
- self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {}))
+ self.rc_login_address = RatelimitSettings(
+ rc_login_config.get("address", {}),
+ defaults={"per_second": 0.003, "burst_count": 5},
+ )
+ self.rc_login_account = RatelimitSettings(
+ rc_login_config.get("account", {}),
+ defaults={"per_second": 0.003, "burst_count": 5},
+ )
self.rc_login_failed_attempts = RatelimitSettings(
rc_login_config.get("failed_attempts", {})
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index e4759711ed..ecb3edbe3a 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -116,7 +116,6 @@ class ContentRepositoryConfig(Config):
section = "media"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
-
# Only enable the media repo if either the media repo is enabled or the
# current worker app is the media repo.
if (
@@ -179,11 +178,13 @@ class ContentRepositoryConfig(Config):
for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend
- if provider_config["module"] == "file_system":
- provider_config["module"] = (
- "synapse.rest.media.v1.storage_provider"
- ".FileStorageProviderBackend"
- )
+ if (
+ provider_config["module"] == "file_system"
+ or provider_config["module"] == "synapse.rest.media.v1.storage_provider"
+ ):
+ provider_config[
+ "module"
+ ] = "synapse.media.storage_provider.FileStorageProviderBackend"
provider_class, parsed_config = load_module(
provider_config, ("media_storage_providers", "<item %i>" % i)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index d4ef9930b0..0e46b849cf 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -735,7 +735,6 @@ class ServerConfig(Config):
listeners: Optional[List[dict]],
**kwargs: Any,
) -> str:
-
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 336fe3e0da..318270ebb8 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -30,7 +30,6 @@ class TlsConfig(Config):
section = "tls"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
-
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 86cd4af9bd..d710607c63 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -399,7 +399,7 @@ class Keyring:
# We now convert the returned list of results into a map from server
# name to key ID to FetchKeyResult, to return.
to_return: Dict[str, Dict[str, FetchKeyResult]] = {}
- for (request, results) in zip(deduped_requests, results_per_request):
+ for request, results in zip(deduped_requests, results_per_request):
to_return_by_server = to_return.setdefault(request.server_name, {})
for key_id, key_result in results.items():
existing = to_return_by_server.get(key_id)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e0d82ad81c..a91a5d1e3c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage.controllers import StorageControllers
+ from synapse.storage.databases import StateGroupDataStore
from synapse.storage.databases.main import DataStore
from synapse.types.state import StateFilter
@@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
partial_state: bool
state_map_before_event: Optional[StateMap[str]] = None
+ @classmethod
+ async def batch_persist_unpersisted_contexts(
+ cls,
+ events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
+ room_id: str,
+ last_known_state_group: int,
+ datastore: "StateGroupDataStore",
+ ) -> List[Tuple[EventBase, EventContext]]:
+ """
+ Takes a list of events and their associated unpersisted contexts and persists
+ the unpersisted contexts, returning a list of events and persisted contexts.
+ Note that all the events must be in a linear chain (ie a <- b <- c).
+
+ Args:
+ events_and_context: A list of events and their unpersisted contexts
+ room_id: the room_id for the events
+ last_known_state_group: the last persisted state group
+ datastore: a state datastore
+ """
+ amended_events_and_context = await datastore.store_state_deltas_for_batched(
+ events_and_context, room_id, last_known_state_group
+ )
+
+ events_and_persisted_context = []
+ for event, unpersisted_context in amended_events_and_context:
+ if event.is_state():
+ context = EventContext(
+ storage=unpersisted_context._storage,
+ state_group=unpersisted_context.state_group_after_event,
+ state_group_before_event=unpersisted_context.state_group_before_event,
+ state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+ partial_state=unpersisted_context.partial_state,
+ prev_group=unpersisted_context.state_group_before_event,
+ delta_ids=unpersisted_context.state_delta_due_to_event,
+ )
+ else:
+ context = EventContext(
+ storage=unpersisted_context._storage,
+ state_group=unpersisted_context.state_group_after_event,
+ state_group_before_event=unpersisted_context.state_group_before_event,
+ state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+ partial_state=unpersisted_context.partial_state,
+ prev_group=unpersisted_context.prev_group_for_state_group_before_event,
+ delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
+ )
+ events_and_persisted_context.append((event, context))
+ return events_and_persisted_context
+
async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 623a2c71ea..765c15bb51 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -33,8 +33,8 @@ from typing_extensions import Literal
import synapse
from synapse.api.errors import Codes
from synapse.logging.opentracing import trace
-from synapse.rest.media.v1._base import FileInfo
-from synapse.rest.media.v1.media_storage import ReadableFileWrapper
+from synapse.media._base import FileInfo
+from synapse.media.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import JsonDict, RoomAlias, UserProfile
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 97c61cc258..3e4d52c8d8 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -45,6 +45,8 @@ CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
+ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
+ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@@ -78,7 +80,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
# correctly, we need to await its result. Therefore it doesn't make a lot of
# sense to make it go through the run() wrapper.
if f.__name__ == "check_event_allowed":
-
# We need to wrap check_event_allowed because its old form would return either
# a boolean or a dict, but now we want to return the dict separately from the
# boolean.
@@ -100,7 +101,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
return wrap_check_event_allowed
if f.__name__ == "on_create_room":
-
# We need to wrap on_create_room because its old form would return a boolean
# if the room creation is denied, but now we just want it to raise an
# exception.
@@ -174,6 +174,12 @@ class ThirdPartyEventRules:
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = []
self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = []
+ self._on_add_user_third_party_identifier_callbacks: List[
+ ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+ ] = []
+ self._on_remove_user_third_party_identifier_callbacks: List[
+ ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+ ] = []
def register_third_party_rules_callbacks(
self,
@@ -193,6 +199,12 @@ class ThirdPartyEventRules:
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None,
on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
+ on_add_user_third_party_identifier: Optional[
+ ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+ ] = None,
+ on_remove_user_third_party_identifier: Optional[
+ ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+ ] = None,
) -> None:
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
@@ -230,6 +242,11 @@ class ThirdPartyEventRules:
if on_threepid_bind is not None:
self._on_threepid_bind_callbacks.append(on_threepid_bind)
+ if on_add_user_third_party_identifier is not None:
+ self._on_add_user_third_party_identifier_callbacks.append(
+ on_add_user_third_party_identifier
+ )
+
async def check_event_allowed(
self,
event: EventBase,
@@ -513,6 +530,9 @@ class ThirdPartyEventRules:
local homeserver, not when it's created on an identity server (and then kept track
of so that it can be unbound on the same IS later on).
+ THIS MODULE CALLBACK METHOD HAS BEEN DEPRECATED. Please use the
+ `on_add_user_third_party_identifier` callback method instead.
+
Args:
user_id: the user being associated with the threepid.
medium: the threepid's medium.
@@ -525,3 +545,44 @@ class ThirdPartyEventRules:
logger.exception(
"Failed to run module API callback %s: %s", callback, e
)
+
+ async def on_add_user_third_party_identifier(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
+ """Called when an association between a user's Matrix ID and a third-party ID
+ (email, phone number) has successfully been registered on the homeserver.
+
+ Args:
+ user_id: The User ID included in the association.
+ medium: The medium of the third-party ID (email, msisdn).
+ address: The address of the third-party ID (i.e. an email address).
+ """
+ for callback in self._on_add_user_third_party_identifier_callbacks:
+ try:
+ await callback(user_id, medium, address)
+ except Exception as e:
+ logger.exception(
+ "Failed to run module API callback %s: %s", callback, e
+ )
+
+ async def on_remove_user_third_party_identifier(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
+ """Called when an association between a user's Matrix ID and a third-party ID
+ (email, phone number) has been successfully removed on the homeserver.
+
+ This is called *after* any known bindings on identity servers for this
+ association have been removed.
+
+ Args:
+ user_id: The User ID included in the removed association.
+ medium: The medium of the third-party ID (email, msisdn).
+ address: The address of the third-party ID (i.e. an email address).
+ """
+ for callback in self._on_remove_user_third_party_identifier_callbacks:
+ try:
+ await callback(user_id, medium, address)
+ except Exception as e:
+ logger.exception(
+ "Failed to run module API callback %s: %s", callback, e
+ )
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index d720b5fd3f..3063df7990 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -314,7 +314,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# stream position.
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
- for ((destination, edu_key), pos) in keyed_edus.items():
+ for (destination, edu_key), pos in keyed_edus.items():
rows.append(
(
pos,
@@ -329,7 +329,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
j = self.edus.bisect_right(to_token) + 1
edus = self.edus.items()[i:j]
- for (pos, edu) in edus:
+ for pos, edu in edus:
rows.append((pos, EduRow(edu)))
# Sort rows based on pos
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 8b7760b2cc..b06f25b03c 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -252,16 +252,19 @@ class AdminHandler:
profile = await self.get_user(UserID.from_string(user_id))
if profile is not None:
writer.write_profile(profile)
+ logger.info("[%s] Written profile", user_id)
# Get all devices the user has
devices = await self._device_handler.get_devices_by_user(user_id)
writer.write_devices(devices)
+ logger.info("[%s] Written %s devices", user_id, len(devices))
# Get all connections the user has
connections = await self.get_whois(UserID.from_string(user_id))
writer.write_connections(
connections["devices"][""]["sessions"][0]["connections"]
)
+ logger.info("[%s] Written %s connections", user_id, len(connections))
# Get all account data the user has global and in rooms
global_data = await self._store.get_global_account_data_for_user(user_id)
@@ -269,6 +272,29 @@ class AdminHandler:
writer.write_account_data("global", global_data)
for room_id in by_room_data:
writer.write_account_data(room_id, by_room_data[room_id])
+ logger.info(
+ "[%s] Written account data for %s rooms", user_id, len(by_room_data)
+ )
+
+ # Get all media ids the user has
+ limit = 100
+ start = 0
+ while True:
+ media_ids, total = await self._store.get_local_media_by_user_paginate(
+ start, limit, user_id
+ )
+ for media in media_ids:
+ writer.write_media_id(media["media_id"], media)
+
+ logger.info(
+ "[%s] Written %d media_ids of %s",
+ user_id,
+ (start + len(media_ids)),
+ total,
+ )
+ if (start + limit) >= total:
+ break
+ start += limit
return writer.finished()
@@ -360,6 +386,18 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
+ """Write the media's metadata of a user.
+ Exports only the metadata, as this can be fetched from the database via
+ read only. In order to access the files, a connection to the correct
+ media repository would be required.
+
+ Args:
+ media_id: ID of the media.
+ media_metadata: Metadata of one media file.
+ """
+
+ @abc.abstractmethod
def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5d1d21cdc8..ec3ab968e9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -737,7 +737,7 @@ class ApplicationServicesHandler:
)
ret = []
- for (success, result) in results:
+ for success, result in results:
if success:
ret.extend(result)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index cf12b55d21..308e38edea 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -815,7 +815,6 @@ class AuthHandler:
now_ms = self._clock.time_msec()
if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
-
raise SynapseError(
HTTPStatus.FORBIDDEN,
"The supplied refresh token has expired",
@@ -1543,6 +1542,17 @@ class AuthHandler:
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
) -> None:
+ """
+ Adds an association between a user's Matrix ID and a third-party ID (email,
+ phone number).
+
+ Args:
+ user_id: The ID of the user to associate.
+ medium: The medium of the third-party ID (email, msisdn).
+ address: The address of the third-party ID (i.e. an email address).
+ validated_at: The timestamp in ms of when the validation that the user owns
+ this third-party ID occurred.
+ """
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -1567,42 +1577,44 @@ class AuthHandler:
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
+ # Inform Synapse modules that a 3PID association has been created.
+ await self._third_party_rules.on_add_user_third_party_identifier(
+ user_id, medium, address
+ )
+
+ # Deprecated method for informing Synapse modules that a 3PID association
+ # has successfully been created.
await self._third_party_rules.on_threepid_bind(user_id, medium, address)
- async def delete_threepid(
- self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
- ) -> bool:
- """Attempts to unbind the 3pid on the identity servers and deletes it
- from the local database.
+ async def delete_local_threepid(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
+ """Deletes an association between a third-party ID and a user ID from the local
+ database. This method does not unbind the association from any identity servers.
+
+ If `medium` is 'email' and a pusher is associated with this third-party ID, the
+ pusher will also be deleted.
Args:
user_id: ID of user to remove the 3pid from.
medium: The medium of the 3pid being removed: "email" or "msisdn".
address: The 3pid address to remove.
- id_server: Use the given identity server when unbinding
- any threepids. If None then will attempt to unbind using the
- identity server specified when binding (if known).
-
- Returns:
- Returns True if successfully unbound the 3pid on
- the identity server, False if identity server doesn't support the
- unbind API.
"""
-
# 'Canonicalise' email addresses as per above
if medium == "email":
address = canonicalise_email(address)
- result = await self.hs.get_identity_handler().try_unbind_threepid(
- user_id, medium, address, id_server
+ await self.store.user_delete_threepid(user_id, medium, address)
+
+ # Inform Synapse modules that a 3PID association has been deleted.
+ await self._third_party_rules.on_remove_user_third_party_identifier(
+ user_id, medium, address
)
- await self.store.user_delete_threepid(user_id, medium, address)
if medium == "email":
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id="m.email", pushkey=address, user_id=user_id
)
- return result
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
@@ -2259,7 +2271,6 @@ class PasswordAuthProvider:
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
-
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
try:
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index d24f649382..d31263c717 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -100,26 +100,28 @@ class DeactivateAccountHandler:
# unbinding
identity_server_supports_unbinding = True
- # Retrieve the 3PIDs this user has bound to an identity server
- threepids = await self.store.user_get_bound_threepids(user_id)
-
- for threepid in threepids:
+ # Attempt to unbind any known bound threepids to this account from identity
+ # server(s).
+ bound_threepids = await self.store.user_get_bound_threepids(user_id)
+ for threepid in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
user_id, threepid["medium"], threepid["address"], id_server
)
- identity_server_supports_unbinding &= result
except Exception:
# Do we want this to be a fatal error or should we carry on?
logger.exception("Failed to remove threepid from ID server")
raise SynapseError(400, "Failed to remove threepid from ID server")
- await self.store.user_delete_threepid(
+
+ identity_server_supports_unbinding &= result
+
+ # Remove any local threepid associations for this account.
+ local_threepids = await self.store.user_get_threepids(user_id)
+ for threepid in local_threepids:
+ await self._auth_handler.delete_local_threepid(
user_id, threepid["medium"], threepid["address"]
)
- # Remove all 3PIDs this user has bound to the homeserver
- await self.store.user_delete_threepids(user_id)
-
# delete any devices belonging to the user, which will also
# delete corresponding access tokens.
await self._device_handler.delete_all_devices_for_user(user_id)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a5798e9483..1fb23cc9bf 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -497,9 +497,11 @@ class DirectoryHandler:
raise SynapseError(403, "Not allowed to publish room")
# Check if publishing is blocked by a third party module
- allowed_by_third_party_rules = await (
- self.third_party_event_rules.check_visibility_can_be_modified(
- room_id, visibility
+ allowed_by_third_party_rules = (
+ await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
)
)
if not allowed_by_third_party_rules:
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 83f53ceb88..50317ec753 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -188,7 +188,6 @@ class E2eRoomKeysHandler:
# XXX: perhaps we should use a finer grained lock here?
async with self._upload_linearizer.queue(user_id):
-
# Check that the version we're trying to upload is the current version
try:
version_info = await self.store.get_e2e_room_keys_version_info(user_id)
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 46dd63c3f0..c508861b6a 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -236,7 +236,6 @@ class EventAuthHandler:
# in any of them.
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
if not await self.is_user_in_rooms(allowed_rooms, user_id):
-
# If this is a remote request, the user might be in an allowed room
# that we do not know about.
if get_domain_from_id(user_id) != self._server_name:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 1a29abde98..aead0b44b9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,6 @@ class InitialSyncHandler:
as_client_event: bool = True,
include_archived: bool = False,
) -> JsonDict:
-
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index aa90d0000d..e433d6b01f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -574,7 +574,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
@@ -721,8 +721,6 @@ class EventCreationHandler:
current_state_group=current_state_group,
)
- context = await unpersisted_context.persist(event)
-
# In an ideal world we wouldn't need the second part of this condition. However,
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this
# behaviour. Another reason is that this code is also evaluated each time a new
@@ -739,7 +737,7 @@ class EventCreationHandler:
assert state_map is not None
prev_event_id = state_map.get((EventTypes.Member, event.sender))
else:
- prev_state_ids = await context.get_prev_state_ids(
+ prev_state_ids = await unpersisted_context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@@ -764,8 +762,7 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
-
- return event, context
+ return event, unpersisted_context
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@@ -1005,7 +1002,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.create_event(
+ event, unpersisted_context = await self.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1016,6 +1013,7 @@ class EventCreationHandler:
historical=historical,
depth=depth,
)
+ context = await unpersisted_context.persist(event)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
event.sender,
@@ -1190,7 +1188,6 @@ class EventCreationHandler:
if for_batch:
assert prev_event_ids is not None
assert state_map is not None
- assert current_state_group is not None
auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
@@ -2046,7 +2043,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.create_event(
+ event, unpersisted_context = await self.create_event(
requester,
{
"type": EventTypes.Dummy,
@@ -2055,6 +2052,7 @@ class EventCreationHandler:
"sender": user_id,
},
)
+ context = await unpersisted_context.persist(event)
event.internal_metadata.proactively_send = False
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 87af31aa27..4ad2233573 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler):
)
if self.unpersisted_users_changes:
-
await self.store.update_presence(
[
self.user_to_current_state[user_id]
@@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler):
now = self.clock.time_msec()
with Measure(self.clock, "presence_update_states"):
-
# NOTE: We purposefully don't await between now and when we've
# calculated what we want to do with the new states, to avoid races.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c611efb760..e4e506e62c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -476,7 +476,7 @@ class RegistrationHandler:
# create room expects the localpart of the room alias
config["room_alias_name"] = room_alias.localpart
- info, _ = await room_creation_handler.create_room(
+ room_id, _, _ = await room_creation_handler.create_room(
fake_requester,
config=config,
ratelimit=False,
@@ -490,7 +490,7 @@ class RegistrationHandler:
user_id, authenticated_entity=self._server_name
),
target=UserID.from_string(user_id),
- room_id=info["room_id"],
+ room_id=room_id,
# Since it was just created, there are no remote hosts.
remote_room_hosts=[],
action="join",
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 837dabb3b7..b1784638f4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -51,6 +51,7 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext
from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
@@ -211,7 +212,7 @@ class RoomCreationHandler:
# the required power level to send the tombstone event.
(
tombstone_event,
- tombstone_context,
+ tombstone_unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
{
@@ -225,6 +226,9 @@ class RoomCreationHandler:
},
},
)
+ tombstone_context = await tombstone_unpersisted_context.persist(
+ tombstone_event
+ )
validate_event_for_room_version(tombstone_event)
await self._event_auth_handler.check_auth_rules_from_context(
tombstone_event
@@ -690,13 +694,14 @@ class RoomCreationHandler:
config: JsonDict,
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
- ) -> Tuple[dict, int]:
+ ) -> Tuple[str, Optional[RoomAlias], int]:
"""Creates a new room.
Args:
- requester:
- The user who requested the room creation.
- config : A dict of configuration options.
+ requester: The user who requested the room creation.
+ config: A dict of configuration options. This will be the body of
+ a /createRoom request; see
+ https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
ratelimit: set to False to disable the rate limiter
creator_join_profile:
@@ -707,14 +712,17 @@ class RoomCreationHandler:
`avatar_url` and/or `displayname`.
Returns:
- First, a dict containing the keys `room_id` and, if an alias
- was, requested, `room_alias`. Secondly, the stream_id of the
- last persisted event.
+ A 3-tuple containing:
+ - the room ID;
+ - if requested, the room alias, otherwise None; and
+ - the `stream_id` of the last persisted event.
Raises:
- SynapseError if the room ID couldn't be stored, 3pid invitation config
- validation failed, or something went horribly wrong.
- ResourceLimitError if server is blocked to some resource being
- exceeded
+ SynapseError:
+ if the room ID couldn't be stored, 3pid invitation config
+ validation failed, or something went horribly wrong.
+ ResourceLimitError:
+ if server is blocked to some resource being
+ exceeded
"""
user_id = requester.user.to_string()
@@ -864,9 +872,11 @@ class RoomCreationHandler:
)
# Check whether this visibility value is blocked by a third party module
- allowed_by_third_party_rules = await (
- self.third_party_event_rules.check_visibility_can_be_modified(
- room_id, visibility
+ allowed_by_third_party_rules = (
+ await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
)
)
if not allowed_by_third_party_rules:
@@ -1024,11 +1034,6 @@ class RoomCreationHandler:
last_sent_event_id = member_event_id
depth += 1
- result = {"room_id": room_id}
-
- if room_alias:
- result["room_alias"] = room_alias.to_string()
-
# Always wait for room creation to propagate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(room_id),
@@ -1036,7 +1041,7 @@ class RoomCreationHandler:
last_stream_id,
)
- return result, last_stream_id
+ return room_id, room_alias, last_stream_id
async def _send_events_for_new_room(
self,
@@ -1091,7 +1096,7 @@ class RoomCreationHandler:
content: JsonDict,
for_batch: bool,
**kwargs: Any,
- ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+ ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
"""
Creates an event and associated event context.
Args:
@@ -1110,20 +1115,23 @@ class RoomCreationHandler:
event_dict = create_event_dict(etype, content, **kwargs)
- new_event, new_context = await self.event_creation_handler.create_event(
+ (
+ new_event,
+ new_unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
creator,
event_dict,
prev_event_ids=prev_event,
depth=depth,
state_map=state_map,
for_batch=for_batch,
- current_state_group=current_state_group,
)
+
depth += 1
prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id
- return new_event, new_context
+ return new_event, new_unpersisted_context
try:
config = self._presets_dict[preset_config]
@@ -1133,10 +1141,10 @@ class RoomCreationHandler:
)
creation_content.update({"creator": creator_id})
- creation_event, creation_context = await create_event(
+ creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False
)
-
+ creation_context = await unpersisted_creation_context.persist(creation_event)
logger.debug("Sending %s in new room", EventTypes.Member)
ev = await self.event_creation_handler.handle_new_client_event(
requester=creator,
@@ -1180,7 +1188,6 @@ class RoomCreationHandler:
power_event, power_context = await create_event(
EventTypes.PowerLevels, pl_content, True
)
- current_state_group = power_context._state_group
events_to_send.append((power_event, power_context))
else:
power_level_content: JsonDict = {
@@ -1229,14 +1236,12 @@ class RoomCreationHandler:
power_level_content,
True,
)
- current_state_group = pl_context._state_group
events_to_send.append((pl_event, pl_context))
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
)
- current_state_group = room_alias_context._state_group
events_to_send.append((room_alias_event, room_alias_context))
if (EventTypes.JoinRules, "") not in initial_state:
@@ -1245,7 +1250,6 @@ class RoomCreationHandler:
{"join_rule": config["join_rules"]},
True,
)
- current_state_group = join_rules_context._state_group
events_to_send.append((join_rules_event, join_rules_context))
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
@@ -1254,7 +1258,6 @@ class RoomCreationHandler:
{"history_visibility": config["history_visibility"]},
True,
)
- current_state_group = visibility_context._state_group
events_to_send.append((visibility_event, visibility_context))
if config["guest_can_join"]:
@@ -1264,14 +1267,12 @@ class RoomCreationHandler:
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True,
)
- current_state_group = guest_access_context._state_group
events_to_send.append((guest_access_event, guest_access_context))
for (etype, state_key), content in initial_state.items():
event, context = await create_event(
etype, content, True, state_key=state_key
)
- current_state_group = context._state_group
events_to_send.append((event, context))
if config["encrypted"]:
@@ -1283,9 +1284,16 @@ class RoomCreationHandler:
)
events_to_send.append((encryption_event, encryption_context))
+ datastore = self.hs.get_datastores().state
+ events_and_context = (
+ await UnpersistedEventContext.batch_persist_unpersisted_contexts(
+ events_to_send, room_id, current_state_group, datastore
+ )
+ )
+
last_event = await self.event_creation_handler.handle_new_client_event(
creator,
- events_to_send,
+ events_and_context,
ignore_shadow_ban=True,
ratelimit=False,
)
@@ -1825,7 +1833,7 @@ class RoomShutdownHandler:
new_room_user_id, authenticated_entity=requester_user_id
)
- info, stream_id = await self._room_creation_handler.create_room(
+ new_room_id, _, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": RoomCreationPreset.PUBLIC_CHAT,
@@ -1834,7 +1842,6 @@ class RoomShutdownHandler:
},
ratelimit=False,
)
- new_room_id = info["room_id"]
logger.info(
"Shutting down room %r, joining to new room: %r", room_id, new_room_id
@@ -1887,6 +1894,7 @@ class RoomShutdownHandler:
# Join users to new room
if new_room_user_id:
+ assert new_room_id is not None
await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
@@ -1919,6 +1927,7 @@ class RoomShutdownHandler:
aliases_for_room = await self.store.get_aliases_for_room(room_id)
+ assert new_room_id is not None
await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id
)
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index c73d2adaad..bf9df60218 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,7 @@ class RoomBatchHandler:
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
- event, context = await self.event_creation_handler.create_event(
+ event, unpersisted_context = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service
),
@@ -345,7 +345,7 @@ class RoomBatchHandler:
historical=True,
depth=inherited_depth,
)
-
+ context = await unpersisted_context.persist(event)
assert context._state_group
# Normally this is done when persisting the event but we have to
@@ -374,7 +374,7 @@ class RoomBatchHandler:
# correct stream_ordering as they are backfilled (which decrements).
# Events are sorted by (topological_ordering, stream_ordering)
# where topological_ordering is just depth.
- for (event, context) in reversed(events_to_persist):
+ for event, context in reversed(events_to_persist):
# This call can't raise `PartialStateConflictError` since we forbid
# use of the historical batch API during partial state
await self.event_creation_handler.handle_new_client_event(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a965c7ec76..de7476f300 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.event_creation_handler.create_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier=outlier,
historical=historical,
)
-
+ context = await unpersisted_context.persist(event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.event_creation_handler.create_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
auth_event_ids=auth_event_ids,
outlier=True,
)
+ context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True
result_event = (
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4e4595312c..fd6d946c37 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1297,7 +1297,6 @@ class SyncHandler:
return RoomNotifCounts.empty()
with Measure(self.clock, "unread_notifs_for_room_id"):
-
return await self.store.get_unread_event_push_actions_by_room_for_user(
room_id,
sync_config.user.to_string(),
diff --git a/synapse/http/client.py b/synapse/http/client.py
index a05f297933..ae48e7c3f0 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -44,6 +44,7 @@ from twisted.internet.interfaces import (
IAddress,
IDelayedCall,
IHostResolution,
+ IOpenSSLContextFactory,
IReactorCore,
IReactorPluggableNameResolver,
IReactorTime,
@@ -958,8 +959,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: False)
- def getContext(self, hostname=None, port=None):
+ def getContext(self) -> SSL.Context:
return self._context
- def creatorForNetloc(self, hostname: bytes, port: int):
+ def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory:
return self
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 312aab4dcc..3302d4e48a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -440,7 +440,7 @@ class MatrixFederationHttpClient:
Args:
request: details of request to be sent
- retry_on_dns_fail: true if the request should be retied on DNS failures
+ retry_on_dns_fail: true if the request should be retried on DNS failures
timeout: number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
@@ -475,7 +475,7 @@ class MatrixFederationHttpClient:
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
- FederationDeniedError: If this destination is not on our
+ FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
@@ -871,7 +871,7 @@ class MatrixFederationHttpClient:
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
- FederationDeniedError: If this destination is not on our
+ FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
@@ -958,7 +958,7 @@ class MatrixFederationHttpClient:
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
- FederationDeniedError: If this destination is not on our
+ FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
@@ -1036,6 +1036,8 @@ class MatrixFederationHttpClient:
args: A dictionary used to create query strings, defaults to
None.
+ retry_on_dns_fail: true if the request should be retried on DNS failures
+
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
@@ -1063,7 +1065,7 @@ class MatrixFederationHttpClient:
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
- FederationDeniedError: If this destination is not on our
+ FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
@@ -1141,7 +1143,7 @@ class MatrixFederationHttpClient:
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
- FederationDeniedError: If this destination is not on our
+ FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
@@ -1197,7 +1199,7 @@ class MatrixFederationHttpClient:
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
- FederationDeniedError: If this destination is not on our
+ FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 5aed71262f..c70eee649c 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -524,6 +524,7 @@ def whitelisted_homeserver(destination: str) -> bool:
# Start spans and scopes
+
# Could use kwargs but I want these to be explicit
def start_active_span(
operation_name: str,
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
new file mode 100644
index 0000000000..ef8334ae25
--- /dev/null
+++ b/synapse/media/_base.py
@@ -0,0 +1,479 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import urllib
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+
+import attr
+
+from twisted.internet.interfaces import IConsumer
+from twisted.protocols.basic import FileSender
+from twisted.web.server import Request
+
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.http.server import finish_request, respond_with_json
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
+
+logger = logging.getLogger(__name__)
+
+# list all text content types that will have the charset default to UTF-8 when
+# none is given
+TEXT_CONTENT_TYPES = [
+ "text/css",
+ "text/csv",
+ "text/html",
+ "text/calendar",
+ "text/plain",
+ "text/javascript",
+ "application/json",
+ "application/ld+json",
+ "application/rtf",
+ "image/svg+xml",
+ "text/xml",
+]
+
+
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
+ """Parses the server name, media ID and optional file name from the request URI
+
+ Also performs some rough validation on the server name.
+
+ Args:
+ request: The `Request`.
+
+ Returns:
+ A tuple containing the parsed server name, media ID and optional file name.
+
+ Raises:
+ SynapseError(404): if parsing or validation fail for any reason
+ """
+ try:
+ # The type on postpath seems incorrect in Twisted 21.2.0.
+ postpath: List[bytes] = request.postpath # type: ignore
+ assert postpath
+
+ # This allows users to append e.g. /test.png to the URL. Useful for
+ # clients that parse the URL to see content type.
+ server_name_bytes, media_id_bytes = postpath[:2]
+ server_name = server_name_bytes.decode("utf-8")
+ media_id = media_id_bytes.decode("utf8")
+
+ # Validate the server name, raising if invalid
+ parse_and_validate_server_name(server_name)
+
+ file_name = None
+ if len(postpath) > 2:
+ try:
+ file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
+ except UnicodeDecodeError:
+ pass
+ return server_name, media_id, file_name
+ except Exception:
+ raise SynapseError(
+ 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN
+ )
+
+
+def respond_404(request: SynapseRequest) -> None:
+ respond_with_json(
+ request,
+ 404,
+ cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND),
+ send_cors=True,
+ )
+
+
+async def respond_with_file(
+ request: SynapseRequest,
+ media_type: str,
+ file_path: str,
+ file_size: Optional[int] = None,
+ upload_name: Optional[str] = None,
+) -> None:
+ logger.debug("Responding with %r", file_path)
+
+ if os.path.isfile(file_path):
+ if file_size is None:
+ stat = os.stat(file_path)
+ file_size = stat.st_size
+
+ add_file_headers(request, media_type, file_size, upload_name)
+
+ with open(file_path, "rb") as f:
+ await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+
+ finish_request(request)
+ else:
+ respond_404(request)
+
+
+def add_file_headers(
+ request: Request,
+ media_type: str,
+ file_size: Optional[int],
+ upload_name: Optional[str],
+) -> None:
+ """Adds the correct response headers in preparation for responding with the
+ media.
+
+ Args:
+ request
+ media_type: The media/content type.
+ file_size: Size in bytes of the media, if known.
+ upload_name: The name of the requested file, if any.
+ """
+
+ def _quote(x: str) -> str:
+ return urllib.parse.quote(x.encode("utf-8"))
+
+ # Default to a UTF-8 charset for text content types.
+ # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
+ if media_type.lower() in TEXT_CONTENT_TYPES:
+ content_type = media_type + "; charset=UTF-8"
+ else:
+ content_type = media_type
+
+ request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
+ if upload_name:
+ # RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
+ #
+ # `filename` is defined to be a `value`, which is defined by RFC2616
+ # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token`
+ # is (essentially) a single US-ASCII word, and a `quoted-string` is a
+ # US-ASCII string surrounded by double-quotes, using backslash as an
+ # escape character. Note that %-encoding is *not* permitted.
+ #
+ # `filename*` is defined to be an `ext-value`, which is defined in
+ # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`,
+ # where `value-chars` is essentially a %-encoded string in the given charset.
+ #
+ # [1]: https://tools.ietf.org/html/rfc6266#section-4.1
+ # [2]: https://tools.ietf.org/html/rfc2616#section-3.6
+ # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1
+
+ # We avoid the quoted-string version of `filename`, because (a) synapse didn't
+ # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we
+ # may as well just do the filename* version.
+ if _can_encode_filename_as_token(upload_name):
+ disposition = "inline; filename=%s" % (upload_name,)
+ else:
+ disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),)
+
+ request.setHeader(b"Content-Disposition", disposition.encode("ascii"))
+
+ # cache for at least a day.
+ # XXX: we might want to turn this off for data we don't want to
+ # recommend caching as it's sensitive or private - or at least
+ # select private. don't bother setting Expires as all our
+ # clients are smart enough to be happy with Cache-Control
+ request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
+ if file_size is not None:
+ request.setHeader(b"Content-Length", b"%d" % (file_size,))
+
+ # Tell web crawlers to not index, archive, or follow links in media. This
+ # should help to prevent things in the media repo from showing up in web
+ # search results.
+ request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
+
+
+# separators as defined in RFC2616. SP and HT are handled separately.
+# see _can_encode_filename_as_token.
+_FILENAME_SEPARATOR_CHARS = {
+ "(",
+ ")",
+ "<",
+ ">",
+ "@",
+ ",",
+ ";",
+ ":",
+ "\\",
+ '"',
+ "/",
+ "[",
+ "]",
+ "?",
+ "=",
+ "{",
+ "}",
+}
+
+
+def _can_encode_filename_as_token(x: str) -> bool:
+ for c in x:
+ # from RFC2616:
+ #
+ # token = 1*<any CHAR except CTLs or separators>
+ #
+ # separators = "(" | ")" | "<" | ">" | "@"
+ # | "," | ";" | ":" | "\" | <">
+ # | "/" | "[" | "]" | "?" | "="
+ # | "{" | "}" | SP | HT
+ #
+ # CHAR = <any US-ASCII character (octets 0 - 127)>
+ #
+ # CTL = <any US-ASCII control character
+ # (octets 0 - 31) and DEL (127)>
+ #
+ if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS:
+ return False
+ return True
+
+
+async def respond_with_responder(
+ request: SynapseRequest,
+ responder: "Optional[Responder]",
+ media_type: str,
+ file_size: Optional[int],
+ upload_name: Optional[str] = None,
+) -> None:
+ """Responds to the request with given responder. If responder is None then
+ returns 404.
+
+ Args:
+ request
+ responder
+ media_type: The media/content type.
+ file_size: Size in bytes of the media. If not known it should be None
+ upload_name: The name of the requested file, if any.
+ """
+ if not responder:
+ respond_404(request)
+ return
+
+ # If we have a responder we *must* use it as a context manager.
+ with responder:
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
+
+ logger.debug("Responding to media request with responder %s", responder)
+ add_file_headers(request, media_type, file_size, upload_name)
+ try:
+ await responder.write_to_consumer(request)
+ except Exception as e:
+ # The majority of the time this will be due to the client having gone
+ # away. Unfortunately, Twisted simply throws a generic exception at us
+ # in that case.
+ logger.warning("Failed to write to consumer: %s %s", type(e), e)
+
+ # Unregister the producer, if it has one, so Twisted doesn't complain
+ if request.producer:
+ request.unregisterProducer()
+
+ finish_request(request)
+
+
+class Responder(ABC):
+ """Represents a response that can be streamed to the requester.
+
+ Responder is a context manager which *must* be used, so that any resources
+ held can be cleaned up.
+ """
+
+ @abstractmethod
+ def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
+ """Stream response into consumer
+
+ Args:
+ consumer: The consumer to stream into.
+
+ Returns:
+ Resolves once the response has finished being written
+ """
+ raise NotImplementedError()
+
+ def __enter__(self) -> None: # noqa: B027
+ pass
+
+ def __exit__( # noqa: B027
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ pass
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThumbnailInfo:
+ """Details about a generated thumbnail."""
+
+ width: int
+ height: int
+ method: str
+ # Content type of thumbnail, e.g. image/png
+ type: str
+ # The size of the media file, in bytes.
+ length: Optional[int] = None
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FileInfo:
+ """Details about a requested/uploaded file."""
+
+ # The server name where the media originated from, or None if local.
+ server_name: Optional[str]
+ # The local ID of the file. For local files this is the same as the media_id
+ file_id: str
+ # If the file is for the url preview cache
+ url_cache: bool = False
+ # Whether the file is a thumbnail or not.
+ thumbnail: Optional[ThumbnailInfo] = None
+
+ # The below properties exist to maintain compatibility with third-party modules.
+ @property
+ def thumbnail_width(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.width
+
+ @property
+ def thumbnail_height(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.height
+
+ @property
+ def thumbnail_method(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.method
+
+ @property
+ def thumbnail_type(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.type
+
+ @property
+ def thumbnail_length(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.length
+
+
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
+ """
+ Get the filename of the downloaded file by inspecting the
+ Content-Disposition HTTP header.
+
+ Args:
+ headers: The HTTP request headers.
+
+ Returns:
+ The filename, or None.
+ """
+ content_disposition = headers.get(b"Content-Disposition", [b""])
+
+ # No header, bail out.
+ if not content_disposition[0]:
+ return None
+
+ _, params = _parse_header(content_disposition[0])
+
+ upload_name = None
+
+ # First check if there is a valid UTF-8 filename
+ upload_name_utf8 = params.get(b"filename*", None)
+ if upload_name_utf8:
+ if upload_name_utf8.lower().startswith(b"utf-8''"):
+ upload_name_utf8 = upload_name_utf8[7:]
+ # We have a filename*= section. This MUST be ASCII, and any UTF-8
+ # bytes are %-quoted.
+ try:
+ # Once it is decoded, we can then unquote the %-encoded
+ # parts strictly into a unicode string.
+ upload_name = urllib.parse.unquote(
+ upload_name_utf8.decode("ascii"), errors="strict"
+ )
+ except UnicodeDecodeError:
+ # Incorrect UTF-8.
+ pass
+
+ # If there isn't check for an ascii name.
+ if not upload_name:
+ upload_name_ascii = params.get(b"filename", None)
+ if upload_name_ascii and is_ascii(upload_name_ascii):
+ upload_name = upload_name_ascii.decode("ascii")
+
+ # This may be None here, indicating we did not find a matching name.
+ return upload_name
+
+
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
+ """Parse a Content-type like header.
+
+ Cargo-culted from `cgi`, but works on bytes rather than strings.
+
+ Args:
+ line: header to be parsed
+
+ Returns:
+ The main content-type, followed by the parameter dictionary
+ """
+ parts = _parseparam(b";" + line)
+ key = next(parts)
+ pdict = {}
+ for p in parts:
+ i = p.find(b"=")
+ if i >= 0:
+ name = p[:i].strip().lower()
+ value = p[i + 1 :].strip()
+
+ # strip double-quotes
+ if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
+ value = value[1:-1]
+ value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"')
+ pdict[name] = value
+
+ return key, pdict
+
+
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
+ """Generator which splits the input on ;, respecting double-quoted sequences
+
+ Cargo-culted from `cgi`, but works on bytes rather than strings.
+
+ Args:
+ s: header to be parsed
+
+ Returns:
+ The split input
+ """
+ while s[:1] == b";":
+ s = s[1:]
+
+ # look for the next ;
+ end = s.find(b";")
+
+ # if there is an odd number of " marks between here and the next ;, skip to the
+ # next ; instead
+ while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
+ end = s.find(b";", end + 1)
+
+ if end < 0:
+ end = len(s)
+ f = s[:end]
+ yield f.strip()
+ s = s[end:]
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/media/filepath.py
index 1f6441c412..1f6441c412 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/media/filepath.py
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/media/media_repository.py
index c70e1837af..b81e3c2b0c 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -32,18 +32,10 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement
-from synapse.http.server import UnrecognizedRequestResource
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID
-from synapse.util.async_helpers import Linearizer
-from synapse.util.retryutils import NotRetryingDestination
-from synapse.util.stringutils import random_string
-
-from ._base import (
+from synapse.media._base import (
FileInfo,
Responder,
ThumbnailInfo,
@@ -51,15 +43,15 @@ from ._base import (
respond_404,
respond_with_responder,
)
-from .config_resource import MediaConfigResource
-from .download_resource import DownloadResource
-from .filepath import MediaFilePaths
-from .media_storage import MediaStorage
-from .preview_url_resource import PreviewUrlResource
-from .storage_provider import StorageProviderWrapper
-from .thumbnail_resource import ThumbnailResource
-from .thumbnailer import Thumbnailer, ThumbnailError
-from .upload_resource import UploadResource
+from synapse.media.filepath import MediaFilePaths
+from synapse.media.media_storage import MediaStorage
+from synapse.media.storage_provider import StorageProviderWrapper
+from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
+from synapse.util.async_helpers import Linearizer
+from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1044,69 +1036,3 @@ class MediaRepository:
removed_media.append(media_id)
return removed_media, len(removed_media)
-
-
-class MediaRepositoryResource(UnrecognizedRequestResource):
- """File uploading and downloading.
-
- Uploads are POSTed to a resource which returns a token which is used to GET
- the download::
-
- => POST /_matrix/media/r0/upload HTTP/1.1
- Content-Type: <media-type>
- Content-Length: <content-length>
-
- <media>
-
- <= HTTP/1.1 200 OK
- Content-Type: application/json
-
- { "content_uri": "mxc://<server-name>/<media-id>" }
-
- => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1
-
- <= HTTP/1.1 200 OK
- Content-Type: <media-type>
- Content-Disposition: attachment;filename=<upload-filename>
-
- <media>
-
- Clients can get thumbnails by supplying a desired width and height and
- thumbnailing method::
-
- => GET /_matrix/media/r0/thumbnail/<server_name>
- /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1
-
- <= HTTP/1.1 200 OK
- Content-Type: image/jpeg or image/png
-
- <thumbnail>
-
- The thumbnail methods are "crop" and "scale". "scale" tries to return an
- image where either the width or the height is smaller than the requested
- size. The client should then scale and letterbox the image if it needs to
- fit within a given rectangle. "crop" tries to return an image where the
- width and height are close to the requested size and the aspect matches
- the requested size. The client should scale the image if it needs to fit
- within a given rectangle.
- """
-
- def __init__(self, hs: "HomeServer"):
- # If we're not configured to use it, raise if we somehow got here.
- if not hs.config.media.can_load_media_repo:
- raise ConfigError("Synapse is not configured to use a media repo.")
-
- super().__init__()
- media_repo = hs.get_media_repository()
-
- self.putChild(b"upload", UploadResource(hs, media_repo))
- self.putChild(b"download", DownloadResource(hs, media_repo))
- self.putChild(
- b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
- )
- if hs.config.media.url_preview_enabled:
- self.putChild(
- b"preview_url",
- PreviewUrlResource(hs, media_repo, media_repo.media_storage),
- )
- self.putChild(b"config", MediaConfigResource(hs))
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
new file mode 100644
index 0000000000..a7e22a91e1
--- /dev/null
+++ b/synapse/media/media_storage.py
@@ -0,0 +1,374 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import contextlib
+import logging
+import os
+import shutil
+from types import TracebackType
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ BinaryIO,
+ Callable,
+ Generator,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+)
+
+import attr
+
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
+from twisted.protocols.basic import FileSender
+
+import synapse
+from synapse.api.errors import NotFoundError
+from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+ from synapse.media.storage_provider import StorageProvider
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage:
+ """Responsible for storing/fetching files from local sources.
+
+ Args:
+ hs
+ local_media_directory: Base path where we store media on disk
+ filepaths
+ storage_providers: List of StorageProvider that are used to fetch and store files.
+ """
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ local_media_directory: str,
+ filepaths: MediaFilePaths,
+ storage_providers: Sequence["StorageProvider"],
+ ):
+ self.hs = hs
+ self.reactor = hs.get_reactor()
+ self.local_media_directory = local_media_directory
+ self.filepaths = filepaths
+ self.storage_providers = storage_providers
+ self.spam_checker = hs.get_spam_checker()
+ self.clock = hs.get_clock()
+
+ async def store_file(self, source: IO, file_info: FileInfo) -> str:
+ """Write `source` to the on disk media store, and also any other
+ configured storage providers
+
+ Args:
+ source: A file like object that should be written
+ file_info: Info about the file to store
+
+ Returns:
+ the file path written to in the primary media store
+ """
+
+ with self.store_into_file(file_info) as (f, fname, finish_cb):
+ # Write to the main repository
+ await self.write_to_file(source, f)
+ await finish_cb()
+
+ return fname
+
+ async def write_to_file(self, source: IO, output: IO) -> None:
+ """Asynchronously write the `source` to `output`."""
+ await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
+ @contextlib.contextmanager
+ def store_into_file(
+ self, file_info: FileInfo
+ ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
+ """Context manager used to get a file like object to write into, as
+ described by file_info.
+
+ Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+ like object that can be written to, fname is the absolute path of file
+ on disk, and finish_cb is a function that returns an awaitable.
+
+ fname can be used to read the contents from after upload, e.g. to
+ generate thumbnails.
+
+ finish_cb must be called and waited on after the file has been
+ successfully been written to. Should not be called if there was an
+ error.
+
+ Args:
+ file_info: Info about the file to store
+
+ Example:
+
+ with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ # .. write into f ...
+ await finish_cb()
+ """
+
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ os.makedirs(dirname, exist_ok=True)
+
+ finished_called = [False]
+
+ try:
+ with open(fname, "wb") as f:
+
+ async def finish() -> None:
+ # Ensure that all writes have been flushed and close the
+ # file.
+ f.flush()
+ f.close()
+
+ spam_check = await self.spam_checker.check_media_file_for_spam(
+ ReadableFileWrapper(self.clock, fname), file_info
+ )
+ if spam_check != synapse.module_api.NOT_SPAM:
+ logger.info("Blocking media due to spam checker")
+ # Note that we'll delete the stored media, due to the
+ # try/except below. The media also won't be stored in
+ # the DB.
+ # We currently ignore any additional field returned by
+ # the spam-check API.
+ raise SpamMediaException(errcode=spam_check[0])
+
+ for provider in self.storage_providers:
+ await provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
+ yield f, fname, finish
+ except Exception as e:
+ try:
+ os.remove(fname)
+ except Exception:
+ pass
+
+ raise e from None
+
+ if not finished_called:
+ raise Exception("Finished callback not called")
+
+ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
+ """Attempts to fetch media described by file_info from the local cache
+ and configured storage providers.
+
+ Args:
+ file_info
+
+ Returns:
+ Returns a Responder if the file was found, otherwise None.
+ """
+ paths = [self._file_info_to_path(file_info)]
+
+ # fallback for remote thumbnails with no method in the filename
+ if file_info.thumbnail and file_info.server_name:
+ paths.append(
+ self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ )
+ )
+
+ for path in paths:
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ logger.debug("responding with local file %s", local_path)
+ return FileResponder(open(local_path, "rb"))
+ logger.debug("local file %s did not exist", local_path)
+
+ for provider in self.storage_providers:
+ for path in paths:
+ res: Any = await provider.fetch(path, file_info)
+ if res:
+ logger.debug("Streaming %s from %s", path, provider)
+ return res
+ logger.debug("%s not found on %s", path, provider)
+
+ return None
+
+ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
+ """Ensures that the given file is in the local cache. Attempts to
+ download it from storage providers if it isn't.
+
+ Args:
+ file_info
+
+ Returns:
+ Full path to local file
+ """
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ return local_path
+
+ # Fallback for paths without method names
+ # Should be removed in the future
+ if file_info.thumbnail and file_info.server_name:
+ legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ )
+ legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+ if os.path.exists(legacy_local_path):
+ return legacy_local_path
+
+ dirname = os.path.dirname(local_path)
+ os.makedirs(dirname, exist_ok=True)
+
+ for provider in self.storage_providers:
+ res: Any = await provider.fetch(path, file_info)
+ if res:
+ with res:
+ consumer = BackgroundFileConsumer(
+ open(local_path, "wb"), self.reactor
+ )
+ await res.write_to_consumer(consumer)
+ await consumer.wait()
+ return local_path
+
+ raise NotFoundError()
+
+ def _file_info_to_path(self, file_info: FileInfo) -> str:
+ """Converts file_info into a relative path.
+
+ The path is suitable for storing files under a directory, e.g. used to
+ store files on local FS under the base media repository directory.
+ """
+ if file_info.url_cache:
+ if file_info.thumbnail:
+ return self.filepaths.url_cache_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
+ )
+ return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+ if file_info.server_name:
+ if file_info.thumbnail:
+ return self.filepaths.remote_media_thumbnail_rel(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
+ )
+ return self.filepaths.remote_media_filepath_rel(
+ file_info.server_name, file_info.file_id
+ )
+
+ if file_info.thumbnail:
+ return self.filepaths.local_media_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
+ )
+ return self.filepaths.local_media_filepath_rel(file_info.file_id)
+
+
+def _write_file_synchronously(source: IO, dest: IO) -> None:
+ """Write `source` to the file like `dest` synchronously. Should be called
+ from a thread.
+
+ Args:
+ source: A file like object that's to be written
+ dest: A file like object to be written to
+ """
+ source.seek(0) # Ensure we read from the start of the file
+ shutil.copyfileobj(source, dest)
+
+
+class FileResponder(Responder):
+ """Wraps an open file that can be sent to a request.
+
+ Args:
+ open_file: A file like object to be streamed ot the client,
+ is closed when finished streaming.
+ """
+
+ def __init__(self, open_file: IO):
+ self.open_file = open_file
+
+ def write_to_consumer(self, consumer: IConsumer) -> Deferred:
+ return make_deferred_yieldable(
+ FileSender().beginFileTransfer(self.open_file, consumer)
+ )
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+ """The media was blocked by a spam checker, so we simply 404 the request (in
+ the same way as if it was quarantined).
+ """
+
+
+@attr.s(slots=True, auto_attribs=True)
+class ReadableFileWrapper:
+ """Wrapper that allows reading a file in chunks, yielding to the reactor,
+ and writing to a callback.
+
+ This is simplified `FileSender` that takes an IO object rather than an
+ `IConsumer`.
+ """
+
+ CHUNK_SIZE = 2**14
+
+ clock: Clock
+ path: str
+
+ async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
+ """Reads the file in chunks and calls the callback with each chunk."""
+
+ with open(self.path, "rb") as file:
+ while True:
+ chunk = file.read(self.CHUNK_SIZE)
+ if not chunk:
+ break
+
+ callback(chunk)
+
+ # We yield to the reactor by sleeping for 0 seconds.
+ await self.clock.sleep(0)
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/media/oembed.py
index 7592aa5d47..c0eaf04be5 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/media/oembed.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, List, Optional
import attr
-from synapse.rest.media.v1.preview_html import parse_html_description
+from synapse.media.preview_html import parse_html_description
from synapse.types import JsonDict
from synapse.util import json_decoder
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/media/preview_html.py
index 516d0434f0..516d0434f0 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/media/preview_html.py
diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py
new file mode 100644
index 0000000000..1c9b71d69c
--- /dev/null
+++ b/synapse/media/storage_provider.py
@@ -0,0 +1,181 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+import os
+import shutil
+from typing import TYPE_CHECKING, Callable, Optional
+
+from synapse.config._base import Config
+from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
+
+from ._base import FileInfo, Responder
+from .media_storage import FileResponder
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class StorageProvider(metaclass=abc.ABCMeta):
+ """A storage provider is a service that can store uploaded media and
+ retrieve them.
+ """
+
+ @abc.abstractmethod
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
+ """Store the file described by file_info. The actual contents can be
+ retrieved by reading the file in file_info.upload_path.
+
+ Args:
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
+ """
+
+ @abc.abstractmethod
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
+ """Attempt to fetch the file described by file_info and stream it
+ into writer.
+
+ Args:
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
+
+ Returns:
+ Returns a Responder if the provider has the file, otherwise returns None.
+ """
+
+
+class StorageProviderWrapper(StorageProvider):
+ """Wraps a storage provider and provides various config options
+
+ Args:
+ backend: The storage provider to wrap.
+ store_local: Whether to store new local files or not.
+ store_synchronous: Whether to wait for file to be successfully
+ uploaded, or todo the upload in the background.
+ store_remote: Whether remote media should be uploaded
+ """
+
+ def __init__(
+ self,
+ backend: StorageProvider,
+ store_local: bool,
+ store_synchronous: bool,
+ store_remote: bool,
+ ):
+ self.backend = backend
+ self.store_local = store_local
+ self.store_synchronous = store_synchronous
+ self.store_remote = store_remote
+
+ def __str__(self) -> str:
+ return "StorageProviderWrapper[%s]" % (self.backend,)
+
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
+ if not file_info.server_name and not self.store_local:
+ return None
+
+ if file_info.server_name and not self.store_remote:
+ return None
+
+ if file_info.url_cache:
+ # The URL preview cache is short lived and not worth offloading or
+ # backing up.
+ return None
+
+ if self.store_synchronous:
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
+ else:
+ # TODO: Handle errors.
+ async def store() -> None:
+ try:
+ return await maybe_awaitable(
+ self.backend.store_file(path, file_info)
+ )
+ except Exception:
+ logger.exception("Error storing file")
+
+ run_in_background(store)
+
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
+ if file_info.url_cache:
+ # Files in the URL preview cache definitely aren't stored here,
+ # so avoid any potentially slow I/O or network access.
+ return None
+
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ return await maybe_awaitable(self.backend.fetch(path, file_info))
+
+
+class FileStorageProviderBackend(StorageProvider):
+ """A storage provider that stores files in a directory on a filesystem.
+
+ Args:
+ hs
+ config: The config returned by `parse_config`.
+ """
+
+ def __init__(self, hs: "HomeServer", config: str):
+ self.hs = hs
+ self.cache_directory = hs.config.media.media_store_path
+ self.base_directory = config
+
+ def __str__(self) -> str:
+ return "FileStorageProviderBackend[%s]" % (self.base_directory,)
+
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
+ """See StorageProvider.store_file"""
+
+ primary_fname = os.path.join(self.cache_directory, path)
+ backup_fname = os.path.join(self.base_directory, path)
+
+ dirname = os.path.dirname(backup_fname)
+ os.makedirs(dirname, exist_ok=True)
+
+ # mypy needs help inferring the type of the second parameter, which is generic
+ shutil_copyfile: Callable[[str, str], str] = shutil.copyfile
+ await defer_to_thread(
+ self.hs.get_reactor(),
+ shutil_copyfile,
+ primary_fname,
+ backup_fname,
+ )
+
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
+ """See StorageProvider.fetch"""
+
+ backup_fname = os.path.join(self.base_directory, path)
+ if os.path.isfile(backup_fname):
+ return FileResponder(open(backup_fname, "rb"))
+
+ return None
+
+ @staticmethod
+ def parse_config(config: dict) -> str:
+ """Called on startup to parse config supplied. This should parse
+ the config and raise if there is a problem.
+
+ The returned value is passed into the constructor.
+
+ In this case we only care about a single param, the directory, so let's
+ just pull that out.
+ """
+ return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/media/thumbnailer.py
index 9480cc5763..f909a4fb9a 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/media/thumbnailer.py
@@ -38,7 +38,6 @@ class ThumbnailError(Exception):
class Thumbnailer:
-
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index b01372565d..8ce5887229 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -87,7 +87,6 @@ class LaterGauge(Collector):
]
def collect(self) -> Iterable[Metric]:
-
g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
try:
diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py
index b7d47ce3e7..a22c4e5bbd 100644
--- a/synapse/metrics/_gc.py
+++ b/synapse/metrics/_gc.py
@@ -139,7 +139,6 @@ def install_gc_manager() -> None:
class PyPyGCStats(Collector):
def collect(self) -> Iterable[Metric]:
-
# @stats is a pretty-printer object with __str__() returning a nice table,
# plus some fields that contain data from that table.
# unfortunately, fields are pretty-printed themselves (i. e. '4.5MB').
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d22dd19d38..424239e3df 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -64,9 +64,11 @@ from synapse.events.third_party_rules import (
CHECK_EVENT_ALLOWED_CALLBACK,
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK,
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK,
+ ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK,
ON_CREATE_ROOM_CALLBACK,
ON_NEW_EVENT_CALLBACK,
ON_PROFILE_UPDATE_CALLBACK,
+ ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK,
ON_THREEPID_BIND_CALLBACK,
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK,
)
@@ -357,6 +359,12 @@ class ModuleApi:
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None,
on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
+ on_add_user_third_party_identifier: Optional[
+ ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+ ] = None,
+ on_remove_user_third_party_identifier: Optional[
+ ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+ ] = None,
) -> None:
"""Registers callbacks for third party event rules capabilities.
@@ -373,6 +381,8 @@ class ModuleApi:
on_profile_update=on_profile_update,
on_user_deactivation_status_changed=on_user_deactivation_status_changed,
on_threepid_bind=on_threepid_bind,
+ on_add_user_third_party_identifier=on_add_user_third_party_identifier,
+ on_remove_user_third_party_identifier=on_remove_user_third_party_identifier,
)
def register_presence_router_callbacks(
@@ -1576,14 +1586,14 @@ class ModuleApi:
)
requester = create_requester(user_id)
- room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room(
+ room_id, room_alias, _ = await self._hs.get_room_creation_handler().create_room(
requester=requester,
config=config,
ratelimit=ratelimit,
creator_join_profile=creator_join_profile,
)
-
- return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None)
+ room_alias_str = room_alias.to_string() if room_alias else None
+ return room_id, room_alias_str
async def set_displayname(
self,
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 5fc38431ba..3c4a152d6b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -276,7 +276,7 @@ class BulkPushRuleEvaluator:
if related_event is not None:
related_events[relation_type] = _flatten_dict(
related_event,
- msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key,
)
reply_event_id = (
@@ -294,7 +294,7 @@ class BulkPushRuleEvaluator:
if related_event is not None:
related_events["m.in_reply_to"] = _flatten_dict(
related_event,
- msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key,
)
# indicate that this is from a fallback relation.
@@ -330,7 +330,6 @@ class BulkPushRuleEvaluator:
context: EventContext,
event_id_to_event: Mapping[str, EventBase],
) -> None:
-
if (
not event.internal_metadata.is_notifiable()
or event.internal_metadata.is_historical()
@@ -413,7 +412,7 @@ class BulkPushRuleEvaluator:
evaluator = PushRuleEvaluator(
_flatten_dict(
event,
- msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key,
),
has_mentions,
user_mentions,
@@ -508,7 +507,7 @@ def _flatten_dict(
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, JsonValue]] = None,
*,
- msc3783_escape_event_match_key: bool = False,
+ msc3873_escape_event_match_key: bool = False,
) -> Dict[str, JsonValue]:
"""
Given a JSON dictionary (or event) which might contain sub dictionaries,
@@ -537,7 +536,7 @@ def _flatten_dict(
if result is None:
result = {}
for key, value in d.items():
- if msc3783_escape_event_match_key:
+ if msc3873_escape_event_match_key:
# Escape periods in the key with a backslash (and backslashes with an
# extra backslash). This is since a period is used as a separator between
# nested fields.
@@ -553,7 +552,7 @@ def _flatten_dict(
value,
prefix=(prefix + [key]),
result=result,
- msc3783_escape_event_match_key=msc3783_escape_event_match_key,
+ msc3873_escape_event_match_key=msc3873_escape_event_match_key,
)
# `room_version` should only ever be set when looking at the top level of an event
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 2374f810c9..111ec07e64 100644
--- a/synapse/replication/http/account_data.py
+++ b/synapse/replication/http/account_data.py
@@ -265,7 +265,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
@staticmethod
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]
-
return {}
async def _handle_request( # type: ignore[override]
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index ecea6fc915..cc3929dcf5 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -195,7 +195,6 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
async def _serialize_payload( # type: ignore[override]
user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
-
return {
"user_id": user_id,
"device_id": device_id,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index fd1c0ec6af..dfc061eb5e 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -328,7 +328,6 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
outbound_redis_connection: txredisapi.ConnectionHandler,
channel_names: List[str],
):
-
super().__init__(
hs,
uuid="subscriber",
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 14b6705862..ad9b760713 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -139,7 +139,6 @@ class EventsStream(Stream):
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult:
-
# the events stream merges together three separate sources:
# * new events
# * current_state changes
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 14c4e6ebbb..c327f15043 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -108,8 +108,7 @@ class ClientRestResource(JsonResource):
if is_main_process:
logout.register_servlets(hs, client_resource)
sync.register_servlets(hs, client_resource)
- if is_main_process:
- filter.register_servlets(hs, client_resource)
+ filter.register_servlets(hs, client_resource)
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
if is_main_process:
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index a3beb74e2c..c546ef7e23 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -53,11 +53,11 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs: "HomeServer"):
- self.auth = hs.get_auth()
- self.store = hs.get_datastores().main
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastores().main
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- await assert_requester_is_admin(self.auth, request)
+ await assert_requester_is_admin(self._auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
@@ -79,7 +79,7 @@ class EventReportsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- event_reports, total = await self.store.get_event_reports_paginate(
+ event_reports, total = await self._store.get_event_reports_paginate(
start, limit, direction, user_id, room_id
)
ret = {"event_reports": event_reports, "total": total}
@@ -108,13 +108,13 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.auth = hs.get_auth()
- self.store = hs.get_datastores().main
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastores().main
async def on_GET(
self, request: SynapseRequest, report_id: str
) -> Tuple[int, JsonDict]:
- await assert_requester_is_admin(self.auth, request)
+ await assert_requester_is_admin(self._auth, request)
message = (
"The report_id parameter must be a string representing a positive integer."
@@ -131,8 +131,33 @@ class EventReportDetailRestServlet(RestServlet):
HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
)
- ret = await self.store.get_event_report(resolved_report_id)
+ ret = await self._store.get_event_report(resolved_report_id)
if not ret:
raise NotFoundError("Event report not found")
return HTTPStatus.OK, ret
+
+ async def on_DELETE(
+ self, request: SynapseRequest, report_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self._auth, request)
+
+ message = (
+ "The report_id parameter must be a string representing a positive integer."
+ )
+ try:
+ resolved_report_id = int(report_id)
+ except ValueError:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+ )
+
+ if resolved_report_id < 0:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+ )
+
+ if await self._store.delete_event_report(resolved_report_id):
+ return HTTPStatus.OK, {}
+
+ raise NotFoundError("Event report not found")
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 1d6e4982d7..4de56bf13f 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -75,7 +75,6 @@ class RoomRestV2Servlet(RestServlet):
async def on_DELETE(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
-
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester)
@@ -144,7 +143,6 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
-
await assert_requester_is_admin(self._auth, request)
if not RoomID.is_valid(room_id):
@@ -181,7 +179,6 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, delete_id: str
) -> Tuple[int, JsonDict]:
-
await assert_requester_is_admin(self._auth, request)
delete_status = self._pagination_handler.get_delete_status(delete_id)
@@ -438,7 +435,6 @@ class RoomStateRestServlet(RestServlet):
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
-
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$")
def __init__(self, hs: "HomeServer"):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 0c0bf540b9..357e9a574d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -304,13 +304,20 @@ class UserRestServletV2(RestServlet):
# remove old threepids
for medium, address in del_threepids:
try:
- await self.auth_handler.delete_threepid(
- user_id, medium, address, None
+ # Attempt to remove any known bindings of this third-party ID
+ # and user ID from identity servers.
+ await self.hs.get_identity_handler().try_unbind_threepid(
+ user_id, medium, address, id_server=None
)
except Exception:
logger.exception("Failed to remove threepids")
raise SynapseError(500, "Failed to remove threepids")
+ # Delete the local association of this user ID and third-party ID.
+ await self.auth_handler.delete_local_threepid(
+ user_id, medium, address
+ )
+
# add new threepids
current_time = self.hs.get_clock().time_msec()
for medium, address in add_threepids:
@@ -683,8 +690,12 @@ class AccountValidityRenewServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if self.account_activity_handler.on_legacy_admin_request_callback:
- expiration_ts = await (
- self.account_activity_handler.on_legacy_admin_request_callback(request)
+ expiration_ts = (
+ await (
+ self.account_activity_handler.on_legacy_admin_request_callback(
+ request
+ )
+ )
)
else:
body = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 662f5bf762..484d7440a4 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -768,7 +768,9 @@ class ThreepidDeleteRestServlet(RestServlet):
user_id = requester.user.to_string()
try:
- ret = await self.auth_handler.delete_threepid(
+ # Attempt to remove any known bindings of this third-party ID
+ # and user ID from identity servers.
+ ret = await self.hs.get_identity_handler().try_unbind_threepid(
user_id, body.medium, body.address, body.id_server
)
except Exception:
@@ -783,6 +785,11 @@ class ThreepidDeleteRestServlet(RestServlet):
else:
id_server_unbind_result = "no-support"
+ # Delete the local association of this user ID and third-party ID.
+ await self.auth_handler.delete_local_threepid(
+ user_id, body.medium, body.address
+ )
+
return 200, {"id_server_unbind_result": id_server_unbind_result}
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index eb77337044..276a1b405d 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -97,7 +97,6 @@ class AuthRestServlet(RestServlet):
return None
async def on_POST(self, request: Request, stagetype: str) -> None:
-
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index cc1c2f9731..236199897c 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -79,7 +79,6 @@ class CreateFilterRestServlet(RestServlet):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
-
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 3cb1e7e375..bce806f2bb 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -628,10 +628,12 @@ class RegisterRestServlet(RestServlet):
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
- desired_username = await (
- self.password_auth_provider.get_username_for_registration(
- auth_result,
- params,
+ desired_username = (
+ await (
+ self.password_auth_provider.get_username_for_registration(
+ auth_result,
+ params,
+ )
)
)
@@ -682,9 +684,11 @@ class RegisterRestServlet(RestServlet):
session_id
)
- display_name = await (
- self.password_auth_provider.get_displayname_for_registration(
- auth_result, params
+ display_name = (
+ await (
+ self.password_auth_provider.get_displayname_for_registration(
+ auth_result, params
+ )
)
)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index d0db85cca7..14b04810a1 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -160,11 +160,11 @@ class RoomCreateRestServlet(TransactionRestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- info, _ = await self._room_creation_handler.create_room(
+ room_id, _, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
- return 200, info
+ return 200, {"room_id": room_id}
def get_room_config(self, request: Request) -> JsonDict:
user_supplied_config = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index f2013faeb2..8fcb8ac3d9 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -16,7 +16,7 @@ import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
-from synapse.api.constants import EduTypes, Membership, PresenceState
+from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
@@ -139,7 +139,28 @@ class SyncRestServlet(RestServlet):
device_id,
)
- request_key = (user, timeout, since, filter_id, full_state, device_id)
+ # Stream position of the last ignored users account data event for this user,
+ # if we're initial syncing.
+ # We include this in the request key to invalidate an initial sync
+ # in the response cache once the set of ignored users has changed.
+ # (We filter out ignored users from timeline events, so our sync response
+ # is invalid once the set of ignored users changes.)
+ last_ignore_accdata_streampos: Optional[int] = None
+ if not since:
+ # No `since`, so this is an initial sync.
+ last_ignore_accdata_streampos = await self.store.get_latest_stream_id_for_global_account_data_by_type_for_user(
+ user.to_string(), AccountDataTypes.IGNORED_USER_LIST
+ )
+
+ request_key = (
+ user,
+ timeout,
+ since,
+ filter_id,
+ full_state,
+ device_id,
+ last_ignore_accdata_streampos,
+ )
if filter_id is None:
filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/config_resource.py
index a95804d327..a95804d327 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/config_resource.py
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/download_resource.py
index 048a042692..8f270cf4cc 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -22,11 +22,10 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_boolean
from synapse.http.site import SynapseRequest
-
-from ._base import parse_media_id, respond_404
+from synapse.media._base import parse_media_id, respond_404
if TYPE_CHECKING:
- from synapse.rest.media.v1.media_repository import MediaRepository
+ from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py
new file mode 100644
index 0000000000..5ebaa3b032
--- /dev/null
+++ b/synapse/rest/media/media_repository_resource.py
@@ -0,0 +1,93 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from synapse.config._base import ConfigError
+from synapse.http.server import UnrecognizedRequestResource
+
+from .config_resource import MediaConfigResource
+from .download_resource import DownloadResource
+from .preview_url_resource import PreviewUrlResource
+from .thumbnail_resource import ThumbnailResource
+from .upload_resource import UploadResource
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class MediaRepositoryResource(UnrecognizedRequestResource):
+ """File uploading and downloading.
+
+ Uploads are POSTed to a resource which returns a token which is used to GET
+ the download::
+
+ => POST /_matrix/media/r0/upload HTTP/1.1
+ Content-Type: <media-type>
+ Content-Length: <content-length>
+
+ <media>
+
+ <= HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ { "content_uri": "mxc://<server-name>/<media-id>" }
+
+ => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1
+
+ <= HTTP/1.1 200 OK
+ Content-Type: <media-type>
+ Content-Disposition: attachment;filename=<upload-filename>
+
+ <media>
+
+ Clients can get thumbnails by supplying a desired width and height and
+ thumbnailing method::
+
+ => GET /_matrix/media/r0/thumbnail/<server_name>
+ /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1
+
+ <= HTTP/1.1 200 OK
+ Content-Type: image/jpeg or image/png
+
+ <thumbnail>
+
+ The thumbnail methods are "crop" and "scale". "scale" tries to return an
+ image where either the width or the height is smaller than the requested
+ size. The client should then scale and letterbox the image if it needs to
+ fit within a given rectangle. "crop" tries to return an image where the
+ width and height are close to the requested size and the aspect matches
+ the requested size. The client should scale the image if it needs to fit
+ within a given rectangle.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ # If we're not configured to use it, raise if we somehow got here.
+ if not hs.config.media.can_load_media_repo:
+ raise ConfigError("Synapse is not configured to use a media repo.")
+
+ super().__init__()
+ media_repo = hs.get_media_repository()
+
+ self.putChild(b"upload", UploadResource(hs, media_repo))
+ self.putChild(b"download", DownloadResource(hs, media_repo))
+ self.putChild(
+ b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
+ )
+ if hs.config.media.url_preview_enabled:
+ self.putChild(
+ b"preview_url",
+ PreviewUrlResource(hs, media_repo, media_repo.media_storage),
+ )
+ self.putChild(b"config", MediaConfigResource(hs))
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py
index a8f6fd6b35..7ada728757 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/preview_url_resource.py
@@ -40,21 +40,19 @@ from synapse.http.server import (
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.media._base import FileInfo, get_filename_from_headers
+from synapse.media.media_storage import MediaStorage
+from synapse.media.oembed import OEmbedProvider
+from synapse.media.preview_html import decode_body, parse_html_to_open_graph
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.rest.media.v1._base import get_filename_from_headers
-from synapse.rest.media.v1.media_storage import MediaStorage
-from synapse.rest.media.v1.oembed import OEmbedProvider
-from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.stringutils import random_string
-from ._base import FileInfo
-
if TYPE_CHECKING:
- from synapse.rest.media.v1.media_repository import MediaRepository
+ from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -163,6 +161,10 @@ class PreviewUrlResource(DirectServeJsonResource):
7. Stores the result in the database cache.
4. Returns the result.
+ If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or
+ image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole
+ does not fail. As much information as possible is returned.
+
The in-memory cache expires after 1 hour.
Expired entries in the database cache (and their associated media files) are
@@ -364,16 +366,25 @@ class PreviewUrlResource(DirectServeJsonResource):
oembed_url = self._oembed.autodiscover_from_html(tree)
og_from_oembed: JsonDict = {}
if oembed_url:
- oembed_info = await self._handle_url(
- oembed_url, user, allow_data_urls=True
- )
- (
- og_from_oembed,
- author_name,
- expiration_ms,
- ) = await self._handle_oembed_response(
- url, oembed_info, expiration_ms
- )
+ try:
+ oembed_info = await self._handle_url(
+ oembed_url, user, allow_data_urls=True
+ )
+ except Exception as e:
+ # Fetching the oEmbed info failed, don't block the entire URL preview.
+ logger.warning(
+ "oEmbed fetch failed during URL preview: %s errored with %s",
+ oembed_url,
+ e,
+ )
+ else:
+ (
+ og_from_oembed,
+ author_name,
+ expiration_ms,
+ ) = await self._handle_oembed_response(
+ url, oembed_info, expiration_ms
+ )
# Parse Open Graph information from the HTML in case the oEmbed
# response failed or is incomplete.
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 5f725c7600..4ee2a0dbda 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -27,9 +27,7 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
-from synapse.rest.media.v1.media_storage import MediaStorage
-
-from ._base import (
+from synapse.media._base import (
FileInfo,
ThumbnailInfo,
parse_media_id,
@@ -37,9 +35,10 @@ from ._base import (
respond_with_file,
respond_with_responder,
)
+from synapse.media.media_storage import MediaStorage
if TYPE_CHECKING:
- from synapse.rest.media.v1.media_repository import MediaRepository
+ from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -69,7 +68,8 @@ class ThumbnailResource(DirectServeJsonResource):
width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")
- m_type = parse_string(request, "type", "image/png")
+ # TODO Parse the Accept header to get an prioritised list of thumbnail types.
+ m_type = "image/png"
if server_name == self.server_name:
if self.dynamic_thumbnails:
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/upload_resource.py
index 97548b54e5..697348613b 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/upload_resource.py
@@ -20,10 +20,10 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_bytes_from_args
from synapse.http.site import SynapseRequest
-from synapse.rest.media.v1.media_storage import SpamMediaException
+from synapse.media.media_storage import SpamMediaException
if TYPE_CHECKING:
- from synapse.rest.media.v1.media_repository import MediaRepository
+ from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 6e035afcce..88427a5737 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,469 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+#
-import logging
-import os
-import urllib
-from abc import ABC, abstractmethod
-from types import TracebackType
-from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
-
-import attr
-
-from twisted.internet.interfaces import IConsumer
-from twisted.protocols.basic import FileSender
-from twisted.web.server import Request
-
-from synapse.api.errors import Codes, SynapseError, cs_error
-from synapse.http.server import finish_request, respond_with_json
-from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable
-from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
-
-logger = logging.getLogger(__name__)
-
-# list all text content types that will have the charset default to UTF-8 when
-# none is given
-TEXT_CONTENT_TYPES = [
- "text/css",
- "text/csv",
- "text/html",
- "text/calendar",
- "text/plain",
- "text/javascript",
- "application/json",
- "application/ld+json",
- "application/rtf",
- "image/svg+xml",
- "text/xml",
-]
-
-
-def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
- """Parses the server name, media ID and optional file name from the request URI
-
- Also performs some rough validation on the server name.
-
- Args:
- request: The `Request`.
-
- Returns:
- A tuple containing the parsed server name, media ID and optional file name.
-
- Raises:
- SynapseError(404): if parsing or validation fail for any reason
- """
- try:
- # The type on postpath seems incorrect in Twisted 21.2.0.
- postpath: List[bytes] = request.postpath # type: ignore
- assert postpath
-
- # This allows users to append e.g. /test.png to the URL. Useful for
- # clients that parse the URL to see content type.
- server_name_bytes, media_id_bytes = postpath[:2]
- server_name = server_name_bytes.decode("utf-8")
- media_id = media_id_bytes.decode("utf8")
-
- # Validate the server name, raising if invalid
- parse_and_validate_server_name(server_name)
-
- file_name = None
- if len(postpath) > 2:
- try:
- file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
- except UnicodeDecodeError:
- pass
- return server_name, media_id, file_name
- except Exception:
- raise SynapseError(
- 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN
- )
-
-
-def respond_404(request: SynapseRequest) -> None:
- respond_with_json(
- request,
- 404,
- cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND),
- send_cors=True,
- )
-
-
-async def respond_with_file(
- request: SynapseRequest,
- media_type: str,
- file_path: str,
- file_size: Optional[int] = None,
- upload_name: Optional[str] = None,
-) -> None:
- logger.debug("Responding with %r", file_path)
-
- if os.path.isfile(file_path):
- if file_size is None:
- stat = os.stat(file_path)
- file_size = stat.st_size
-
- add_file_headers(request, media_type, file_size, upload_name)
-
- with open(file_path, "rb") as f:
- await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
-
- finish_request(request)
- else:
- respond_404(request)
-
-
-def add_file_headers(
- request: Request,
- media_type: str,
- file_size: Optional[int],
- upload_name: Optional[str],
-) -> None:
- """Adds the correct response headers in preparation for responding with the
- media.
-
- Args:
- request
- media_type: The media/content type.
- file_size: Size in bytes of the media, if known.
- upload_name: The name of the requested file, if any.
- """
-
- def _quote(x: str) -> str:
- return urllib.parse.quote(x.encode("utf-8"))
-
- # Default to a UTF-8 charset for text content types.
- # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
- if media_type.lower() in TEXT_CONTENT_TYPES:
- content_type = media_type + "; charset=UTF-8"
- else:
- content_type = media_type
-
- request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
- if upload_name:
- # RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
- #
- # `filename` is defined to be a `value`, which is defined by RFC2616
- # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token`
- # is (essentially) a single US-ASCII word, and a `quoted-string` is a
- # US-ASCII string surrounded by double-quotes, using backslash as an
- # escape character. Note that %-encoding is *not* permitted.
- #
- # `filename*` is defined to be an `ext-value`, which is defined in
- # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`,
- # where `value-chars` is essentially a %-encoded string in the given charset.
- #
- # [1]: https://tools.ietf.org/html/rfc6266#section-4.1
- # [2]: https://tools.ietf.org/html/rfc2616#section-3.6
- # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1
-
- # We avoid the quoted-string version of `filename`, because (a) synapse didn't
- # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we
- # may as well just do the filename* version.
- if _can_encode_filename_as_token(upload_name):
- disposition = "inline; filename=%s" % (upload_name,)
- else:
- disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),)
-
- request.setHeader(b"Content-Disposition", disposition.encode("ascii"))
-
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our
- # clients are smart enough to be happy with Cache-Control
- request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
- if file_size is not None:
- request.setHeader(b"Content-Length", b"%d" % (file_size,))
-
- # Tell web crawlers to not index, archive, or follow links in media. This
- # should help to prevent things in the media repo from showing up in web
- # search results.
- request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
-
-
-# separators as defined in RFC2616. SP and HT are handled separately.
-# see _can_encode_filename_as_token.
-_FILENAME_SEPARATOR_CHARS = {
- "(",
- ")",
- "<",
- ">",
- "@",
- ",",
- ";",
- ":",
- "\\",
- '"',
- "/",
- "[",
- "]",
- "?",
- "=",
- "{",
- "}",
-}
-
-
-def _can_encode_filename_as_token(x: str) -> bool:
- for c in x:
- # from RFC2616:
- #
- # token = 1*<any CHAR except CTLs or separators>
- #
- # separators = "(" | ")" | "<" | ">" | "@"
- # | "," | ";" | ":" | "\" | <">
- # | "/" | "[" | "]" | "?" | "="
- # | "{" | "}" | SP | HT
- #
- # CHAR = <any US-ASCII character (octets 0 - 127)>
- #
- # CTL = <any US-ASCII control character
- # (octets 0 - 31) and DEL (127)>
- #
- if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS:
- return False
- return True
-
-
-async def respond_with_responder(
- request: SynapseRequest,
- responder: "Optional[Responder]",
- media_type: str,
- file_size: Optional[int],
- upload_name: Optional[str] = None,
-) -> None:
- """Responds to the request with given responder. If responder is None then
- returns 404.
-
- Args:
- request
- responder
- media_type: The media/content type.
- file_size: Size in bytes of the media. If not known it should be None
- upload_name: The name of the requested file, if any.
- """
- if not responder:
- respond_404(request)
- return
-
- # If we have a responder we *must* use it as a context manager.
- with responder:
- if request._disconnected:
- logger.warning(
- "Not sending response to request %s, already disconnected.", request
- )
- return
-
- logger.debug("Responding to media request with responder %s", responder)
- add_file_headers(request, media_type, file_size, upload_name)
- try:
-
- await responder.write_to_consumer(request)
- except Exception as e:
- # The majority of the time this will be due to the client having gone
- # away. Unfortunately, Twisted simply throws a generic exception at us
- # in that case.
- logger.warning("Failed to write to consumer: %s %s", type(e), e)
-
- # Unregister the producer, if it has one, so Twisted doesn't complain
- if request.producer:
- request.unregisterProducer()
-
- finish_request(request)
-
-
-class Responder(ABC):
- """Represents a response that can be streamed to the requester.
-
- Responder is a context manager which *must* be used, so that any resources
- held can be cleaned up.
- """
-
- @abstractmethod
- def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
- """Stream response into consumer
-
- Args:
- consumer: The consumer to stream into.
-
- Returns:
- Resolves once the response has finished being written
- """
- raise NotImplementedError()
-
- def __enter__(self) -> None: # noqa: B027
- pass
-
- def __exit__( # noqa: B027
- self,
- exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType],
- ) -> None:
- pass
-
-
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class ThumbnailInfo:
- """Details about a generated thumbnail."""
-
- width: int
- height: int
- method: str
- # Content type of thumbnail, e.g. image/png
- type: str
- # The size of the media file, in bytes.
- length: Optional[int] = None
-
-
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class FileInfo:
- """Details about a requested/uploaded file."""
-
- # The server name where the media originated from, or None if local.
- server_name: Optional[str]
- # The local ID of the file. For local files this is the same as the media_id
- file_id: str
- # If the file is for the url preview cache
- url_cache: bool = False
- # Whether the file is a thumbnail or not.
- thumbnail: Optional[ThumbnailInfo] = None
-
- # The below properties exist to maintain compatibility with third-party modules.
- @property
- def thumbnail_width(self) -> Optional[int]:
- if not self.thumbnail:
- return None
- return self.thumbnail.width
-
- @property
- def thumbnail_height(self) -> Optional[int]:
- if not self.thumbnail:
- return None
- return self.thumbnail.height
-
- @property
- def thumbnail_method(self) -> Optional[str]:
- if not self.thumbnail:
- return None
- return self.thumbnail.method
-
- @property
- def thumbnail_type(self) -> Optional[str]:
- if not self.thumbnail:
- return None
- return self.thumbnail.type
-
- @property
- def thumbnail_length(self) -> Optional[int]:
- if not self.thumbnail:
- return None
- return self.thumbnail.length
-
-
-def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
- """
- Get the filename of the downloaded file by inspecting the
- Content-Disposition HTTP header.
-
- Args:
- headers: The HTTP request headers.
-
- Returns:
- The filename, or None.
- """
- content_disposition = headers.get(b"Content-Disposition", [b""])
-
- # No header, bail out.
- if not content_disposition[0]:
- return None
-
- _, params = _parse_header(content_disposition[0])
-
- upload_name = None
-
- # First check if there is a valid UTF-8 filename
- upload_name_utf8 = params.get(b"filename*", None)
- if upload_name_utf8:
- if upload_name_utf8.lower().startswith(b"utf-8''"):
- upload_name_utf8 = upload_name_utf8[7:]
- # We have a filename*= section. This MUST be ASCII, and any UTF-8
- # bytes are %-quoted.
- try:
- # Once it is decoded, we can then unquote the %-encoded
- # parts strictly into a unicode string.
- upload_name = urllib.parse.unquote(
- upload_name_utf8.decode("ascii"), errors="strict"
- )
- except UnicodeDecodeError:
- # Incorrect UTF-8.
- pass
-
- # If there isn't check for an ascii name.
- if not upload_name:
- upload_name_ascii = params.get(b"filename", None)
- if upload_name_ascii and is_ascii(upload_name_ascii):
- upload_name = upload_name_ascii.decode("ascii")
-
- # This may be None here, indicating we did not find a matching name.
- return upload_name
-
-
-def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
- """Parse a Content-type like header.
-
- Cargo-culted from `cgi`, but works on bytes rather than strings.
-
- Args:
- line: header to be parsed
-
- Returns:
- The main content-type, followed by the parameter dictionary
- """
- parts = _parseparam(b";" + line)
- key = next(parts)
- pdict = {}
- for p in parts:
- i = p.find(b"=")
- if i >= 0:
- name = p[:i].strip().lower()
- value = p[i + 1 :].strip()
-
- # strip double-quotes
- if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
- value = value[1:-1]
- value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"')
- pdict[name] = value
-
- return key, pdict
-
-
-def _parseparam(s: bytes) -> Generator[bytes, None, None]:
- """Generator which splits the input on ;, respecting double-quoted sequences
-
- Cargo-culted from `cgi`, but works on bytes rather than strings.
-
- Args:
- s: header to be parsed
-
- Returns:
- The split input
- """
- while s[:1] == b";":
- s = s[1:]
-
- # look for the next ;
- end = s.find(b";")
-
- # if there is an odd number of " marks between here and the next ;, skip to the
- # next ; instead
- while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
- end = s.find(b";", end + 1)
-
- if end < 0:
- end = len(s)
- f = s[:end]
- yield f.strip()
- s = s[end:]
+# This exists purely for backwards compatibility with media providers and spam checkers.
+from synapse.media._base import FileInfo, Responder # noqa: F401
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index db25848744..11b0e8e231 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,364 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import contextlib
-import logging
-import os
-import shutil
-from types import TracebackType
-from typing import (
- IO,
- TYPE_CHECKING,
- Any,
- Awaitable,
- BinaryIO,
- Callable,
- Generator,
- Optional,
- Sequence,
- Tuple,
- Type,
-)
-
-import attr
-
-from twisted.internet.defer import Deferred
-from twisted.internet.interfaces import IConsumer
-from twisted.protocols.basic import FileSender
-
-import synapse
-from synapse.api.errors import NotFoundError
-from synapse.logging.context import defer_to_thread, make_deferred_yieldable
-from synapse.util import Clock
-from synapse.util.file_consumer import BackgroundFileConsumer
-
-from ._base import FileInfo, Responder
-from .filepath import MediaFilePaths
-
-if TYPE_CHECKING:
- from synapse.rest.media.v1.storage_provider import StorageProvider
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class MediaStorage:
- """Responsible for storing/fetching files from local sources.
-
- Args:
- hs
- local_media_directory: Base path where we store media on disk
- filepaths
- storage_providers: List of StorageProvider that are used to fetch and store files.
- """
-
- def __init__(
- self,
- hs: "HomeServer",
- local_media_directory: str,
- filepaths: MediaFilePaths,
- storage_providers: Sequence["StorageProvider"],
- ):
- self.hs = hs
- self.reactor = hs.get_reactor()
- self.local_media_directory = local_media_directory
- self.filepaths = filepaths
- self.storage_providers = storage_providers
- self.spam_checker = hs.get_spam_checker()
- self.clock = hs.get_clock()
-
- async def store_file(self, source: IO, file_info: FileInfo) -> str:
- """Write `source` to the on disk media store, and also any other
- configured storage providers
-
- Args:
- source: A file like object that should be written
- file_info: Info about the file to store
-
- Returns:
- the file path written to in the primary media store
- """
-
- with self.store_into_file(file_info) as (f, fname, finish_cb):
- # Write to the main repository
- await self.write_to_file(source, f)
- await finish_cb()
-
- return fname
-
- async def write_to_file(self, source: IO, output: IO) -> None:
- """Asynchronously write the `source` to `output`."""
- await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
-
- @contextlib.contextmanager
- def store_into_file(
- self, file_info: FileInfo
- ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
- """Context manager used to get a file like object to write into, as
- described by file_info.
-
- Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
- like object that can be written to, fname is the absolute path of file
- on disk, and finish_cb is a function that returns an awaitable.
-
- fname can be used to read the contents from after upload, e.g. to
- generate thumbnails.
-
- finish_cb must be called and waited on after the file has been
- successfully been written to. Should not be called if there was an
- error.
-
- Args:
- file_info: Info about the file to store
-
- Example:
-
- with media_storage.store_into_file(info) as (f, fname, finish_cb):
- # .. write into f ...
- await finish_cb()
- """
-
- path = self._file_info_to_path(file_info)
- fname = os.path.join(self.local_media_directory, path)
-
- dirname = os.path.dirname(fname)
- os.makedirs(dirname, exist_ok=True)
-
- finished_called = [False]
-
- try:
- with open(fname, "wb") as f:
-
- async def finish() -> None:
- # Ensure that all writes have been flushed and close the
- # file.
- f.flush()
- f.close()
-
- spam_check = await self.spam_checker.check_media_file_for_spam(
- ReadableFileWrapper(self.clock, fname), file_info
- )
- if spam_check != synapse.module_api.NOT_SPAM:
- logger.info("Blocking media due to spam checker")
- # Note that we'll delete the stored media, due to the
- # try/except below. The media also won't be stored in
- # the DB.
- # We currently ignore any additional field returned by
- # the spam-check API.
- raise SpamMediaException(errcode=spam_check[0])
-
- for provider in self.storage_providers:
- await provider.store_file(path, file_info)
-
- finished_called[0] = True
-
- yield f, fname, finish
- except Exception as e:
- try:
- os.remove(fname)
- except Exception:
- pass
-
- raise e from None
-
- if not finished_called:
- raise Exception("Finished callback not called")
-
- async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
- """Attempts to fetch media described by file_info from the local cache
- and configured storage providers.
-
- Args:
- file_info
-
- Returns:
- Returns a Responder if the file was found, otherwise None.
- """
- paths = [self._file_info_to_path(file_info)]
-
- # fallback for remote thumbnails with no method in the filename
- if file_info.thumbnail and file_info.server_name:
- paths.append(
- self.filepaths.remote_media_thumbnail_rel_legacy(
- server_name=file_info.server_name,
- file_id=file_info.file_id,
- width=file_info.thumbnail.width,
- height=file_info.thumbnail.height,
- content_type=file_info.thumbnail.type,
- )
- )
-
- for path in paths:
- local_path = os.path.join(self.local_media_directory, path)
- if os.path.exists(local_path):
- logger.debug("responding with local file %s", local_path)
- return FileResponder(open(local_path, "rb"))
- logger.debug("local file %s did not exist", local_path)
-
- for provider in self.storage_providers:
- for path in paths:
- res: Any = await provider.fetch(path, file_info)
- if res:
- logger.debug("Streaming %s from %s", path, provider)
- return res
- logger.debug("%s not found on %s", path, provider)
-
- return None
-
- async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
- """Ensures that the given file is in the local cache. Attempts to
- download it from storage providers if it isn't.
-
- Args:
- file_info
-
- Returns:
- Full path to local file
- """
- path = self._file_info_to_path(file_info)
- local_path = os.path.join(self.local_media_directory, path)
- if os.path.exists(local_path):
- return local_path
-
- # Fallback for paths without method names
- # Should be removed in the future
- if file_info.thumbnail and file_info.server_name:
- legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
- server_name=file_info.server_name,
- file_id=file_info.file_id,
- width=file_info.thumbnail.width,
- height=file_info.thumbnail.height,
- content_type=file_info.thumbnail.type,
- )
- legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
- if os.path.exists(legacy_local_path):
- return legacy_local_path
-
- dirname = os.path.dirname(local_path)
- os.makedirs(dirname, exist_ok=True)
-
- for provider in self.storage_providers:
- res: Any = await provider.fetch(path, file_info)
- if res:
- with res:
- consumer = BackgroundFileConsumer(
- open(local_path, "wb"), self.reactor
- )
- await res.write_to_consumer(consumer)
- await consumer.wait()
- return local_path
-
- raise NotFoundError()
-
- def _file_info_to_path(self, file_info: FileInfo) -> str:
- """Converts file_info into a relative path.
-
- The path is suitable for storing files under a directory, e.g. used to
- store files on local FS under the base media repository directory.
- """
- if file_info.url_cache:
- if file_info.thumbnail:
- return self.filepaths.url_cache_thumbnail_rel(
- media_id=file_info.file_id,
- width=file_info.thumbnail.width,
- height=file_info.thumbnail.height,
- content_type=file_info.thumbnail.type,
- method=file_info.thumbnail.method,
- )
- return self.filepaths.url_cache_filepath_rel(file_info.file_id)
-
- if file_info.server_name:
- if file_info.thumbnail:
- return self.filepaths.remote_media_thumbnail_rel(
- server_name=file_info.server_name,
- file_id=file_info.file_id,
- width=file_info.thumbnail.width,
- height=file_info.thumbnail.height,
- content_type=file_info.thumbnail.type,
- method=file_info.thumbnail.method,
- )
- return self.filepaths.remote_media_filepath_rel(
- file_info.server_name, file_info.file_id
- )
-
- if file_info.thumbnail:
- return self.filepaths.local_media_thumbnail_rel(
- media_id=file_info.file_id,
- width=file_info.thumbnail.width,
- height=file_info.thumbnail.height,
- content_type=file_info.thumbnail.type,
- method=file_info.thumbnail.method,
- )
- return self.filepaths.local_media_filepath_rel(file_info.file_id)
-
-
-def _write_file_synchronously(source: IO, dest: IO) -> None:
- """Write `source` to the file like `dest` synchronously. Should be called
- from a thread.
-
- Args:
- source: A file like object that's to be written
- dest: A file like object to be written to
- """
- source.seek(0) # Ensure we read from the start of the file
- shutil.copyfileobj(source, dest)
-
-
-class FileResponder(Responder):
- """Wraps an open file that can be sent to a request.
-
- Args:
- open_file: A file like object to be streamed ot the client,
- is closed when finished streaming.
- """
-
- def __init__(self, open_file: IO):
- self.open_file = open_file
-
- def write_to_consumer(self, consumer: IConsumer) -> Deferred:
- return make_deferred_yieldable(
- FileSender().beginFileTransfer(self.open_file, consumer)
- )
-
- def __exit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType],
- ) -> None:
- self.open_file.close()
-
-
-class SpamMediaException(NotFoundError):
- """The media was blocked by a spam checker, so we simply 404 the request (in
- the same way as if it was quarantined).
- """
-
-
-@attr.s(slots=True, auto_attribs=True)
-class ReadableFileWrapper:
- """Wrapper that allows reading a file in chunks, yielding to the reactor,
- and writing to a callback.
-
- This is simplified `FileSender` that takes an IO object rather than an
- `IConsumer`.
- """
-
- CHUNK_SIZE = 2**14
-
- clock: Clock
- path: str
-
- async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
- """Reads the file in chunks and calls the callback with each chunk."""
-
- with open(self.path, "rb") as file:
- while True:
- chunk = file.read(self.CHUNK_SIZE)
- if not chunk:
- break
-
- callback(chunk)
+#
- # We yield to the reactor by sleeping for 0 seconds.
- await self.clock.sleep(0)
+# This exists purely for backwards compatibility with spam checkers.
+from synapse.media.media_storage import ReadableFileWrapper # noqa: F401
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 1c9b71d69c..d7653f30ae 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,171 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+#
-import abc
-import logging
-import os
-import shutil
-from typing import TYPE_CHECKING, Callable, Optional
-
-from synapse.config._base import Config
-from synapse.logging.context import defer_to_thread, run_in_background
-from synapse.util.async_helpers import maybe_awaitable
-
-from ._base import FileInfo, Responder
-from .media_storage import FileResponder
-
-logger = logging.getLogger(__name__)
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class StorageProvider(metaclass=abc.ABCMeta):
- """A storage provider is a service that can store uploaded media and
- retrieve them.
- """
-
- @abc.abstractmethod
- async def store_file(self, path: str, file_info: FileInfo) -> None:
- """Store the file described by file_info. The actual contents can be
- retrieved by reading the file in file_info.upload_path.
-
- Args:
- path: Relative path of file in local cache
- file_info: The metadata of the file.
- """
-
- @abc.abstractmethod
- async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
- """Attempt to fetch the file described by file_info and stream it
- into writer.
-
- Args:
- path: Relative path of file in local cache
- file_info: The metadata of the file.
-
- Returns:
- Returns a Responder if the provider has the file, otherwise returns None.
- """
-
-
-class StorageProviderWrapper(StorageProvider):
- """Wraps a storage provider and provides various config options
-
- Args:
- backend: The storage provider to wrap.
- store_local: Whether to store new local files or not.
- store_synchronous: Whether to wait for file to be successfully
- uploaded, or todo the upload in the background.
- store_remote: Whether remote media should be uploaded
- """
-
- def __init__(
- self,
- backend: StorageProvider,
- store_local: bool,
- store_synchronous: bool,
- store_remote: bool,
- ):
- self.backend = backend
- self.store_local = store_local
- self.store_synchronous = store_synchronous
- self.store_remote = store_remote
-
- def __str__(self) -> str:
- return "StorageProviderWrapper[%s]" % (self.backend,)
-
- async def store_file(self, path: str, file_info: FileInfo) -> None:
- if not file_info.server_name and not self.store_local:
- return None
-
- if file_info.server_name and not self.store_remote:
- return None
-
- if file_info.url_cache:
- # The URL preview cache is short lived and not worth offloading or
- # backing up.
- return None
-
- if self.store_synchronous:
- # store_file is supposed to return an Awaitable, but guard
- # against improper implementations.
- await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
- else:
- # TODO: Handle errors.
- async def store() -> None:
- try:
- return await maybe_awaitable(
- self.backend.store_file(path, file_info)
- )
- except Exception:
- logger.exception("Error storing file")
-
- run_in_background(store)
-
- async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
- if file_info.url_cache:
- # Files in the URL preview cache definitely aren't stored here,
- # so avoid any potentially slow I/O or network access.
- return None
-
- # store_file is supposed to return an Awaitable, but guard
- # against improper implementations.
- return await maybe_awaitable(self.backend.fetch(path, file_info))
-
-
-class FileStorageProviderBackend(StorageProvider):
- """A storage provider that stores files in a directory on a filesystem.
-
- Args:
- hs
- config: The config returned by `parse_config`.
- """
-
- def __init__(self, hs: "HomeServer", config: str):
- self.hs = hs
- self.cache_directory = hs.config.media.media_store_path
- self.base_directory = config
-
- def __str__(self) -> str:
- return "FileStorageProviderBackend[%s]" % (self.base_directory,)
-
- async def store_file(self, path: str, file_info: FileInfo) -> None:
- """See StorageProvider.store_file"""
-
- primary_fname = os.path.join(self.cache_directory, path)
- backup_fname = os.path.join(self.base_directory, path)
-
- dirname = os.path.dirname(backup_fname)
- os.makedirs(dirname, exist_ok=True)
-
- # mypy needs help inferring the type of the second parameter, which is generic
- shutil_copyfile: Callable[[str, str], str] = shutil.copyfile
- await defer_to_thread(
- self.hs.get_reactor(),
- shutil_copyfile,
- primary_fname,
- backup_fname,
- )
-
- async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
- """See StorageProvider.fetch"""
-
- backup_fname = os.path.join(self.base_directory, path)
- if os.path.isfile(backup_fname):
- return FileResponder(open(backup_fname, "rb"))
-
- return None
-
- @staticmethod
- def parse_config(config: dict) -> str:
- """Called on startup to parse config supplied. This should parse
- the config and raise if there is a problem.
-
- The returned value is passed into the constructor.
-
- In this case we only care about a single param, the directory, so let's
- just pull that out.
- """
- return Config.ensure_directory(config["directory"])
+# This exists purely for backwards compatibility with media providers.
+from synapse.media.storage_provider import StorageProvider # noqa: F401
diff --git a/synapse/server.py b/synapse/server.py
index e5a3475247..a7c32e9a60 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -105,6 +105,7 @@ from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.media.media_repository import MediaRepository
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
from synapse.notifier import Notifier, ReplicationNotifier
@@ -115,10 +116,7 @@ from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
-from synapse.rest.media.v1.media_repository import (
- MediaRepository,
- MediaRepositoryResource,
-)
+from synapse.rest.media.media_repository_resource import MediaRepositoryResource
from synapse.server_notices.server_notices_manager import ServerNoticesManager
from synapse.server_notices.server_notices_sender import ServerNoticesSender
from synapse.server_notices.worker_server_notices_sender import (
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 564e3705c2..9732dbdb6e 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -178,7 +178,7 @@ class ServerNoticesManager:
"avatar_url": self._config.servernotices.server_notices_mxid_avatar_url,
}
- info, _ = await self._room_creation_handler.create_room(
+ room_id, _, _ = await self._room_creation_handler.create_room(
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
@@ -188,7 +188,6 @@ class ServerNoticesManager:
ratelimit=False,
creator_join_profile=join_profile,
)
- room_id = info["room_id"]
self.maybe_get_notice_room_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 837dc7646e..dc3948c170 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
-from .filtering import FilteringStore
+from .filtering import FilteringWorkerStore
from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
@@ -99,7 +99,7 @@ class DataStore(
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
- FilteringStore,
+ FilteringWorkerStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 95567826f2..308d19440f 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -237,6 +237,37 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
return None
+ async def get_latest_stream_id_for_global_account_data_by_type_for_user(
+ self, user_id: str, data_type: str
+ ) -> Optional[int]:
+ """
+ Returns:
+ The stream ID of the account data,
+ or None if there is no such account data.
+ """
+
+ def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
+ sql = """
+ SELECT stream_id FROM account_data
+ WHERE user_id = ? AND account_data_type = ?
+ ORDER BY stream_id DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (user_id, data_type))
+
+ row = txn.fetchone()
+ if row:
+ return row[0]
+ else:
+ return None
+
+ return await self.db_pool.runInteraction(
+ "get_latest_stream_id_for_global_account_data_by_type_for_user",
+ get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
+ )
+
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8e61aba454..0d75d9739a 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
],
)
- for (user_id, messages_by_device) in edu["messages"].items():
- for (device_id, msg) in messages_by_device.items():
+ for user_id, messages_by_device in edu["messages"].items():
+ for device_id, msg in messages_by_device.items():
with start_active_span("store_outgoing_to_device_message"):
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
@@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
def _remove_dead_devices_from_device_inbox_txn(
txn: LoggingTransaction,
) -> Tuple[int, bool]:
-
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1ca66d57d4..0dd15f16ff 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -512,7 +512,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
results.append(("org.matrix.signing_key_update", result))
if issue_8631_logger.isEnabledFor(logging.DEBUG):
- for (user_id, edu) in results:
+ for user_id, edu in results:
issue_8631_logger.debug(
"device update to %s for %s from %s to %s: %s",
destination,
@@ -1316,7 +1316,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
"""
count = 0
- for (destination, user_id, stream_id, device_id) in rows:
+ for destination, user_id, stream_id, device_id in rows:
txn.execute(
delete_sql, (destination, user_id, stream_id, stream_id, device_id)
)
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 6240f9a75e..9f8d2e4bea 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No backup with that version exists")
values = []
- for (room_id, session_id, room_key) in room_keys:
+ for room_id, session_id, room_key in room_keys:
values.append(
(
user_id,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 2c2d145666..b9c39b1718 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
# add each cross-signing signature to the correct device in the result dict.
- for (user_id, key_id, device_id, signature) in cross_sigs_result:
+ for user_id, key_id, device_id, signature in cross_sigs_result:
target_device_result = result[user_id][device_id]
# We've only looked up cross-signatures for non-deleted devices with key
# data.
@@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
# devices.
user_list = []
user_device_list = []
- for (user_id, device_id) in query_list:
+ for user_id, device_id in query_list:
if device_id is None:
user_list.append(user_id)
else:
@@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
txn.execute(sql, query_params)
- for (user_id, device_id, display_name, key_json) in txn:
+ for user_id, device_id, display_name, key_json in txn:
assert device_id is not None
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
@@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
signature_query_clauses = []
signature_query_params = []
- for (user_id, device_id) in device_query:
+ for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ca780cca36..ff3edeb716 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
latest_events: List[str],
limit: int,
) -> List[str]:
-
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
event_results: List[str] = []
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 7996cbb557..73b8aea16c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -469,7 +469,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events: List[EventBase],
) -> None:
-
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
return
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 584536111d..0a275e6ce6 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
nbrows = 0
last_row_event_id = ""
- for (event_id, event_json_raw) in results:
+ for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
results = list(txn)
# (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
- for (event_id, event_json_raw) in results:
+ for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
except Exception as e:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d0ef10258..b7e7498125 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1493,7 +1493,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(redactions_sql + clause, args)
- for (redacter, redacted) in txn:
+ for redacter, redacted in txn:
d = event_dict.get(redacted)
if d:
d.redactions.append(redacter)
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 12f3b601f1..8e57c8e5a0 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
@@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore):
return db_to_json(def_json)
-
-class FilteringStore(FilteringWorkerStore):
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)
@@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore):
return filter_id
- return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+ attempts = 0
+ while True:
+ # Try a few times.
+ # This is technically needed if a user tries to create two filters at once,
+ # leading to two concurrent transactions.
+ # The failure case would be:
+ # - SELECT filter_id ... filter_json = ? → both transactions return no rows
+ # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
+ # - INSERT INTO ... → both transactions insert filter_id = 6
+ # One of the transactions will commit. The other will get a unique key
+ # constraint violation error (IntegrityError). This is not the same as a
+ # serialisability violation, which would be automatically retried by
+ # `runInteraction`.
+ try:
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+ except self.db_pool.engine.module.IntegrityError:
+ attempts += 1
+
+ if attempts >= 5:
+ raise StoreError(500, "Couldn't generate a filter ID.")
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index b202c5eb87..fa8be214ce 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
-
# Set ordering
order_by_column = MediaSortOrder(order_by).value
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df53e726e6..fddbc07afa 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,6 @@ class PusherWorkerStore(SQLBaseStore):
last_user = progress.get("last_user", "")
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT name FROM users
WHERE deactivated = ? and name > ?
@@ -392,7 +391,6 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT p.id, access_token FROM pushers AS p
LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
@@ -449,7 +447,6 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn: LoggingTransaction) -> int:
-
sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey
FROM pushers AS p
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index dddf49c2d5..92a82240ab 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
def _populate_receipt_event_stream_ordering_txn(
txn: LoggingTransaction,
) -> bool:
-
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9a55e17624..717237e024 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="user_delete_threepid",
)
- async def user_delete_threepids(self, user_id: str) -> None:
- """Delete all threepid this user has bound
-
- Args:
- user_id: The user id to delete all threepids of
-
- """
- await self.db_pool.simple_delete(
- "user_threepids",
- keyvalues={"user_id": user_id},
- desc="user_delete_threepids",
- )
-
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
) -> None:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 644bbb8878..a2e9519cb6 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1417,6 +1417,27 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_un_partial_stated_rooms_from_stream_txn,
)
+ async def delete_event_report(self, report_id: int) -> bool:
+ """Remove an event report from database.
+
+ Args:
+ report_id: Report to delete
+
+ Returns:
+ Whether the report was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ table="event_reports",
+ keyvalues={"id": report_id},
+ desc="delete_event_report",
+ )
+ except StoreError:
+ # Deletion failed because report does not exist
+ return False
+
+ return True
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -2139,7 +2160,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
reason: Optional[str],
content: JsonDict,
received_ts: int,
- ) -> None:
+ ) -> int:
+ """Add an event report
+
+ Args:
+ room_id: Room that contains the reported event.
+ event_id: The reported event.
+ user_id: User who reports the event.
+ reason: Description that the user specifies.
+ content: Report request body (score and reason).
+ received_ts: Time when the user submitted the report (milliseconds).
+ Returns:
+ Id of the event report.
+ """
next_id = self._event_reports_id_gen.get_next()
await self.db_pool.simple_insert(
table="event_reports",
@@ -2154,6 +2187,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
},
desc="add_event_report",
)
+ return next_id
async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
"""Retrieve an event report
@@ -2168,7 +2202,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
def _get_event_report_txn(
txn: LoggingTransaction, report_id: int
) -> Optional[Dict[str, Any]]:
-
sql = """
SELECT
er.id,
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3fe433f66c..a7aae661d8 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore):
class SearchBackgroundUpdateStore(SearchWorkerStore):
-
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
@@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
-
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ba325d390b..ebb2ae964f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
-
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7b7d0c3c9..d3393d8e49 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore):
insert_cols = []
qargs = []
- for (key, val) in chain(
+ for key, val in chain(
keyvalues.items(), absolutes.items(), additive_relatives.items()
):
insert_cols.append(key)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 818c46182e..ac5fbf6b86 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
+
# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6b33d809b6..6d72bd9f67 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
def get_destination_rooms_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
-
if direction == Direction.BACKWARDS:
order = "DESC"
else:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 30af4b3b6c..f16a509ac4 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,6 +14,7 @@
import logging
import re
+import unicodedata
from typing import (
TYPE_CHECKING,
Iterable,
@@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def _populate_user_directory_createtables(
self, progress: JsonDict, batch_size: int
) -> int:
-
# Get all the rooms that we want to process.
def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
@@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
values={"display_name": display_name, "avatar_url": avatar_url},
)
+ # The display name that goes into the database index.
+ index_display_name = display_name
+ if index_display_name is not None:
+ index_display_name = _filter_text_for_index(index_display_name)
+
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
@@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_id,
get_localpart_from_id(user_id),
get_domain_from_id(user_id),
- display_name,
+ index_display_name,
),
)
elif isinstance(self.database_engine, Sqlite3Engine):
- value = "%s %s" % (user_id, display_name) if display_name else user_id
+ value = (
+ "%s %s" % (user_id, index_display_name)
+ if index_display_name
+ else user_id
+ )
self.db_pool.simple_upsert_txn(
txn,
table="user_directory_search",
@@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results[0:limit]}
+def _filter_text_for_index(text: str) -> str:
+ """Transforms text before it is inserted into the user directory index, or searched
+ for in the user directory index.
+
+ Note that the user directory search table needs to be rebuilt whenever this function
+ changes.
+ """
+ # Lowercase the text, to make searches case-insensitive.
+ # This is necessary for both PostgreSQL and SQLite. PostgreSQL's
+ # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using
+ # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at
+ # all.
+ text = text.lower()
+
+ # Normalize the text. NFKC normalization has two effects:
+ # 1. It canonicalizes the text, ie. maps all visually identical strings to the same
+ # string. For example, ["e", "◌́"] is mapped to ["é"].
+ # 2. It maps strings that are roughly equivalent to the same string.
+ # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to
+ # ["i", "9"].
+ text = unicodedata.normalize("NFKC", text)
+
+ # Note that nothing is done to make searches accent-insensitive.
+ # That could be achieved by converting to NFKD form instead (with combining accents
+ # split out) and filtering out combining accents using `unicodedata.combining(c)`.
+ # The downside of this may be noisier search results, since search terms with
+ # explicit accents will match characters with no accents, or completely different
+ # accents.
+ #
+ # text = unicodedata.normalize("NFKD", text)
+ # text = "".join([c for c in text if not unicodedata.combining(c)])
+
+ return text
+
+
def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
@@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str:
We specifically add both a prefix and non prefix matching term so that
exact matches get ranked higher.
"""
+ search_term = _filter_text_for_index(search_term)
# Pull out the individual words, discarding any non-word characters.
results = _parse_words(search_term)
@@ -918,6 +963,8 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
+ search_term = _filter_text_for_index(search_term)
+
escaped_words = []
for word in _parse_words(search_term):
# Postgres tsvector and tsquery quoting rules:
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d743282f13..097dea5182 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
-
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 1a7232b276..bf4cdfdf29 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
import attr
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -257,14 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
member_filter, non_member_filter = state_filter.get_member_split()
# Now we look them up in the member and non-member caches
- (
- non_member_state,
- incomplete_groups_nm,
- ) = self._get_state_for_groups_using_cache(
+ non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter
)
- (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
+ member_state, incomplete_groups_m = self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter
)
@@ -404,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
+ async def store_state_deltas_for_batched(
+ self,
+ events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
+ room_id: str,
+ prev_group: int,
+ ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+ """Generate and store state deltas for a group of events and contexts created to be
+ batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
+
+ Args:
+ events_and_context: the events to generate and store a state groups for
+ and their associated contexts
+ room_id: the id of the room the events were created for
+ prev_group: the state group of the last event persisted before the batched events
+ were created
+ """
+
+ def insert_deltas_group_txn(
+ txn: LoggingTransaction,
+ events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
+ prev_group: int,
+ ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+ """Generate and store state groups for the provided events and contexts.
+
+ Requires that we have the state as a delta from the last persisted state group.
+
+ Returns:
+ A list of state groups
+ """
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ num_state_groups = sum(
+ 1 for event, _ in events_and_context if event.is_state()
+ )
+
+ state_groups = self._state_group_seq_gen.get_next_mult_txn(
+ txn, num_state_groups
+ )
+
+ sg_before = prev_group
+ state_group_iter = iter(state_groups)
+ for event, context in events_and_context:
+ if not event.is_state():
+ context.state_group_after_event = sg_before
+ context.state_group_before_event = sg_before
+ continue
+
+ sg_after = next(state_group_iter)
+ context.state_group_after_event = sg_after
+ context.state_group_before_event = sg_before
+ context.state_delta_due_to_event = {
+ (event.type, event.state_key): event.event_id
+ }
+ sg_before = sg_after
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups",
+ keys=("id", "room_id", "event_id"),
+ values=[
+ (context.state_group_after_event, room_id, event.event_id)
+ for event, context in events_and_context
+ if event.is_state()
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_group_edges",
+ keys=("state_group", "prev_state_group"),
+ values=[
+ (
+ context.state_group_after_event,
+ context.state_group_before_event,
+ )
+ for event, context in events_and_context
+ if event.is_state()
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (
+ context.state_group_after_event,
+ room_id,
+ key[0],
+ key[1],
+ state_id,
+ )
+ for event, context in events_and_context
+ if context.state_delta_due_to_event is not None
+ for key, state_id in context.state_delta_due_to_event.items()
+ ],
+ )
+ return events_and_context
+
+ return await self.db_pool.runInteraction(
+ "store_state_deltas_for_batched.insert_deltas_group",
+ insert_deltas_group_txn,
+ events_and_context,
+ prev_group,
+ )
+
async def store_state_group(
self,
event_id: str,
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6c335a9315..2a1c6fa31b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -563,7 +563,7 @@ def _apply_module_schemas(
"""
# This is the old way for password_auth_provider modules to make changes
# to the database. This should instead be done using the module API
- for (mod, _config) in config.authproviders.password_providers:
+ for mod, _config in config.authproviders.password_providers:
if not hasattr(mod, "get_db_schema_files"):
continue
modname = ".".join((mod.__module__, mod.__name__))
@@ -591,7 +591,7 @@ def _apply_module_schema_files(
(modname,),
)
applied_deltas = {d for d, in cur}
- for (name, stream) in names_and_streams:
+ for name, stream in names_and_streams:
if name in applied_deltas:
continue
diff --git a/synapse/types/state.py b/synapse/types/state.py
index 743a4f9217..4b3071acce 100644
--- a/synapse/types/state.py
+++ b/synapse/types/state.py
@@ -120,7 +120,7 @@ class StateFilter:
def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
"""The inverse to `from_types`."""
- for (event_type, state_keys) in self.types.items():
+ for event_type, state_keys in self.types.items():
if state_keys is None:
yield event_type, None
else:
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 9387632d0d..6ffa56217e 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -98,7 +98,6 @@ class EvictionReason(Enum):
@attr.s(slots=True, auto_attribs=True)
class CacheMetric:
-
_cache: Sized
_cache_type: str
_cache_name: str
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 3b1e205700..1c0fde4966 100644
--- a/synapse/util/check_dependencies.py
+++ b/synapse/util/check_dependencies.py
@@ -183,7 +183,7 @@ def check_requirements(extra: Optional[str] = None) -> None:
deps_unfulfilled = []
errors = []
- for (requirement, must_be_installed) in dependencies:
+ for requirement, must_be_installed in dependencies:
try:
dist: metadata.Distribution = metadata.distribution(requirement.name)
except metadata.PackageNotFoundError:
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index f97f98a057..d00d34e652 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -211,7 +211,6 @@ def _check_yield_points(
result = Failure()
if current_context() != expected_context:
-
# This happens because the context is lost sometime *after* the
# previous yield and *after* the current yield. E.g. the
# deferred we waited on didn't follow the rules, or we forgot to
|